前回からの続きです。
今回は、 k近傍法によるあやめの分類問題において、グリッドサーチを使うことでハイパーパラメータの探索を行います。
ハイパーパラメータの探索は、sklearn
のサイトでは、以下にまとめられています。
3.2. Tuning the hyper-parameters of an estimator
for ループを用いたグリッドサーチ
グリッドサーチとは、端的に言うと、候補として挙げたハイパーパラメータを全て試して、その中で最も良い結果を残したものを選ぶ手法です。
k 近傍法で使われるハイパーパラメータは、scikit-learn
のk近傍法のページで確認できます。
sklearn.neighbors
.KNeighborsClassifier
それぞれのハイパーパラメータの意味を、アルゴリズムの説明を読んで理解します。
1.6.4. Nearest Neighbor Algorithms
ここでは、以下を使い探索を行うことにしました。
param_gs_knn ={'est__n_neighbors':[1, 3, 5, 7, 9, 11, 15, 21], 'est__weights':['uniform','distance'], 'est__p':[1,2]}
また、グリッドサーチは交差検証と組み合わせて行うのが一般的です。
ループで全てのハイパーパラメータを試す時に交差検証を行うことで、ハイパーパラメータが特定のトレーニングセットに依存してしまうのを防ぎます。
この流れで実装してみます。
まずは、データセットを読み込みます。
In [1]: from sklearn.datasets import load_iris ...: import pandas as pd ...: iris = load_iris() ...: df = pd.DataFrame(iris.data, columns=iris.feature_names)
訓練データとテストデータを分割します。
In [2]: from sklearn.model_selection import train_test_split ...: X_train, X_test, y_train, y_test = train_test_split(df[iris.feature_names], ...: iris.target, ...: test_size=0.2, ...: stratify=iris.target, ...: random_state=0)
探索するハイパーパラメータを定義します。
In [3]: param_gs_knn ={'est__n_neighbors':[1, 3, 5, 7, 9, 11, 15, 21], ...: 'est__weights':['uniform','distance'], ...: 'est__p':[1,2]}
for ループと交差検証でハイパーパラメータの探索を行います。
In [4]: import numpy as np ...: from sklearn.neighbors import KNeighborsClassifier ...: from sklearn.model_selection import cross_val_score ...: ...: best_score = 0 ...: best_params = {} ...: ...: for n_neighbors in param_gs_knn['est__n_neighbors']: ...: for weights in param_gs_knn['est__weights']: ...: for p in param_gs_knn['est__p']: ...: knn = KNeighborsClassifier(n_neighbors=n_neighbors, ...: weights=weights, ...: p=p) ...: # 交差検証によるハイパーパラメータの探索 ...: scores = cross_val_score(knn, X_train, y_train, cv=5) ...: score = np.mean(scores) ...: if score > best_score: ...: best_score = score ...: best_params = {'n_neighbors': n_neighbors, ...: 'weights': weights, ...: 'p': p}
最も結果の良かったハイパーパラメータとその交差検証のスコアを確認します。
In [5]: print(best_score) ...: print(best_params) 0.9666666666666666 {'n_neighbors': 11, 'weights': 'uniform', 'p': 2}
最も結果の良かったハイパーパラメータを用いて学習を行い、テストデータでテストします。
In [6]: knn = KNeighborsClassifier(**best_params) ...: knn.fit(X_train, y_train) ...: socre = knn.score(X_test, y_test) ...: print(score) 0.9666666666666666
GridSearchCV
scikit-learn
には、 GridSearchCV というグリッドサーチのための交差検証が用意されており、これを使うことでグリッドサーチを簡単に行うことができます。
sklearn.model_selection
.GridSearchCV
データの読み込みを行います。
In [1]: from sklearn.datasets import load_iris ...: import pandas as pd ...: iris = load_iris() ...: df = pd.DataFrame(iris.data, columns=iris.feature_names)
データを分割します。
In [2]: from sklearn.model_selection import train_test_split ...: X_train, X_test, y_train, y_test = train_test_split(df[iris.feature_names], ...: iris.target, ...: test_size=0.2, ...: stratify=iris.target, ...: random_state=0)
用いるモデルと探索するハイパーパラメータを決めます。
In [3]: from sklearn.neighbors import KNeighborsClassifier ...: knn = KNeighborsClassifier() ...: ...: param_gs_knn ={'n_neighbors':[1, 3, 5, 7, 9, 11, 15, 21], ...: 'weights':['uniform','distance'], ...: 'p':[1,2]}
交差検証によるグリッドサーチを GridSearchCV で行います。
In [4]: from sklearn.model_selection import GridSearchCV ...: grid_search = GridSearchCV(knn, param_gs_knn, cv=5) ...: grid_search .fit(X_train, y_train) Out[4]: GridSearchCV(cv=5, error_score='raise-deprecating', estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=None, n_neighbors=5, p=2, weights='uniform'), iid='warn', n_jobs=None, param_grid={'n_neighbors': [1, 3, 5, 7, 9, 11, 15, 21], 'p': [1, 2], 'weights': ['uniform', 'distance']}, pre_dispatch='2*n_jobs', refit=True, return_train_score=False, scoring=None, verbose=0)
ベストスコアとハイパーパラメータを確認します。
In [5]: print(grid_search.best_params_) ...: print(grid_search.best_score_) {'n_neighbors': 11, 'p': 2, 'weights': 'uniform'} 0.9666666666666667
上のハイパーパラメータを用いて学習とテストを行います。
In [5]: print(grid_search.best_params_) ...: print(grid_search.best_score_) {'n_neighbors': 11, 'p': 2, 'weights': 'uniform'} 0.9666666666666667
for ループを使う場合に比べ、簡単にグリッドサーチが行えます。