@senspond
>
사이킷런의 GridSearchCV에 대해서 정리해 본 글입니다. 매개변수들에 대한 설명, 성능평가 방법, 샘플예제를 포함하고 있습니다.
안녕하세요. 오늘은 사이킷런의 GridSearchCV에 대해서 정리해보고자 합니다.
GridSearch 는 지정해준 몇 가지 잠재적 Parameter들의 후보군들의 조합 중에서 가장 Best 조합을 찾아줍니다
GridSearch는 sklearn 패키지의 model_selection에 있습니다.
from sklearn.model_selection import GridSearchCV
estimator : classifier, regressor, pipeline 등 가능
param_grid : 튜닝을 위해 파라미터, 사용될 파라미터를 dictionary 형태로 만들어서 넣는다.
scoring : 예측 성능을 측정할 평가 방법을 지정
cv : 교차 검증에서 몇개로 분할되는지 지정한다.
refit : True가 디폴트 / True로 하면 최적의 하이퍼 파라미터를 찾아서 재학습 시킨다.
n-jobs : 병렬처리에 쓰일 코어의 갯수, -1 이면 최대로
return_train_score : 훈련 점수 포함 여부 (True/False)
param_grid = {
'min_samples_leaf': [8, 12, 24, 48],
'min_samples_split': [1, 2, 4],
'max_depth': [6, 8, 12,24],
}
et_reg = ExtraTreesRegressor()
gsc = GridSearchCV(et_reg, param_grid, cv=5,
scoring='neg_mean_squared_error',
n_jobs=-1,
return_train_score=True)
gsc.fit(X_train_scaled, y_train)
print(gsc.best_params_)
print(gsc.best_estimator_)
print(gsc.best_score_)
그리드 서치는 파라미터들의 경우의 수들을 다 돌려가며 최적의 조합을 찾기 때문에 다소 오래걸릴 수가 있습니다. 그리고 학습시킬 데이터가 많다면 더더욱 오래걸리는 것 같은데요. 저는 학습 데이터가 200백만건이 넘어가는 데이터 셋을 준비해서 돌려봤더니 30분 가까이 걸렸네요.
param_gird 는 estimator 에 들어가는 모델에 따라서 달라 질 수가 있습니다.
여기서 scoring 은 클 수록 모델 성능이 좋은 것으로 인식하는데, MSE/MAE 같은 평가를 하려고 한다면 작을 수록 모델성능이 좋은 것이기 때문에, neg_ 를 붙여서 neg_mean_absolute_error 로 써줘야 합니다.
모델 성능 평가방법은 정말 다양하게 있지만 몇가지 추려서 정리해보면 아래와 같습니다.
평가방법 | 사이킷런 평가 지표 API | GridSearchCV scoring 적용값 |
MAE (평균 절대오차) | metrics.mean_absoulte_error | 'neg_mean_absolute_error' |
MSE (평균 제곱오차) | metrics.mean_squared_error | 'neg_mean_squared_error' |
R2 (결정계수) | metrics.r2_score | 'r2' |
AP( 평균 정밀도) | metrics.average_precision_score | 'average_precision' |
F-score | metrics.f1_score | 'f1' |
정확도(예측과 결과가 얼마나 일치) | metrics.accuracy_score | 'accuracy' |
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
사이킷런 내장 데이터셋을 불러와 GridSearchCV로 RandomForestClassifier 의 파라미터 후보군들을 최적화한 후 성능평가 및 피쳐 중요도를 시각화 해보는 예제입니다.
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score
import pandas as pd
# 위스콘신 유방암 데이터 로드, train_test_split으로 데이터 나누기
cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target, test_size=0.2)
sklearn 에 내장되어있는 위스콘신 유방암 데이터셋을 가져와 이진분류기 모델을 그리드서치로 최적화해보는 예시입니다.
from sklearn.model_selection import GridSearchCV
param_grid = {
'min_samples_split' : list(range(2,10,2)),
'min_samples_leaf': list(range(1,4,1)),
'max_depth': list(range(6, 40, 2)),
}
gsc = GridSearchCV(RandomForestClassifier(), param_grid, cv=5,
scoring='accuracy',
n_jobs=-1,
return_train_score=True)
gsc.fit(X_train, y_train)
test_score = gsc.score(X_test, y_test)
print("테스트 세트 점수: {:.2f}".format( test_score ))
print("최고 교차 검증 점수: {:.2f}".format(gsc.best_score_))
print("최적 매개변수: {}".format(gsc.best_params_))
경우의 수는 많으나 학습시키는 데이터가 그리 많지 않기에... 비교적 오래 걸리지는 않습니다.
list(range(6, 40, 2)) 와 같이 써서 [6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38] 와 같은 리스트를 만들어 낼 수 있습니다.
단순 정확도 성능평가 기준으로는 최적의 성능을 내주는 파라미터는 {'max_depth': 6, 'min_samples_leaf': 2, 'min_samples_split': 6} 라는 결과를 얻었습니다.
본래라면 수동으로 일일히 조정해가며 노가다를 해야하는 미세조정이었지만, 그리드서치의 위력입니다.
이번에는 성능평가를 f1_score 로 바꿔봅니다.
위와는 다소 차이가 있는 매개변수 조합을 찾아주었네요.
즉, scoring 방법에 따라서 최적 매개변수가 차이가 날 수 있다는 점이 되겠습니다.
다음은 그리드서치를 통해 얻은 최적 모델을 가지고 피쳐 중요도를 뽑아서 시각해보는 예시입니다.
피쳐 중요도 상위15개 추출
feature_importances = gsc.best_estimator_.feature_importances_
#sorted(zip(feature_importances, cancer.feature_names), reverse=True)[:15]
ser = pd.Series(rf.feature_importances_, index=cancer.feature_names)
top15 = ser.sort_values(ascending=False)[:15]
top15
worst area 0.147701
worst perimeter 0.145425
worst radius 0.106434
worst concave points 0.098417
(...)
피쳐 중요도 상위15개 시각화
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(8,4))
plt.title('Feature Importances Top 15')
sns.barplot(x=top15, y=top15.index, hue=top15.index)
plt.xlabel('feature_importances_')
plt.ylabel('feature_names')
plt.show()
안녕하세요. Red, Green, Blue 가 만나 새로운 세상을 만들어 나가겠다는 이상을 가진 개발자의 개인공간입니다.
현재글에서 작성자가 발행한 같은 카테고리내 이전, 다음 글들을 보여줍니다
@senspond
>