본문 바로가기
Machine Learning/scikit-learn

[교차 검증] cross_val_score()

by ISLA! 2023. 9. 26.

교차 검증을 간편하게 할 수 있는  ▶︎  cross_val_score()

 

 

  • 해당 API는 내부에서 학습, 예측, 평가를 시켜주므로, 간단하게 교차 검증을 수행할 수 있음!
  • 분류를 수행할 때는 StratifiedKFold를 사용 / 회귀에서는 K 폴드 사용
  • K-fold 에 비해 검증 정확도가 올라가 일반화가 용이함 (과대적합 방지)

 

📑 적용 방식

1. 폴드 세트를 설정

2. for 루프에서 반복으로 학습 및 데이터 인덱스를 추출

3. 반복적으로 학습&예측을 수행 후, 예측 성능 반환

from sklearn.model_selection import cross_val_score, cross_validate

cross_val_score( estimator, X, y = None, scoring = None, cv = None, n_jobs = 1, verbose = 0, fit_params  = None, pre_dispatch = '2*n_jobs' )

 

  • estimator : Classifier 또는 Regressor 알고리즘 클래스
  • X : 피쳐 데이터 세트
  • y : 레이블 데이터 세트
  • scoring : 예측 성능 평가 지표
  • cv : 교차 검증 폴드 수
    • 👉 cv 로 지정된 횟수만큼 scoring 파라미터로 지정된 평가 지표로 평가 결괏값을 배열로 반환
    • 👉 일반적으로 이를 평균하여 평가 수치로 사용

 

 

📑 예제

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score, cross_validate
from sklearn.datasets import load_iris

iris_data = load_iris()
dt_clf = DecisionTreeClassifier(random_state = 11)

data = iris_data.data
label = iris_data.target

# 성능 징표는 정확도, 교차 검증 세트는 3개
scores = cross_val_score(dt_clf, data, label, scoring = 'accuracy', cv = 3)
print('교차 검증별 정확도:', np.round(scores, 4))
print('평균 검증 정확도:', np.round(np.mean(scores), 4))
교차 검증별 정확도: [0.98 0.92 0.98]
평균 검증 정확도: 0.96

 

👀  참고사항

  • cross_validate() 메서드 : 여러 개의 평가 지표를 반환하며, 학습 데이터에 대한 성능 평가 지표와 수행 시간도 함께 제공
  • cross_val_score() 메서드 : 한 개의 평가 지표만 가능(대부분의 경우 사용)

 

728x90