Mục Lục
- 1 1. Giới thiệu
- 2 2. Validation: Đánh Giá Mô Hình Khách Quan
- 3 3. Regularization: “Phạt” Độ Phức Tạp
- 3.1 3.1. Early Stopping: Dừng Đúng Lúc
- 3.2 3.2. Thêm Số Hạng Regularization vào Hàm Mất Mát
- 3.3 3.3. L2 Regularization (Weight Decay): Giảm Kích Thước Các Tham Số
- 3.4 3.4. L1 Regularization: Tạo Ra Các Mô Hình Sparse
- 3.5 3.5. Elastic Net Regularization: Kết Hợp L1 và L2
- 3.6 3.6. Regularization trong Sklearn
- 4 4. Các Phương Pháp Khác
- 5 5. Tóm Tắt Nội Dung
- 6 6. Tài Liệu Tham Khảo
1. Giới thiệu
Trong lĩnh vực Machine Learning, mục tiêu là xây dựng các mô hình có khả năng dự đoán chính xác trên dữ liệu mới, chưa từng thấy. Tuy nhiên, một vấn đề phổ biến là overfitting, xảy ra khi mô hình “học thuộc” dữ liệu huấn luyện (training data) quá kỹ, dẫn đến hiệu suất kém trên dữ liệu kiểm tra (test data). Bài viết này sẽ đi sâu vào khái niệm overfitting và các kỹ thuật để giảm thiểu hiện tượng này, giúp bạn xây dựng các mô hình mạnh mẽ và đáng tin cậy hơn.
Hãy bắt đầu với một ví dụ trực quan. Tưởng tượng bạn đang cố gắng dự đoán giá nhà dựa trên diện tích. Nếu bạn chỉ có vài căn nhà để huấn luyện mô hình, bạn có thể tạo ra một mô hình phức tạp hoàn toàn phù hợp với những dữ liệu đó. Tuy nhiên, mô hình này có thể hoạt động rất tệ trên các căn nhà mới vì nó đã “học” những đặc điểm riêng biệt của tập dữ liệu nhỏ đó thay vì các quy luật chung.
Hình trên minh họa hiện tượng overfitting. Các điểm dữ liệu màu đỏ là dữ liệu huấn luyện, các điểm màu vàng là dữ liệu kiểm tra và đường màu xanh lục là mô hình “thực tế”. Mô hình bậc cao (ví dụ bậc 16) khớp hoàn hảo với dữ liệu huấn luyện, nhưng lại dự đoán sai lệch đáng kể trên dữ liệu kiểm tra. Mô hình bậc 4 cho kết quả tốt hơn vì nó nắm bắt được xu hướng chung mà không bị ảnh hưởng bởi nhiễu trong dữ liệu huấn luyện.
Về bản chất, overfitting xảy ra khi mô hình trở nên quá phức tạp so với lượng dữ liệu huấn luyện có sẵn. Độ phức tạp của mô hình có thể được hiểu là số lượng tham số mà mô hình có thể điều chỉnh để phù hợp với dữ liệu. Ví dụ, trong mô hình đa thức, bậc của đa thức quyết định độ phức tạp. Trong mạng nơ-ron (Neural Networks), số lượng lớp ẩn (hidden layers) và số lượng nơ-ron (units) trong mỗi lớp quyết định độ phức tạp.
Vậy, làm thế nào để tránh overfitting? Phần tiếp theo sẽ trình bày các kỹ thuật phổ biến và hiệu quả nhất.
2. Validation: Đánh Giá Mô Hình Khách Quan
2.1. Tập Validation: “Kỳ Thi Thử” Trước Khi Đánh Giá Thực Tế
Để đánh giá khả năng khái quát hóa của mô hình (khả năng hoạt động tốt trên dữ liệu mới), chúng ta cần một tập dữ liệu mà mô hình chưa từng thấy trong quá trình huấn luyện. Đó là lý do chúng ta chia tập dữ liệu ban đầu thành ba phần:
- Tập huấn luyện (Training set): Dùng để huấn luyện mô hình.
- Tập validation (Validation set): Dùng để điều chỉnh các siêu tham số (hyperparameters) của mô hình và lựa chọn mô hình tốt nhất.
- Tập kiểm tra (Test set): Dùng để đánh giá hiệu suất cuối cùng của mô hình đã chọn.
Tập validation đóng vai trò như một “kỳ thi thử” trước khi đánh giá mô hình trên tập kiểm tra. Bằng cách đánh giá mô hình trên tập validation, chúng ta có thể điều chỉnh các siêu tham số (ví dụ, bậc của đa thức trong ví dụ trên) để đạt được hiệu suất tốt nhất trên dữ liệu chưa từng thấy.
Hình trên minh họa cách sử dụng tập validation để lựa chọn mô hình. Train error thường giảm khi độ phức tạp của mô hình tăng lên. Tuy nhiên, validation error có thể bắt đầu tăng sau một điểm nhất định, cho thấy mô hình đang bắt đầu overfitting. Chúng ta nên chọn mô hình có validation error thấp nhất.
2.2. Cross-validation: Tối Ưu Hóa Dữ Liệu Khiêm Tốn
Trong nhiều trường hợp, chúng ta có thể không có đủ dữ liệu để tạo ra một tập validation lớn. Lúc này, cross-validation là một kỹ thuật mạnh mẽ để đánh giá mô hình một cách hiệu quả hơn.
K-fold cross-validation hoạt động bằng cách chia tập huấn luyện thành k phần (folds) có kích thước gần bằng nhau. Sau đó, chúng ta huấn luyện mô hình k lần, mỗi lần sử dụng một fold khác nhau làm tập validation và k-1 fold còn lại làm tập huấn luyện. Hiệu suất của mô hình được tính trung bình trên k lần huấn luyện này.
K-fold Cross-validation: Đánh giá mô hình trên nhiều tập Validation khác nhau.
Hình trên minh họa K-fold cross-validation với k=5. Dữ liệu được chia thành 5 folds. Mỗi fold được sử dụng làm tập validation một lần, trong khi 4 fold còn lại được sử dụng làm tập huấn luyện.
Leave-one-out cross-validation là một trường hợp đặc biệt của K-fold cross-validation, trong đó k bằng với số lượng mẫu trong tập huấn luyện. Mỗi mẫu được sử dụng làm tập validation một lần, trong khi tất cả các mẫu còn lại được sử dụng làm tập huấn luyện.
Sklearn cung cấp nhiều công cụ để thực hiện cross-validation một cách dễ dàng.
Cross-validation trong Sklearn: Dễ dàng thực hiện đánh giá mô hình.
Cross-validation giúp chúng ta đánh giá mô hình một cách khách quan hơn, đặc biệt khi lượng dữ liệu có sẵn hạn chế.
3. Regularization: “Phạt” Độ Phức Tạp
Regularization là một nhóm các kỹ thuật giúp ngăn chặn overfitting bằng cách thêm một “hình phạt” vào hàm mất mát (loss function) để khuyến khích mô hình đơn giản hơn. Ý tưởng là thay vì cố gắng khớp hoàn hảo với dữ liệu huấn luyện (có thể bao gồm cả nhiễu), chúng ta muốn mô hình tìm ra các quy luật chung và bỏ qua các chi tiết không quan trọng.
3.1. Early Stopping: Dừng Đúng Lúc
Trong quá trình huấn luyện, đặc biệt là với các thuật toán lặp như Gradient Descent, chúng ta có thể theo dõi hiệu suất của mô hình trên tập validation. Nếu chúng ta thấy rằng validation error bắt đầu tăng lên, đó là dấu hiệu cho thấy mô hình đang bắt đầu overfitting. Early stopping là kỹ thuật dừng quá trình huấn luyện trước khi mô hình đạt đến điểm overfitting.
Hình trên minh họa early stopping. Huấn luyện được dừng lại tại điểm mà validation error đạt giá trị nhỏ nhất, trước khi bắt đầu tăng lên do overfitting.
3.2. Thêm Số Hạng Regularization vào Hàm Mất Mát
Một phương pháp regularization phổ biến khác là thêm một số hạng vào hàm mất mát để “phạt” độ phức tạp của mô hình. Hàm mất mát mới được gọi là regularized loss function:
J_reg(theta) = J(theta) + lambda * R(theta)
Trong đó:
J(theta)là hàm mất mát ban đầu (ví dụ, mean squared error cho bài toán hồi quy).R(theta)là số hạng regularization, đo lường độ phức tạp của mô hình.lambdalà tham số regularization, điều chỉnh mức độ “phạt” độ phức tạp.lambdacàng lớn, mô hình càng đơn giản.
3.3. L2 Regularization (Weight Decay): Giảm Kích Thước Các Tham Số
L2 regularization, còn được gọi là weight decay, là một trong những kỹ thuật regularization phổ biến nhất. Nó thêm một số hạng vào hàm mất mát để “phạt” các tham số có giá trị lớn:
R(w) = ||w||_2^2
Trong đó w là vector chứa tất cả các tham số của mô hình và ||w||_2^2 là bình phương của L2 norm (Euclidean norm) của w.
L2 regularization khuyến khích các tham số có giá trị nhỏ, giúp mô hình trở nên đơn giản hơn và ít nhạy cảm hơn với nhiễu trong dữ liệu huấn luyện.
Hình trên minh họa ảnh hưởng của weight decay. Khi tham số regularization (lambda) tăng lên, các đường phân chia giữa các lớp trở nên mượt mà hơn, cho thấy mô hình ít overfitting hơn.
Trong Neural Networks, weight decay thường chỉ được áp dụng cho các trọng số (weights) và không áp dụng cho các bias.
3.4. L1 Regularization: Tạo Ra Các Mô Hình Sparse
L1 regularization thêm một số hạng vào hàm mất mát để “phạt” tổng giá trị tuyệt đối của các tham số:
R(w) = ||w||_1 = sum(|w_i|)
L1 regularization có xu hướng tạo ra các mô hình sparse, trong đó nhiều tham số có giá trị bằng 0. Điều này có thể hữu ích để lựa chọn các đặc trưng quan trọng nhất và loại bỏ các đặc trưng không liên quan.
3.5. Elastic Net Regularization: Kết Hợp L1 và L2
Elastic Net regularization kết hợp cả L1 và L2 regularization:
R(w) = alpha * ||w||_1 + (1 - alpha) * ||w||_2^2
Trong đó alpha là một tham số điều chỉnh mức độ kết hợp giữa L1 và L2 regularization.
3.6. Regularization trong Sklearn
Sklearn cung cấp các công cụ để dễ dàng áp dụng các kỹ thuật regularization khác nhau. Ví dụ, trong Logistic Regression, bạn có thể sử dụng các tham số penalty (để chọn loại regularization) và C (nghịch đảo của lambda) để điều chỉnh mức độ regularization.
4. Các Phương Pháp Khác
Ngoài các kỹ thuật đã đề cập ở trên, còn có nhiều phương pháp khác để chống overfitting, bao gồm:
- Dropout: Một kỹ thuật đặc biệt hiệu quả trong Deep Learning, trong đó các nơ-ron được “tắt” ngẫu nhiên trong quá trình huấn luyện.
- Data Augmentation: Tăng kích thước của tập dữ liệu huấn luyện bằng cách tạo ra các bản sao đã được biến đổi nhẹ của dữ liệu hiện có (ví dụ, xoay, lật, hoặc thêm nhiễu vào ảnh).
- Pruning: Loại bỏ các nhánh không quan trọng trong cây quyết định (Decision Trees).
5. Tóm Tắt Nội Dung
Overfitting là một vấn đề phổ biến trong Machine Learning, xảy ra khi mô hình “học thuộc” dữ liệu huấn luyện quá kỹ, dẫn đến hiệu suất kém trên dữ liệu mới. Để chống overfitting, chúng ta có thể sử dụng nhiều kỹ thuật khác nhau, bao gồm:
- Validation: Sử dụng tập validation để đánh giá mô hình một cách khách quan và điều chỉnh các siêu tham số.
- Cross-validation: Đánh giá mô hình trên nhiều tập validation khác nhau để có được ước tính chính xác hơn về hiệu suất.
- Regularization: Thêm một số hạng vào hàm mất mát để “phạt” độ phức tạp của mô hình.
- Các phương pháp khác: Dropout, Data Augmentation, Pruning, v.v.
Việc lựa chọn kỹ thuật phù hợp phụ thuộc vào từng bài toán cụ thể và cần được thực hiện một cách cẩn thận để đạt được hiệu suất tốt nhất.
