본문 바로가기
Machine Learning/scikit-learn

[교차 검증] Stratified K 폴드

by ISLA! 2023. 8. 17.

Stratified K 폴드

 

 

  • Stratified K 폴드는 불균형한 분포도를 가진 레이블 데이터 집합을 위한 K 폴드 방식이다
  • 불균형한 분포도를 가진 데이터 집합은 특정 레이블 값이 특이하게 많거나 매우 적어서 값의 분포가 한 쪽으로 치우치는 것을 말한다.
  • Stratified K 폴드는 K 폴드가 레이블 데이터 집합이 원본 데이터 집합의 레이블 분포를 학습 및 테스트 세트에 제대로 분배하지 못하는 경우의 문제를 해결해준다
  • 따라서, 왜곡된 레이블 데이터 세트에서 분류를 할 때는, 반드시 Stratified K 폴드를 사용해야한다.
  • 일반적인 분류에서도 Stratified K 폴드를 사용하며 회귀에서는 지원되지 않는다.(분포 측정이 불필요하기 때문)

 

📑 동작 방식

1. 원본 데이터의 레이블 분포를 먼저 고려

2. 이 분포와 동일하게 학습과 검증 데이터 세트를 분배

 

 

📑 예제

  • Kfold와 다른 점은,  split() 메서드에 인자로 피처 데이터 세트 & 레이블 데이터 세트를 모두 넣어줘야 한다는 것!
  • 그래야 레이블 데이터 세트의 분포를 고려하여 데이터를 분할할 수 있기 때문이다.
  • 객체 생성시 n_splits의 값을 먼저 지정해준다.
from sklearn.model_selection import StratifiedKFold
import pandas as pd

iris = load_iris()
iris_df = pd.DataFrame(data = iris.data, columns = iris.feature_names)
iris_df['label'] = iris.target

skf = StratifiedKFold(n_splits=3)
n_iter = 0

for train_index, test_index in skf.split(iris_df, iris_df['label']):
    n_iter += 1
    label_train = iris_df['label'].iloc[train_index]
    label_test = iris_df['label'].iloc[test_index]
    
    print('##교차검증: {0}'.format(n_iter))
    print('학습 레이블 데이터 분포:\n', label_train.value_counts())
    print('검증 레이블 데이터 분포:\n', label_test.value_counts())

 

📑 결과

  • 0, 1, 2의 3개 범주가 모두 균일하게 학습 / 검증 레이블로 치우치지 않게 분리됨을 확인했다.

결과


📑 예제 2

  • StratifiedKFold 를 사용하여, iris data 분류 및 예측 수행
dt_clf = DecisionTreeClassifier(random_state = 11)

skfold = StratifiedKFold(n_splits = 3)
n_iter = 0
cv_accuracy = []

for train_index, test_index in skf.split(features, label):
    X_train, X_test = features[train_index], features[test_index]
    y_train, y_test = label[train_index], label[test_index]
    
    #학습 및 예측
    dt_clf.fit(X_train, y_train)
    pred = dt_clf.predict(X_test)
    
    #반복시마다 정확도 측정
    n_iter += 1
    accuracy = np.round(accuracy_score(y_test, pred), 4)
    train_size = X_train.shape[0]
    test_size = X_test.shape[0]
    print('\n#{0} 교차 검증 정확도 :{1}, 학습 데이터 크기:{2}, 검증 데이터 크기:{3}'.format(n_iter, accuracy, train_size, test_size))
    print('#{0} 검증 세트 인덱스:{1}'.format(n_iter, test_index))
    cv_accuracy.append(accuracy)
    
#교차 검증별 정확도 및 평균 정확도 계산
print('\n## 교차 검증 정확도', np.round(cv_accuracy, 4))
print('## 평균 교차 검증 정확도:', np.round(np.mean(cv_accuracy),4))

 

728x90