superi

主にキャリアと金融を嗜むメディア

XGBoostのハイパーパラメータをチューニングする


スポンサーリンク

前回、XGBoostで予測モデルを作成しました。

 

www.superi.jp

 

 精度を上げるためにはパラメーターチューニングが必要です。しかし、こちらのアルゴリズムは設定すべきパラメータが多く、手動で探索すると手間がかかります。

そこでscikit-kearnのgrid_searchを利用して、最適な値を探索します。

公式リファレンスによるとmax_depthmin_child_weightgammaのチューニングが精度改善に重要みたいなので、これらの最善の値を探索します。  

 

事前準備

グリッドサーチにはsklearn.grid_searchのGridSearchCVを使用します。

ライブラリとデータのインポートは前回と同じなので続きから書きたいと思います。

 Grid search

パラメータ選択において、max_depthmin_child_weightgammaに複数値をセットしました。18通りの計算をすることになります。

modelのfitは、GridSearchCVへくわせてからfitさせます。このとき、scoringで何を評価するか指定します。ここではau_ROCを指定していますが、f1precisionrecallも選択できます。

best_params_とbest_score_で、ベストな値を得ています。

 

#Cross Validation

X_train, X_test, Y_train, Y_test = cv.train_test_split(X, Y, random_state=0)
kf = cv.KFold(n=len(X_train), n_folds=5, shuffle=True)

#parameter select
param_grid = {
'learning_rate':[0.1],
'n_estimators':[1000],
'max_depth':[3,5],
'min_child_weight':[1,2,3],
'max_delta_step':[5],
'gamma':[0,3,10],
'subsample':[0.8],
'colsample_bytree':[0.8],
'objective':['binary:logistic'],
'nthread':[4],
'scale_pos_weight':[1],
'seed':[0]}

#model fit
clf = GridSearchCV(XGBClassifier(), param_grid=param_grid, cv=kf, scoring='roc_auc')
clf.fit(X_train, Y_train)
clf.score(X_test, Y_test)

print("Best parameters: %s" % clf.best_params_)
print("Best auroc score: %s" % clf.best_score_)

 

アウトプットは以下のようになります。

Best parameters: {'max_depth': 3, 'subsample': 0.8, 'seed': 0, 'colsample_bytree': 0.8, 'n_estimators': 1000, 'objective': 'binary:logistic', 'gamma': 10, 'nthread': 4, 'max_delta_step': 5, 'scale_pos_weight': 1, 'learning_rate': 0.1, 'min_child_weight': 3}
Best auroc score: 0.828113064769

ちなみにループ処理でベストパラメータ以外もプロットしてみます。

print("Grid scores on development set:")
print()
for params, mean_score, scores in clf.grid_scores_:
print("%0.3f (+/-%0.03f) for %r"
% (mean_score, scores.std() * 2, params))

出力結果は載せませんが、AUCは0.75~0.82の間をとっていました。私の実務経験上、パラメータチューニングをしても精度向上は0.02くらいに留まります。今回のデータセットは770程度のデータ数なので、一部のデータにオーバーフィットした結果の可能性が高いですね。

 

ハイパーパラメータのチューニングですが、むやみに値を入れると計算に時間がかかるので、ある程度当てをつけつつ、並列処理することをお勧めします。