Stratified K 폴드
- Stratified K 폴드는 불균형한 분포도를 가진 레이블 데이터 집합을 위한 K 폴드 방식이다
- 불균형한 분포도를 가진 데이터 집합은 특정 레이블 값이 특이하게 많거나 매우 적어서 값의 분포가 한 쪽으로 치우치는 것을 말한다.
- Stratified K 폴드는 K 폴드가 레이블 데이터 집합이 원본 데이터 집합의 레이블 분포를 학습 및 테스트 세트에 제대로 분배하지 못하는 경우의 문제를 해결해준다
- 따라서, 왜곡된 레이블 데이터 세트에서 분류를 할 때는, 반드시 Stratified K 폴드를 사용해야한다.
- 일반적인 분류에서도 Stratified K 폴드를 사용하며 회귀에서는 지원되지 않는다.(분포 측정이 불필요하기 때문)
📑 동작 방식
1. 원본 데이터의 레이블 분포를 먼저 고려
2. 이 분포와 동일하게 학습과 검증 데이터 세트를 분배
📑 예제
- Kfold와 다른 점은, split() 메서드에 인자로 피처 데이터 세트 & 레이블 데이터 세트를 모두 넣어줘야 한다는 것!
- 그래야 레이블 데이터 세트의 분포를 고려하여 데이터를 분할할 수 있기 때문이다.
- 객체 생성시 n_splits의 값을 먼저 지정해준다.
from sklearn.model_selection import StratifiedKFold
import pandas as pd
iris = load_iris()
iris_df = pd.DataFrame(data = iris.data, columns = iris.feature_names)
iris_df['label'] = iris.target
skf = StratifiedKFold(n_splits=3)
n_iter = 0
for train_index, test_index in skf.split(iris_df, iris_df['label']):
n_iter += 1
label_train = iris_df['label'].iloc[train_index]
label_test = iris_df['label'].iloc[test_index]
print('##교차검증: {0}'.format(n_iter))
print('학습 레이블 데이터 분포:\n', label_train.value_counts())
print('검증 레이블 데이터 분포:\n', label_test.value_counts())
📑 결과
- 0, 1, 2의 3개 범주가 모두 균일하게 학습 / 검증 레이블로 치우치지 않게 분리됨을 확인했다.
📑 예제 2
- StratifiedKFold 를 사용하여, iris data 분류 및 예측 수행
dt_clf = DecisionTreeClassifier(random_state = 11)
skfold = StratifiedKFold(n_splits = 3)
n_iter = 0
cv_accuracy = []
for train_index, test_index in skf.split(features, label):
X_train, X_test = features[train_index], features[test_index]
y_train, y_test = label[train_index], label[test_index]
#학습 및 예측
dt_clf.fit(X_train, y_train)
pred = dt_clf.predict(X_test)
#반복시마다 정확도 측정
n_iter += 1
accuracy = np.round(accuracy_score(y_test, pred), 4)
train_size = X_train.shape[0]
test_size = X_test.shape[0]
print('\n#{0} 교차 검증 정확도 :{1}, 학습 데이터 크기:{2}, 검증 데이터 크기:{3}'.format(n_iter, accuracy, train_size, test_size))
print('#{0} 검증 세트 인덱스:{1}'.format(n_iter, test_index))
cv_accuracy.append(accuracy)
#교차 검증별 정확도 및 평균 정확도 계산
print('\n## 교차 검증 정확도', np.round(cv_accuracy, 4))
print('## 평균 교차 검증 정확도:', np.round(np.mean(cv_accuracy),4))
728x90
'Machine Learning > scikit-learn' 카테고리의 다른 글
[분류_성능 평가 지표] 정확도와 오차 행렬(이진 분류) (0) | 2023.08.18 |
---|---|
[GridSearchCV] 교차 검증 & 하이퍼 파라미터 튜닝을 한 번에 (0) | 2023.08.17 |
[교차 검증] K-폴드 교차 검증 (0) | 2023.08.17 |
[데이터 전처리] MinMaxScaler를 이용한 데이터 정규화 예제 (0) | 2023.08.17 |
[하이퍼 파라미터] hyper parameter란? (0) | 2023.08.17 |