Machine Learning/scikit-learn
[교차 검증] Stratified K 폴드
ISLA!
2023. 8. 17. 17:57
Stratified K 폴드
- Stratified K 폴드는 불균형한 분포도를 가진 레이블 데이터 집합을 위한 K 폴드 방식이다
- 불균형한 분포도를 가진 데이터 집합은 특정 레이블 값이 특이하게 많거나 매우 적어서 값의 분포가 한 쪽으로 치우치는 것을 말한다.
- Stratified K 폴드는 K 폴드가 레이블 데이터 집합이 원본 데이터 집합의 레이블 분포를 학습 및 테스트 세트에 제대로 분배하지 못하는 경우의 문제를 해결해준다
- 따라서, 왜곡된 레이블 데이터 세트에서 분류를 할 때는, 반드시 Stratified K 폴드를 사용해야한다.
- 일반적인 분류에서도 Stratified K 폴드를 사용하며 회귀에서는 지원되지 않는다.(분포 측정이 불필요하기 때문)
📑 동작 방식
1. 원본 데이터의 레이블 분포를 먼저 고려
2. 이 분포와 동일하게 학습과 검증 데이터 세트를 분배
📑 예제
- Kfold와 다른 점은, split() 메서드에 인자로 피처 데이터 세트 & 레이블 데이터 세트를 모두 넣어줘야 한다는 것!
- 그래야 레이블 데이터 세트의 분포를 고려하여 데이터를 분할할 수 있기 때문이다.
- 객체 생성시 n_splits의 값을 먼저 지정해준다.
from sklearn.model_selection import StratifiedKFold
import pandas as pd
iris = load_iris()
iris_df = pd.DataFrame(data = iris.data, columns = iris.feature_names)
iris_df['label'] = iris.target
skf = StratifiedKFold(n_splits=3)
n_iter = 0
for train_index, test_index in skf.split(iris_df, iris_df['label']):
n_iter += 1
label_train = iris_df['label'].iloc[train_index]
label_test = iris_df['label'].iloc[test_index]
print('##교차검증: {0}'.format(n_iter))
print('학습 레이블 데이터 분포:\n', label_train.value_counts())
print('검증 레이블 데이터 분포:\n', label_test.value_counts())
📑 결과
- 0, 1, 2의 3개 범주가 모두 균일하게 학습 / 검증 레이블로 치우치지 않게 분리됨을 확인했다.
📑 예제 2
- StratifiedKFold 를 사용하여, iris data 분류 및 예측 수행
dt_clf = DecisionTreeClassifier(random_state = 11)
skfold = StratifiedKFold(n_splits = 3)
n_iter = 0
cv_accuracy = []
for train_index, test_index in skf.split(features, label):
X_train, X_test = features[train_index], features[test_index]
y_train, y_test = label[train_index], label[test_index]
#학습 및 예측
dt_clf.fit(X_train, y_train)
pred = dt_clf.predict(X_test)
#반복시마다 정확도 측정
n_iter += 1
accuracy = np.round(accuracy_score(y_test, pred), 4)
train_size = X_train.shape[0]
test_size = X_test.shape[0]
print('\n#{0} 교차 검증 정확도 :{1}, 학습 데이터 크기:{2}, 검증 데이터 크기:{3}'.format(n_iter, accuracy, train_size, test_size))
print('#{0} 검증 세트 인덱스:{1}'.format(n_iter, test_index))
cv_accuracy.append(accuracy)
#교차 검증별 정확도 및 평균 정확도 계산
print('\n## 교차 검증 정확도', np.round(cv_accuracy, 4))
print('## 평균 교차 검증 정확도:', np.round(np.mean(cv_accuracy),4))
728x90