tidymodelsによるtidyな機械学習(その3:ハイパーパラメータのチューニング)
はじめに
前回の記事ではハイパーパラメータのチューニングをfor loopを用いたgrid searchでやっっていました。 tidymodels配下のdialsとtuneを用いることで、より簡単にハイパーパラメータのサーチを行えるので、本記事ではその使い方を紹介したいと思います。 なお、パラメータサーチ以外のtidymodelsの使い方には本記事では言及しないので、以下の記事を参考にして頂ければと思います。
前処理
まずは前回の記事と同様、rsampleで訓練/テストデータの分割を行います。なお、例によってデータはdiamonds
を用います。
# パッケージ library(tidyverse) library(tidymodels) set.seed(42) # Train/Testの分割 df_split = initial_split(diamonds, p = 0.8) df_train = training(df_split) df_test = testing(df_split)
ハイパーパラメータのサーチ
最終的にはtune::tune_grid()
でハイパーパラメータを探索しますが、
そのためにTrain/Validationに分割されたデータ、前処理レシピ、学習用モデル、ハイパーパラメータの候補の4つが必要になります。
最初の3つは以前の記事で触れているので、4つ目のハイパーパラメータに関する部分に詳しく触れていきます。
Train/Validationデータ
これはrsamplesの仕事です。普通はCross Validationで評価するので、それに合わせます。
>df_cv = vfold_cv(df_train, v = 5) > df_cv # 5-fold cross-validation # A tibble: 5 x 2 splits id <named list> <chr> 1 <split [34.5K/8.6K]> Fold1 2 <split [34.5K/8.6K]> Fold2 3 <split [34.5K/8.6K]> Fold3 4 <split [34.5K/8.6K]> Fold4 5 <split [34.5K/8.6K]> Fold5
前処理レシピ
次に、前処理のレシピを作成します。モデルはRandom Forestを使うので、前回の分析同様、前処理は最低限にしておきます。
> rec = recipe(price ~ ., data = df_train) %>% + step_log(price) %>% + step_ordinalscore(all_nominal()) > > rec Data Recipe Inputs: role #variables outcome 1 predictor 9 Operations: Log transformation on price Scoring for all_nominal
学習用モデル
parsnipでモデルを設定します。今回もRandom Forestを使いましょう。
> model = rand_forest(mode = "regression", + trees = 50, # 速度重視 + min_n = tune(), + mtry = tune()) %>% + set_engine("ranger", num.threads = parallel::detectCores(), seed = 42) > > model Random Forest Model Specification (regression) Main Arguments: mtry = tune() trees = 50 min_n = tune() Engine-Specific Arguments: num.threads = parallel::detectCores() seed = 42 Computational engine: ranger
サーチしたいパラメータはここでは値を決めずtune::tune()
を与えます。今回はmin_n
とmtry
をサーチすることにします。
ハイパーパラメータ
ここが本記事での新しい内容になります。
dialesにはparsnipで指定できるハイパーパラメータに関して、探索するレンジを指定するための関数が用意されています。
たとえばmin_n()
はRandom Forestのハイパーパラメータmin_n
に対応する関数で、デフォルトだと2-40のレンジでハイパーパラメータが探索されます。
なお、これは最終ノードに最低でも必要なインスタンスの数を表していて、これを大きくするとより強い正則化がかかります。
> min_n() Minimal Node Size (quantitative) Range: [2, 40]
自分でレンジを決めたい場合は引数で指定することができます。
> min_n(range = c(1, 10)) Minimal Node Size (quantitative) Range: [1, 10]
同様に、もうひとつのハイパーパラメータmtry
に関しても見てみましょう。
こちらは各ツリーでの分割の際に用いる特徴量の数で、これを小さくするとより強い正則化がかかります。
> mtry() # Randomly Selected Predictors (quantitative) Range: [1, ?]
min_n()
とは違って、レンジの最大値が指定されていません。モデルに投入する特徴量の数よりも大きいmtry
を探索しても意味がない*1ので、こちらで直接指定してあげる必要があります。
この際、mtry(range = c(1, 5))
のようにレンジを指定することもできますが、実際にモデルに投入するデータフレームを与えてあげることでレンジを指定することもできます。
> # 前処理済み学習用データ > df_input = rec %>% + prep() %>% + juice() %>% + select(-price) > > finalize(mtry(), df_input) # Randomly Selected Predictors (quantitative) Range: [1, 9]
dials::finalize()
の第1引数にレンジを指定したい関数を、第2引数に特徴量のデータフレームを指定することで、適切なレンジが指定されます。
特徴量の数は分析の途中で変動するので、このやり方は柔軟性があって良いんじゃないかと思います。
さて、これで探索範囲の指定ができるようになりました。
次に、探索したいハイパーパラメータのリストを作ってtune::parameters()
に与えるとparameters
オブジェクトを作ることができます。
> params = list(min_n(), + mtry() %>% finalize(rec %>% prep() %>% juice() %>% select(-price))) %>% + parameters() > > params Collection of 2 parameters for tuning id parameter type object class min_n min_n nparam[+] mtry mtry nparam[+]
このparameters
オブジェクトをdials::grid_*()
に渡すことで、実際に探索するハイパーパラメータの値を作ることができます。
たとえばgrid_regular()
ならグリッドサーチ、grid_random()
ならランダムサーチになります。
今回はランダムサーチにしましょう。サーチの数はsize
で指定できます。
> df_grid = params %>% + grid_random(size = 10) # 実際はもっと多いほうがいい > > df_grid # A tibble: 10 x 2 min_n mtry <int> <int> 1 7 8 2 19 8 3 22 3 4 38 1 5 12 8 6 36 1 7 21 6 8 7 5 9 29 2 10 38 3
チューニング
これでハイパーパラメータの探索準備が整いました!
これまでに作ったオブジェクトをtune::tune_grid()
に渡します。
df_tuned = tune_grid(object = rec, model = model, resamples = df_cv, grid = df_grid, metrics = metric_set(rmse, mae, rsq), control = control_grid(verbose = T))
object
:recipeで作成された前処理レシピを渡します。resamples
に与えたデータがこの前処理を済ませたあとでモデルに投入されます。model
:parsnipで定義した学習用のモデルです。resamples
:rsamplesで作った学習/評価用のデータを渡します。普通はCross Varidationで評価すると思うのでrsamples::vfold_cv()
で作ったデータフレームを渡すのがいいと思います。grid
:dialsとtuneで作ったハイパーパラメータの候補が格納されたデータフレームを渡します。metrics
:精度の評価指標です。yardsticに準備されている関数を指定することができます。control
:指定しなくても構いませんが、今回はログが出力されるようにしています。
学習/評価が終わると、以下のようなデータフレームが手に入ります。
> df_tuned # 5-fold cross-validation # A tibble: 5 x 4 splits id .metrics .notes * <list> <chr> <list> <list> 1 <split [34.5K/8.6K]> Fold1 <tibble [20 × 5]> <tibble [0 × 1]> 2 <split [34.5K/8.6K]> Fold2 <tibble [20 × 5]> <tibble [0 × 1]> 3 <split [34.5K/8.6K]> Fold3 <tibble [20 × 5]> <tibble [0 × 1]> 4 <split [34.5K/8.6K]> Fold4 <tibble [20 × 5]> <tibble [0 × 1]> 5 <split [34.5K/8.6K]> Fold5 <tibble [20 × 5]> <tibble [0 × 1]>
.metrics
に予測精度が格納されています。unnest()
でもとってこれますが、手っ取り早い関数としてtune::collect_metrics()
が準備されています。
> df_tuned %>% + collect_metrics() # A tibble: 30 x 7 mtry min_n .metric .estimator mean n std_err <int> <int> <chr> <chr> <dbl> <int> <dbl> 1 1 38 mae standard 0.120 5 0.00131 2 1 38 rmse standard 0.159 5 0.00173 3 1 38 rsq standard 0.979 5 0.000246 4 2 29 mae standard 0.0792 5 0.000317 5 2 29 rmse standard 0.107 5 0.000518 6 2 29 rsq standard 0.989 5 0.000160 7 4 19 mae standard 0.0677 5 0.000174 8 4 19 rmse standard 0.0932 5 0.000441 9 4 19 rsq standard 0.992 5 0.000109 10 5 3 mae standard 0.0651 5 0.000176 # … with 20 more rows
特に精度の高いハイパーパラメータの候補が知りたい場合は、tune::show_best()
で確認することができます。
> df_tuned %>% + show_best(metric = "rmse", n_top = 3, maximize = FALSE) # A tibble: 3 x 7 mtry min_n .metric .estimator mean n std_err <int> <int> <chr> <chr> <dbl> <int> <dbl> 1 7 8 rmse standard 0.0910 5 0.000702 2 5 3 rmse standard 0.0912 5 0.000558 3 7 16 rmse standard 0.0915 5 0.000679
一番精度の高かったハイパーパラメータを使って全訓練データでモデルを再学習する場合は、tune::select_best()
でハイパーパラメータをとってきてupdate()
でモデルをアップデートできます。
# 一番精度の良かったハイパーパラメータ > df_best_param = df_tuned %>% + select_best(metric = "rmse", maximize = FALSE) > df_best_param # A tibble: 1 x 2 mtry min_n <int> <int> 1 7 8 # モデルのハイパーパラメータを更新 > model_best = update(model, df_best_param) > model_best Random Forest Model Specification (regression) Main Arguments: mtry = 7 trees = 50 min_n = 8 Engine-Specific Arguments: num.threads = parallel::detectCores() seed = 42 Computational engine: ranger
あとは通常通り学習・評価を行うだけです。
まとめ
本記事ではtidymodels配下のdialsとtuneを用いたハイパーパラメータのサーチを紹介しました。 前回の記事ではCross Validationデータにどうアクセスするか結構ややこしかったと思うのですが、tuneを用いることでよりスッキリと探索と評価を行うことができます。 dialsとtuneは公式ドキュメントに個別の使い方はまとまっているのですが、目的がわかりにくい部分や、どの情報が最新かわからないところもあり、典型的なタスクに対しての用途を自分でまとめ直したという経緯になります。
ちなみに、今回はランダムサーチを用いましたが、tuneではbaysian optimizationを用いたパラメータ探索も可能となっています。ぜひ確認してみて下さい*2
本記事で使用したコードは以下にまとめてあります。
参考文献
*1:実際は特徴量の数よりも大きい値を入れるとエラーを吐きます