Machine Learning/scikit-learn
[교차 검증] cross_val_score()
ISLA!
2023. 9. 26. 23:57
교차 검증을 간편하게 할 수 있는 ▶︎ cross_val_score()
- 해당 API는 내부에서 학습, 예측, 평가를 시켜주므로, 간단하게 교차 검증을 수행할 수 있음!
- 분류를 수행할 때는 StratifiedKFold를 사용 / 회귀에서는 K 폴드 사용
- K-fold 에 비해 검증 정확도가 올라가 일반화가 용이함 (과대적합 방지)
📑 적용 방식
1. 폴드 세트를 설정
2. for 루프에서 반복으로 학습 및 데이터 인덱스를 추출
3. 반복적으로 학습&예측을 수행 후, 예측 성능 반환
from sklearn.model_selection import cross_val_score, cross_validate
cross_val_score( estimator, X, y = None, scoring = None, cv = None, n_jobs = 1, verbose = 0, fit_params = None, pre_dispatch = '2*n_jobs' )
- estimator : Classifier 또는 Regressor 알고리즘 클래스
- X : 피쳐 데이터 세트
- y : 레이블 데이터 세트
- scoring : 예측 성능 평가 지표
- cv : 교차 검증 폴드 수
- 👉 cv 로 지정된 횟수만큼 scoring 파라미터로 지정된 평가 지표로 평가 결괏값을 배열로 반환
- 👉 일반적으로 이를 평균하여 평가 수치로 사용
📑 예제
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score, cross_validate
from sklearn.datasets import load_iris
iris_data = load_iris()
dt_clf = DecisionTreeClassifier(random_state = 11)
data = iris_data.data
label = iris_data.target
# 성능 징표는 정확도, 교차 검증 세트는 3개
scores = cross_val_score(dt_clf, data, label, scoring = 'accuracy', cv = 3)
print('교차 검증별 정확도:', np.round(scores, 4))
print('평균 검증 정확도:', np.round(np.mean(scores), 4))
교차 검증별 정확도: [0.98 0.92 0.98]
평균 검증 정확도: 0.96
👀 참고사항
- cross_validate() 메서드 : 여러 개의 평가 지표를 반환하며, 학습 데이터에 대한 성능 평가 지표와 수행 시간도 함께 제공
- cross_val_score() 메서드 : 한 개의 평가 지표만 가능(대부분의 경우 사용)
728x90