Dropout

データ分析が好きです。マーケティングとか広告とか因果推論とか機械学習の解釈性に興味があります。

SHAP(SHapley Additive exPlanations)で機械学習モデルを解釈する


はじめに

ブラックボックスモデルを解釈する手法として、協力ゲーム理論のShapley Valueを応用したSHAP(SHapley Additive exPlanations)が非常に注目されています。SHAPは各インスタンスの予測値の解釈に使えるだけでなく、Partial Dependence Plotのように予測値と変数の関係をみることができ、さらに変数重要度としても解釈が可能であるなど、ミクロ的な解釈からマクロ的な解釈までを一貫して行える点で非常に優れた解釈手法です。

SHAPの論文の作者によって使いやすいPythonパッケージが開発されていることもあり、実際にパッケージを使った実用例はたくさん見かけるので、本記事では協力ゲーム理論の具体例、Shapley Valueのコンセプトと求め方、機械学習モデルを解釈するためのShapley Valueの使われ方を意識してまとめました。


※この記事をベースに2020年1月16日に行われたData Gateway Talk vol.5にて発表した資料は以下になります
speakerdeck.com

この記事で書いていること、書いていないこと

書いていること

  • 協力ゲーム理論の具体例とShapley Valueのコンセプトと求め方
  • 機械学習モデルを解釈するためのShapley Valueの使われ方
  • shapパッケージの簡単な使い方

書いていないこと

アルバイトゲームとShapley Value

ここでは協力ゲーム理論のアルバイトゲームを例にとって、Shapley Valueを直感的に理解することを目指します。

まずはアルバイトゲームについて説明します。*1
アルバイトは単独で行うこともできますし、他の人とチームを組んでやってもいいとします。

  • まずは、1人で働いた場合は、A君が1人でやると6万円、B君が1人でやると4万円、C君が1人でやると2万円がもらえるとしましょう。
  • 次に、2人で働いた場合は、A君とB君が2人でやったときは合計で20万円、A君とC君が2人でやったときは合計で15万円、B君とC君が2人でやったときは合計で10万円がもらえるとしましょう。
  • 最後に、A君B君C君の3人で働いた場合は、合計で24万円がもらえるとしましょう。

これをまとめると以下の表のようになります。

参加者 報酬
A君 6
B君 4
C君 2
A君B君 20
A君C君 15
B君C君 10
A君B君C君 24

今、A君B君C君の3人全員で働いて得た報酬24万円をどうやって分配するのが尤もらしいかを考えます。
直感的には、より貢献度の高い人による多くの報酬を分配するのがフェアな分配のひとつになりそうです。*2
とすると、問題は各人の貢献度をどうやって計るのかということになります。ここで限界貢献度という概念を導入しましょう。これは「各人がアルバイトに参加したときに追加的にどのくらい報酬が増えるか」で計算されます。
たとえば、A君について考えると、限界貢献度は以下のように計算されます。

  • 「誰もいない」→「A君のみ」だと6 - 0 = 6万円
  • 「B君のみ」→「A君とB君」だと20 - 4 = 16万円
  • 「C君のみ」→「A君とC君」だと15 - 2 = 13万円
  • 「B君とC君」→「A君とB君とC君」だと24 - 10 = 14万円

ここでわかるのは、A君の限界報酬はA君が参加する順番に依存するということです。
この影響を打ち消すため、考えられる全ての参加順を用いて平均的な限界貢献度を求めることにしましょう。
発生し得る参加順と、その場合の各人の限界貢献度は以下のようにまとめられます。

参加順 A君の限界貢献度 B君の限界貢献度 C君の限界貢献度
A君→B君→C君 6 14 4
A君→C君→B君 6 9 9
B君→A君→C君 16 4 4
B君→C君→A君 14 4 6
C君→A君→B君 13 9 2
C君→B君→A君 14 8 2

参加順は3!=6通りなので、

  • A君の平均的な限界貢献度は(6 + 6 + 16 + 14 + 13 + 14) / 6 = 11.5万円
  • B君の平均的な限界貢献度は(14 + 9 + 4 + 4 + 9 + 8) / 6 = 8万円
  • C君の平均的な限界貢献度は(4 + 9 + 4 + 6 + 2 + 2) / 6 = 4.5万円

この平均的な限界貢献度のことをShapley Valueと言い、このShapley Valueを用いて報酬を分配しよう、というのがある意味で尤もらしい分配の方法になります。実際、11.5 + 8 + 4.5 = 24万円できれいに分配ができていますし、より貢献度が高い人により多くの報酬が渡るという意味でフェアな分配にもなっています。*3

今回はゲームのプレイヤーがA君B君C君3人のケースを考えました。より一般的なケースとしてN人のプレイヤー\mathcal{N} = \{1, 2, \dots, N\}がゲームに参加するケースを考えると、プレイヤーiのShapley Value\phi_iは以下で計算できます。

\begin{align*}
\phi_i =
\sum_{\mathcal{S} \subset\mathcal{N} \setminus \{i\}}
\frac{S!(N - S - 1)!}{N!}
\bigg(v(\mathcal{S}\cup \{i\}) - v(\mathcal{S})\bigg)
\end{align*}

  • ここで、\mathcal{S} \mathcal{N}からプレイヤーiを除いたプレイヤーの組み合わせです。たとえばA君B君C君3人のケースで、A君のShapley Valueを考える場合は、 \{\emptyset\}, \{B\}, \{C\}, \{B, C\}が該当します。
  • S\mathcal{S}の要素の数、つまりプレイヤーの数になります。先程の例だと、それぞれ0人, 1人, 1人, 2人となります。
  •  v(\cdot)は効用を表す関数です。なので、v(\mathcal{S}\cup \{i\}) - v(\mathcal{S})はプレイヤー iが参加しているときと参加していないときでの効用の差となります。つまり、プレイヤー iが参加したときの限界貢献度となります。

まとめると、プレイヤー iが参加することの限界貢献度を出現する全ての組合わせで求めて平均しています。これはアルバイトゲームの具体例でやったことと全く同じ操作をしています。上式だけみると何を計算しているかよくわからないのですが、具体例をいれて確かめてみると何をやっているかわかりやすいんじゃないかと思います。

機械学習モデルへの応用

最近、協力ゲーム理論のShapley Value機械学習モデルの予測結果を解釈するために利用しよう、という研究が発達してきました。
これらの研究では、モデルに投入したひとつひとつの特徴量をゲームのプレイヤーと見立てて、各特徴量の予測への貢献度をShapley Valueで測ろう、ということをやっています。
具体的なシチュエーションを考えてみましょう。今、特徴量として \mathbf{X} = (X_1, X_2, X_3)の3つがあるとします。
モデルを f(\cdot)とすると、平均的な予測値は E[f( \mathbf{X})]となります。
ここで、ひとつのインスタンスを取り出すと、それぞれ \mathbf{x} = (x_1, x_2, x_3) という値をとっているとします。このとき、予測値としては f(\mathbf{x})が出力されます。

この、平均的な予測値 E[f( \mathbf{X})]と各インスタンスの予測値 f(\mathbf{x})の乖離に対して、各特徴量がどのくらい影響しているのかを調べます。

インスタンスの予測値 f(\mathbf{x})
\begin{align*}
E[f( \mathbf{X} | X_1 = x_1, X_2 = x_2, X_3 = x_3)] & = f( x_1, x_2, x_3 ) = f(\mathbf{x})
\end{align*}
なので、平均的な予測値 E[f( \mathbf{X})]から X_1, X_2, X_3を条件付けていくことで、その特徴量を知ることが各インスタンスの予測に対してどのように影響するかを知ることができます。具体的に、 X_1, X_2, X_3の順で条件づけていくとすると、たとえば下図のような推移が見られます。

f:id:dropout009:20191117175258p:plain

ここで、 \phi_j, j = 1, 2, 3を各特徴量が予測値に与える限界的な効果としています。(なお、 \phi_0は0と平均的な予測値の乖離で、特に重要ではありません。)

まず、特に何の情報もないところから X_1 = x_1という情報を得ると、予測値が \phi_1だけ大きくなります。さらにその状態から X_2 = x_2という情報を得ると、予測値がさらに \phi_2だけ大きくなります。最後に、 X_3 = x_3という情報を得ると、予測値が \phi_3だけ小さくなって、これが最終的なこのインスタンスへの予測結果となります。

ここで、先程のアルバイトゲームとの共通点が現れています。

  • 各特徴量を一つずつ条件付けていくことで予測値に与える限界的な効果を見ていますが、これはアルバイトゲームで言うところの限界貢献度に対応しています。
  • また、今回は X_1, X_2, X_3の順で条件づけていますが、当然別の順番で条件づけていくと予測値に与える限界的な効果は変化します。よって、考えうる全ての順序で限界効果を計算し、それを平均しなければなりません。この平均的な限界効果がShapley Valueに対応します。

詳細は論文をご確認頂きたいのですが、このようにShapley Valueを用いて計算された貢献度は、Shapley Valueが持つ望ましい性質を満たすことが証明されています。
ただ、Shapley Valueを計算するのは非常に計算コストがかかるので、実際の計算ではアルゴリズムごとの近似手法が用いられています。特に、tree系のアルゴリズムに対してはTreeSAHPという高速なアルゴリズムが提案されています。

機械学習モデルに対するShapley Valueを簡単で高速に計算できるPythonパッケージにSHAPがあります。

まずは必要なパッケージを読み込みましょう。

import numpy as np
import pandas as pd

# モデルはRandom Forestを使う
from sklearn.ensemble import RandomForestRegressor

# SHAP(SHapley Additive exPlanations)
import shap
shap.initjs() # いくつかの可視化で必要

データはshapパッケージにあるボストンの不動産価格のデータを用います。
データセットの細かい説明はリンクをご確認下さい。

X, y = shap.datasets.boston()
X.head()

f:id:dropout009:20191117191347p:plain


モデルはRandom Forestを使います。

model = RandomForestRegressor(n_estimators=500, n_jobs=-1)
model.fit(X, y)


ここからがshapの使い方になります。shapにはいくつかのExplainerが用意されていて、まずはExplainerにモデルを渡すします。今回はRandom ForestなのでTreeExplainer()を使います。

explainer = shap.TreeExplainer(model, X)

Explainerにはshap_values()メソッドが用意されています。これにShapley Valueを計算したいインプットの行列を渡すことでShapley Valueを計算できます。shapは高速な計算アルゴリズムが実装されており、特にこのデータセットは500行程度なのですぐに計算は終わりますが、計算コストは依然として高いので大きいデータセットのときは適当にサンプリングして渡す必要があります。

shap_values = explainer.shap_values(X)


まずは一つのインスタンスに対してShapley Valueを確認してみます。これは`force_plot()`を使います。

i = 0
shap.force_plot(explainer.expected_value, shap_values[i,:], X.iloc[i,:])

f:id:dropout009:20191117190407p:plain

basevalueの22.51がが全体平均で、output valueの25.64がこのインスタンスに対する予測値となります。なので、このインスタンスは平均よりも高い値が予測されています。なぜこのような予測になったのかを説明するために、各特徴量がどのくらい大きな因子となっているのかを、Shapley Valueで分解して可視化しています。たとえば、このインスタンスはLSTATが4.98をとっていて、これはこのインスタンスの不動産価格に対して大きなプラスの要因となっています。逆に、RMは6.575となっていて、これはマイナス要因であることも見て取れます。プラスとマイナスを総合するとプラスの方が大きくなっていて、最終的には全体平均より3.13高い予測結果になっていることがわかります。
ちなみにLSTATはその地域に住む低所得層の割合、RMは平均的な部屋の数です。

ひとつひとつのインスタンスでShapley Valueを見ていくことでミクロな分析ができますが、よりマクロな分析として、Shapley Valueを変数ごとに平均して変数重要度のように使うこともできます。
今、データセットインスタンス数が Nとすると、変数 p = 1, \dots, Pの変数重要度 \mathrm{Feature Importance}_pは以下で計算します。

\begin{align*}
\mathrm{Feature Importance}_p = \frac{1}{N}\sum_{i = 1}^{N} \left|\phi_p^i \right|
\end{align*}

ただし、 \phi_p^iインスタンス iでの変数 pのShapley Valueです。また、プラスマイナスの影響を無視するために絶対値をとっています。
この変数重要度を簡単に可視化するための関数として、summry_plot()が用意されています。

shap.summary_plot(shap_values, X, plot_type="bar")

f:id:dropout009:20191117190121p:plain

Shapley Valueの意味で、LSTATとRMが非常に重要な変数であることが見て取れます。


plot_type = "bar"とすると棒グラフが出ますが、指定しないと一つ一つのShapley Valueがそのまま打たれます。*4

shap.summary_plot(shap_values, X)

f:id:dropout009:20191117190134p:plain

上に来るほど先程の棒グラフの意味で重要な変数になります。色が赤いほどその変数の値が高いとき、青いほど低いときのShapley Valueになります。Shapley Valueの分布に加えて、LSTATは低いほうがプラスの要因に、RMは高いほうがプラスの要因になりそうなことが見て取れます。




さらに、特徴量の値とShapley Valueの散布図を書くことで、Partial Dependence Plotのような可視化も可能となります。
これは dependence_plot()を使います。今回は特に重要度の高かったLSTATを見てみましょう。

shap.dependence_plot("LSTAT", shap_values, X)

f:id:dropout009:20191117190116p:plain

LSTATの値が大きくなるほどShapley Valueが小さくなることが見て取れます。
色付けは交互作用が見れるように他の変数の値でされています。デフォルトでは交互作用が一番はっきり現れる変数が自動で選ばれます。今回はDISが選ばれていて、DISの値が大きいほど赤く、小さいほど青い点が打たれます。DISはemployment centreからの距離を表しています。グラフの左半分を見ると、同じLSTATの値でもDISが短いほどShapley Valueが高くなる傾向が見て取れます。


*1:この数値例は岡田章『新板ゲーム理論』の例をそのままお借りしています。

*2:たとえば、山分けで3等分も考えられますが、この設定ではそれはうまくいきません。A君B君C君の3人で働いて3等分すると報酬は8万円になります。この場合、A君B君だけで働いて報酬を2等分するとA君とB君は10万円を稼ぐことができ、C君を外すインセンティブが生まれてしまいます。

*3:Shapley Valueは数学的に証明された望ましい性質をいくつかもっていますが、ここでは具体例と直感的な性質のみ説明しています。詳細を知りたい場合は協力ゲーム理論の教科書か論文をご確認下さい。

*4:僕はこれの正式な名称を知りません。sina plotでいいのかな?

tidymodelsとDALEXによるtidyで解釈可能な機械学習



※この記事をベースにした2020年1月25日に行われた第83回Japan.Rでの発表資料は以下になります。
speakerdeck.com


はじめに

本記事では、tidymodelsを用いて機械学習モデルを作成し、それをDALEXを用いて解釈する方法をまとめています。
DALEX

Collection of tools for Visual Exploration, Explanation and Debugging of Predictive Models

を目的とした複数のパッケージから成り立っています。
具体的にどんなパッケージが所属しているかはリンクを確認して下さい。本記事では、特にメインパッケージであるDALEXingredientsを用います。
ingredientsには、変数重要度、Partial Dependence Plot(PDP)、Individual Conditional Expectation(ICE)、Accumulated Local Effect(ALE)など、特徴量とターゲット変数の関係を見るための関数が用意されています。

なお、本記事ではtidymodelsの使い方には触れないので、過去の記事を参考にして頂ければと思います。

dropout009.hatenablog.com
dropout009.hatenablog.com
dropout009.hatenablog.com

また、本記事では解釈手法そのものには詳しく触れません。Permutationベースの変数重要度やPDP, ICEの詳細な説明は以下の記事も参考にして頂ければと思います。
dropout009.hatenablog.com

パッケージ

本記事で用いるパッケージは以下になります。

library(tidymodels)
library(DALEX) #解釈
library(ingredients) #解釈

library(distributions3) #シミュレーション用
library(colorblindr) #可視化用

set.seed(42)

また、可視化用に以下の関数を用意しておきます

# visualization -----------------------------------------------------------
# ggplot2テーマ
theme_scatter = function() {
  theme_minimal(base_size = 12) %+replace%
    theme(panel.grid.minor = element_blank(),
          panel.grid.major = element_line(color = "gray", size = 0.1),
          legend.position = "top",
          axis.title = element_text(size = 15, color = "black"),
          axis.title.x = element_text(margin = margin(10, 0, 0, 0), hjust = 1),
          axis.title.y = element_text(margin = margin(0, 10, 0, 0), angle = 90, hjust = 1),
          axis.text = element_text(size = 12, color = "black"),
          strip.text = element_text(size = 15, color = "black", margin = margin(5, 5, 5, 5)),
          plot.title = element_text(size = 15, color = "black", margin = margin(0, 0, 18, 0)))
}

シミュレーション1

データ

PDPとICEの関係を浮き彫りにするため、今回はシミュレーションデータを用意します。

特徴量として X_1, X_2, X_3の3つが観測されていいて、
 X_2 X_1より強く Yと関係しているが、 X_3は関係してないという状況を考えます。
できるだけ簡単なほうが結果がわかりやすいと思うので、以下の線形で加法な形に特定します。*1

\begin{align}
Y &= X_1 - 5X_2 + \epsilon, \\
X_1, X_2, X_3 &\sim U(-1, 1),\\
\epsilon &\sim \mathcal{N}(0, 0.1^2)
\end{align}

ここで、 X_1, X_2, X_3はそれぞれ独立に区間 [0, 1]に一様分布し、 Yには平均0で標準偏差0.1の正規分布に従うノイズが乗るとしています。

シミュレーションのために乱数を発生させます。
Rには確率分布を扱うための関数がデフォルトで用意されていますが、distributions3を用いると、より直感的に確率分布を扱うことができます。

N = 1000 #サンプルサイズ

U = Uniform(-1, 1) # 分布を指定
Z = Normal(mu = 0, sigma = 0.1)

X1 = random(U, N) # 分布から乱数を生成
X2 = random(U, N)
X3 = random(U, N)
E = random(Z, N)

Y = X1 - 5*X2 + E 

df = tibble(Y, X1, X2, X3) # データフレームに

データを可視化して X Yの関係を確認しておきましょう。

df %>% 
  sample_n(200) %>% 
  pivot_longer(cols = contains("X"), values_to = "X") %>% 
  ggplot(aes(X, Y)) +
  geom_point(size = 3, 
             shape = 21,
             color = "white", fill = palette_OkabeIto[5], 
             alpha = 0.7) +
  facet_wrap(~name) + 
  theme_scatter()

f:id:dropout009:20191116212342p:plain

 Y X_1とは弱い正の関係が、 X_2とは強い負の関係があること、 X_3とは無相関であることが見て取れます。

モデル

データの準備ができたので、モデルを作成します。
モデルはtidymodelsparsnipを使って作成します。

ブラックボックスモデルとして、Random Forestを使うことにします。

fitted = rand_forest(mode = "regression", 
                     trees = 1000,
                     mtry = 3,
                     min_n = 1) %>% 
  set_engine(engine = "ranger", 
             num.threads = parallel::detectCores(), 
             seed = 42) %>% 
  fit(Y ~ ., data = df)
  • rand_forest()でモデルにRandom Forestを利用することと、そのハイパーパラメータを指定
  • set_engin()で利用パッケージとパッケージ固有のパラメータを指定
  • fit()でデータを指定して学習

までを一気に行っています。

DALEXによる解釈

モデルの学習が終わったので、DALEXを用いてモデルの振る舞いを解釈していきましょう。
DALEXによる解釈は、DALEX::explain()でexplainerオブジェクトを作るところから始まります。
嬉しいことに、parsnipで学習したモデルはexplain()にそのまま与えることができます。

> explainer = explain(fitted,# 学習済みモデル
+                     data = df %>% select(-Y), # インプット
+                     y = df %>% pull(Y),# ターゲット
+                     label = "Random Forest")# ラベルをつけておくことができる(なくてもいい)

Preparation of a new explainer is initiated
  -> model label       :  Random Forest 
  -> data              :  1000  rows  3  cols 
  -> data              :  tibbble converted into a data.frame 
  -> target variable   :  1000  values 
  -> predict function  :  yhat.model_fit  will be used (default)
  -> predicted values  :  numerical, min =  -5.886102 , mean =  0.008677733 , max =  5.799964  
  -> residual function :  difference between y and yhat (default)
  -> residuals         :  numerical, min =  -0.1736233 , mean =  -8.650032e-05 , max =  0.1944877  
A new explainer has been created!

あとはexplainerにingredientsで用意された様々な解釈手法を適応するだけです。

変数重要度

まずは変数重要度を見てみましょう。これはexplainerをingredients::feature_importance()に与えるだけで計算できます。

> fi = feature_importance(explainer, 
+                         loss_function = loss_root_mean_square, # 精度の評価関数
+                         type = "raw") # "ratio"にするとフルモデルと比べて何倍悪化するかが出る
> fi
      variable dropout_loss         label
1 _full_model_   0.05063265 Random Forest
2           X3   0.06979976 Random Forest
3           X1   0.80947038 Random Forest
4           X2   4.19405191 Random Forest
5   _baseline_   4.26605636 Random Forest

DALEXは学習アルゴリズムに依存しない手法がベースになっているので、変数重要度の計算もpermutationベースのものが使われます。
つまり、その変数をシャッフルしてどのくらい予測精度が落ちるのかがその変数の重要度として定義とされています。
なお、ingredientsによって作成されたオブジェクトはplot()で簡単に可視化することができます。ggplot2がベースになっているので、ggplot2の設定やレイヤーを+で重ねていくこともできます。

plot(fi)

f:id:dropout009:20191116221221p:plain

変数重要度を見ると、 X_2の重要度が一番高く、 X_1は想定的に重要度が低く、 X_3は全く重要ではないとしていることがわかります。これはそもそものシミュレーションの設定に沿っています。

PDP

次はPDPを見てみましょう。PDPはモデルの平均的な予測結果を可視化したものです。
これもexplainerをingredients::partial_dependency()に与えるだけで計算できます。

pdp = partial_dependency(explainer)

pdp %>% 
  plot() + # ggplot2のレイヤーや設定を重ねていくことができる
  scale_y_continuous(breaks = seq(-6, 6, 2)) + 
  theme_scatter() +
  theme(legend.position = "none")

f:id:dropout009:20191116221810p:plain


モデルが X_1, X_2, X_3 Yの関係をうまく捉えられていることが見て取れます。

シミュレーション2

データの作成

PDPは交互作用がない場合はうまくモデルを解釈する事ができますが、交互作用がある場合は可視化がうまく機能しないことが知られています。
以下のようなシミュレーションを考えてみましょう。

\begin{align}
Y &= X_1 - 5X_2 + 10X_2X_3 + \epsilon, \\
X_1, X_2 &\sim U(-1, 1),\\
X_1, X_2 &\sim Bernoulli(0.5),\\
\epsilon &\sim \mathcal{N}(0, 0.1^2)
\end{align}

今回、 X_3 = 0は0か1をとる変数で、それ単体では効果はありませんが、 X_2との交互作用があり、 X_2 X_3 = 0のときは正の、 X_3 = 1のときは負の効果があるという特定化になっています。

N = 1000

U = Uniform(-1, 1)
B = Bernoulli(p = 0.5)
Z = Normal(mu = 0, sigma = 0.1)

X1 = random(U, N)
X2 = random(U, N)
X3 = random(B, N)
E = random(Z, N)

Y = X1 - 5*X2 + 10*X2*X3 + E

df = tibble(Y, X1, X2, X3)


df %>% 
  sample_n(200) %>% 
  pivot_longer(cols = contains("X"), values_to = "X") %>% 
  ggplot(aes(X, Y)) +
  geom_point(size = 3, 
             shape = 21,
             color = "white", fill = palette_OkabeIto[5], 
             alpha = 0.7) +
  facet_wrap(~name) + 
  theme_scatter()

f:id:dropout009:20191116223552p:plain

DALEXによる解釈

PDP

シミュレーション1と同じくモデルを作成、学習し、explainerオブジェクトを作ります。

fitted = rand_forest(mode = "regression", 
                     trees = 1000,
                     mtry = 3,
                     min_n = 1) %>% 
  set_engine(engine = "ranger", 
             num.threads = parallel::detectCores(), 
             seed = 42) %>% 
  fit(Y ~ ., data = df)


explainer = DALEX::explain(fitted,
                           data = df %>% select(-Y),
                           y = df %>% pull(Y),
                           label = "Random Forest")

PDPも同様に作成します。
このシミュレーションでは Y X_2の関係をうまく可視化できるかにフォーカスすることにします。

pdp = partial_dependency(explainer, variables = "X2") #変数を限定

pdp %>% 
  plot() + 
  geom_abline(slope = -5, color = "gray70", size = 1) + 
  geom_abline(slope = 5, color = "gray70", size = 1) + 
  scale_y_continuous(breaks = seq(-6, 6, 2),
                     limits = c(-6, 6)) + 
  theme_scatter() +
  theme(legend.position = "none")

f:id:dropout009:20191116230316p:plain
青い線がPDP、グレーの線が本来の関係です。
PDPを見ると、 X_2が動いても予測値にはまるで効果がないように見えます。
念の為ですが、モデルの予測自体はうまくいっています。
Random Forestは交互作用をうまく学習できますし、実際、OOBでの R^2は0.96です。

> fitted
parsnip model object

Fit in:  275msRanger result

Call:
 ranger::ranger(formula = formula, data = data, mtry = ~3, num.trees = ~1000,      min.node.size = ~1, num.threads = ~parallel::detectCores(),      seed = ~42, verbose = FALSE) 

Type:                             Regression 
Number of trees:                  1000 
Sample size:                      1000 
Number of independent variables:  3 
Mtry:                             3 
Target node size:                 1 
Variable importance mode:         none 
Splitrule:                        variance 
OOB prediction error (MSE):       0.3297557 
R squared (OOB):                  0.9624996 

ここで起きている現象は、「モデルは交互作用を学習できているが、PDPはそれをうまく可視化できていない」というものです。PDPは交互作用を平均化してしまうため、プラスの交互作用とマイナスの交互作用が相殺して効果がないように見えてしまっています。

ICE Plot

交互作用をうまく捉える手法の一つがICEになります。
ICEは単純にPDPの平均化する前の出力を全てプロットするというものです。
平均化をしていないので交互作用が相殺されることを防ぐことができます。
例によって、ingredients::ceteris_paribus()にexplainerを与えるだけで計算できます。

# 線が多すぎるとわけがわからなくなるので100サンプルだけ抜き出す
# tibbleのまま渡すと警告が出るのでdata.frameにしている。
# 警告の内容を見るとうまく動かなさそうだが、いまのところtibbleのままでもうまく動いているように思う
ice = ceteris_paribus(explainer, 
                      variables = "X2",
                      new_observation = df %>% sample_n(100) %>% as.data.frame()) 

ice %>% 
  plot(alpha = 0.5, size = 0.5, color = colors_discrete_drwhy(1)) + 
  geom_abline(slope = -5, color = "gray70", size = 1) + 
  geom_abline(slope = 5, color = "gray70", size = 1) +
  scale_y_continuous(breaks = seq(-6, 6, 2),
                     limits = c(-6, 6)) + 
  theme_scatter() +
  theme(legend.position = "none")

f:id:dropout009:20191116230534p:plain
青い線の一本一本が各インスタンスの予測結果に対応しています。
PDPとは違って、ICEではモデルが交互作用を学習していることをうまく可視化することができています。
なお、各線のタテ方向のブレはインスタンスごとの X_1の大きさによるものです。仮に X_1の値が全てのインスタンスで共通なら一直線上に並んだきれいなX印になります。

Conditional PDP

PDPは交互作用のある変数についても平均化してしまうのが問題でした。
ICEはひとつの解決策でしたが、これはこれでやたら線が多くなってしまいますし、値の安定性も低くなります。
今回は X_3=0のときと X_3=1のときでX_2の効果が逆になることが問題でした。
ということは、 X_3の値でグループに分けてPDPを計算することで、交互作用の問題を解決できそうです。
これを行うための関数がingredients::aggregate_profiles()として用意されています。

# ICEを与える。グループを指定するとPDPをグループごとに求めることができる。
conditional_pdp = aggregate_profiles(ice, groups = "X3") 

conditional_pdp %>% 
  plot() + 
  geom_abline(slope = -5, color = "gray70", size = 1) + 
  geom_abline(slope = 5, color = "gray70", size = 1) +
  scale_y_continuous(breaks = seq(-6, 6, 2),
                     limits = c(-6, 6)) + 
  theme_scatter() +
  theme(legend.position = "none")

f:id:dropout009:20191116232138p:plain

単純なPDPでは捉えられなかった交互作用を可視化することに成功しました。
現実のシチュエーションでは、たとえば男女でグループ分けすることで、モデルが男女の効果の違いを捉えているのかを確認することができます。

clusterd ICE Plot

今回はシミュレーションなので X_3でグループ分けすればいいことがわかっていましたが、実際にはどの変数でグループ化すればいいかわからない場合も多いかと思います。
ingredients::cluster_profiles()を用いることで、似たようなICEをクラスター化してPDPを計算してくれます。

# ICEを与える。クラスター数を指定する。
clustered_ice = cluster_profiles(ice, k = 2)

clustered_ice %>% 
  plot() + 
  geom_abline(slope = -5, color = "gray70", size = 1) + 
  geom_abline(slope = 5, color = "gray70", size = 1) +
  scale_y_continuous(breaks = seq(-6, 6, 2),
                     limits = c(-6, 6)) + 
  theme_scatter() +
  theme(legend.position = "none")

f:id:dropout009:20191116232848p:plain

自分で X_3でグループ化した場合と同じく、単純なPDPでは捉えられなかった交互作用を可視化することに成功しました。

まとめ

本記事では、tidymodelsを用いて機械学習モデルを作成し、それをDALEXingredientsを用いて解釈する方法をまとめました。
もう一つの重要なパッケージであるiBreakDownは別の記事でまとめたいと思っています。

本記事で使用したコードは以下にまとめてあります。

github.com

*1:実際にこのデータに遭遇したら線形回帰を使ったほうがいいというのはありますが、わかりやすさのためこうしました

tidymodelsによるtidyな機械学習(その3:ハイパーパラメータのチューニング)

はじめに

前回の記事ではハイパーパラメータのチューニングをfor loopを用いたgrid searchでやっっていました。 tidymodels配下のdialstuneを用いることで、より簡単にハイパーパラメータのサーチを行えるので、本記事ではその使い方を紹介したいと思います。 なお、パラメータサーチ以外のtidymodelsの使い方には本記事では言及しないので、以下の記事を参考にして頂ければと思います。

dropout009.hatenablog.com

dropout009.hatenablog.com

前処理

まずは前回の記事と同様、rsampleで訓練/テストデータの分割を行います。なお、例によってデータはdiamondsを用います。

# パッケージ
library(tidyverse)
library(tidymodels)
set.seed(42)

# Train/Testの分割
df_split = initial_split(diamonds,  p = 0.8)

df_train = training(df_split)
df_test  = testing(df_split)

ハイパーパラメータのサーチ

最終的にはtune::tune_grid()でハイパーパラメータを探索しますが、 そのためにTrain/Validationに分割されたデータ、前処理レシピ、学習用モデル、ハイパーパラメータの候補の4つが必要になります。 最初の3つは以前の記事で触れているので、4つ目のハイパーパラメータに関する部分に詳しく触れていきます。

Train/Validationデータ

これはrsamplesの仕事です。普通はCross Validationで評価するので、それに合わせます。

>df_cv = vfold_cv(df_train, v = 5)

> df_cv
#  5-fold cross-validation 
# A tibble: 5 x 2
  splits               id   
  <named list>         <chr>
1 <split [34.5K/8.6K]> Fold1
2 <split [34.5K/8.6K]> Fold2
3 <split [34.5K/8.6K]> Fold3
4 <split [34.5K/8.6K]> Fold4
5 <split [34.5K/8.6K]> Fold5

前処理レシピ

次に、前処理のレシピを作成します。モデルはRandom Forestを使うので、前回の分析同様、前処理は最低限にしておきます。

> rec = recipe(price ~ ., data = df_train) %>% 
+   step_log(price) %>% 
+   step_ordinalscore(all_nominal())
> 
> rec
Data Recipe

Inputs:

      role #variables
   outcome          1
 predictor          9

Operations:

Log transformation on price
Scoring for all_nominal

学習用モデル

parsnipでモデルを設定します。今回もRandom Forestを使いましょう。

> model = rand_forest(mode = "regression",
+                     trees = 50, # 速度重視
+                     min_n = tune(),
+                     mtry = tune()) %>%
+   set_engine("ranger", num.threads = parallel::detectCores(), seed = 42)
> 
> model
Random Forest Model Specification (regression)

Main Arguments:
  mtry = tune()
  trees = 50
  min_n = tune()

Engine-Specific Arguments:
  num.threads = parallel::detectCores()
  seed = 42

Computational engine: ranger 

サーチしたいパラメータはここでは値を決めずtune::tune()を与えます。今回はmin_nmtryをサーチすることにします。

ハイパーパラメータ

ここが本記事での新しい内容になります。 dialesにはparsnipで指定できるハイパーパラメータに関して、探索するレンジを指定するための関数が用意されています。 たとえばmin_n()はRandom Forestのハイパーパラメータmin_nに対応する関数で、デフォルトだと2-40のレンジでハイパーパラメータが探索されます。 なお、これは最終ノードに最低でも必要なインスタンスの数を表していて、これを大きくするとより強い正則化がかかります。

> min_n()
Minimal Node Size  (quantitative)
Range: [2, 40]

自分でレンジを決めたい場合は引数で指定することができます。

> min_n(range = c(1, 10))
Minimal Node Size  (quantitative)
Range: [1, 10]

同様に、もうひとつのハイパーパラメータmtryに関しても見てみましょう。 こちらは各ツリーでの分割の際に用いる特徴量の数で、これを小さくするとより強い正則化がかかります。

> mtry()
# Randomly Selected Predictors  (quantitative)
Range: [1, ?]

min_n()とは違って、レンジの最大値が指定されていません。モデルに投入する特徴量の数よりも大きいmtryを探索しても意味がない*1ので、こちらで直接指定してあげる必要があります。 この際、mtry(range = c(1, 5))のようにレンジを指定することもできますが、実際にモデルに投入するデータフレームを与えてあげることでレンジを指定することもできます。

> # 前処理済み学習用データ
> df_input = rec %>% 
+   prep() %>% 
+   juice() %>% 
+   select(-price)
> 
> finalize(mtry(), df_input)
# Randomly Selected Predictors  (quantitative)
Range: [1, 9]

dials::finalize()の第1引数にレンジを指定したい関数を、第2引数に特徴量のデータフレームを指定することで、適切なレンジが指定されます。 特徴量の数は分析の途中で変動するので、このやり方は柔軟性があって良いんじゃないかと思います。

さて、これで探索範囲の指定ができるようになりました。 次に、探索したいハイパーパラメータのリストを作ってtune::parameters()に与えるとparametersオブジェクトを作ることができます。

> params = list(min_n(),
+               mtry() %>% finalize(rec %>% prep() %>% juice() %>% select(-price))) %>% 
+   parameters()
> 
> params
Collection of 2 parameters for tuning

    id parameter type object class
 min_n          min_n    nparam[+]
  mtry           mtry    nparam[+]

このparametersオブジェクトをdials::grid_*()に渡すことで、実際に探索するハイパーパラメータの値を作ることができます。 たとえばgrid_regular()ならグリッドサーチ、grid_random()ならランダムサーチになります。 今回はランダムサーチにしましょう。サーチの数はsizeで指定できます。

> df_grid = params %>% 
+   grid_random(size = 10) # 実際はもっと多いほうがいい
> 
> df_grid
# A tibble: 10 x 2
   min_n  mtry
   <int> <int>
 1     7     8
 2    19     8
 3    22     3
 4    38     1
 5    12     8
 6    36     1
 7    21     6
 8     7     5
 9    29     2
10    38     3

チューニング

これでハイパーパラメータの探索準備が整いました! これまでに作ったオブジェクトをtune::tune_grid()に渡します。

df_tuned = tune_grid(object = rec,
                     model = model, 
                     resamples = df_cv,
                     grid = df_grid,
                     metrics = metric_set(rmse, mae, rsq),
                     control = control_grid(verbose = T))
  • objectrecipeで作成された前処理レシピを渡します。resamplesに与えたデータがこの前処理を済ませたあとでモデルに投入されます。
  • modelparsnipで定義した学習用のモデルです。
  • resamplesrsamplesで作った学習/評価用のデータを渡します。普通はCross Varidationで評価すると思うのでrsamples::vfold_cv()で作ったデータフレームを渡すのがいいと思います。
  • griddialstuneで作ったハイパーパラメータの候補が格納されたデータフレームを渡します。
  • metrics:精度の評価指標です。yardsticに準備されている関数を指定することができます。
  • control:指定しなくても構いませんが、今回はログが出力されるようにしています。

学習/評価が終わると、以下のようなデータフレームが手に入ります。

> df_tuned
#  5-fold cross-validation 
# A tibble: 5 x 4
  splits               id    .metrics          .notes          
* <list>               <chr> <list>            <list>          
1 <split [34.5K/8.6K]> Fold1 <tibble [20 × 5]> <tibble [0 × 1]>
2 <split [34.5K/8.6K]> Fold2 <tibble [20 × 5]> <tibble [0 × 1]>
3 <split [34.5K/8.6K]> Fold3 <tibble [20 × 5]> <tibble [0 × 1]>
4 <split [34.5K/8.6K]> Fold4 <tibble [20 × 5]> <tibble [0 × 1]>
5 <split [34.5K/8.6K]> Fold5 <tibble [20 × 5]> <tibble [0 × 1]>

.metricsに予測精度が格納されています。unnest()でもとってこれますが、手っ取り早い関数としてtune::collect_metrics()が準備されています。

> df_tuned %>% 
+   collect_metrics()
# A tibble: 30 x 7
    mtry min_n .metric .estimator   mean     n  std_err
   <int> <int> <chr>   <chr>       <dbl> <int>    <dbl>
 1     1    38 mae     standard   0.120      5 0.00131 
 2     1    38 rmse    standard   0.159      5 0.00173 
 3     1    38 rsq     standard   0.979      5 0.000246
 4     2    29 mae     standard   0.0792     5 0.000317
 5     2    29 rmse    standard   0.107      5 0.000518
 6     2    29 rsq     standard   0.989      5 0.000160
 7     4    19 mae     standard   0.0677     5 0.000174
 8     4    19 rmse    standard   0.0932     5 0.000441
 9     4    19 rsq     standard   0.992      5 0.000109
10     5     3 mae     standard   0.0651     5 0.000176
# … with 20 more rows

特に精度の高いハイパーパラメータの候補が知りたい場合は、tune::show_best()で確認することができます。

> df_tuned %>% 
+   show_best(metric = "rmse", n_top = 3, maximize = FALSE)
# A tibble: 3 x 7
   mtry min_n .metric .estimator   mean     n  std_err
  <int> <int> <chr>   <chr>       <dbl> <int>    <dbl>
1     7     8 rmse    standard   0.0910     5 0.000702
2     5     3 rmse    standard   0.0912     5 0.000558
3     7    16 rmse    standard   0.0915     5 0.000679

一番精度の高かったハイパーパラメータを使って全訓練データでモデルを再学習する場合は、tune::select_best()でハイパーパラメータをとってきてupdate()でモデルをアップデートできます。

# 一番精度の良かったハイパーパラメータ
> df_best_param = df_tuned %>% 
+   select_best(metric = "rmse", maximize = FALSE)
> df_best_param
# A tibble: 1 x 2
   mtry min_n
  <int> <int>
1     7     8

# モデルのハイパーパラメータを更新
> model_best = update(model, df_best_param)
> model_best
Random Forest Model Specification (regression)

Main Arguments:
  mtry = 7
  trees = 50
  min_n = 8

Engine-Specific Arguments:
  num.threads = parallel::detectCores()
  seed = 42

Computational engine: ranger 

あとは通常通り学習・評価を行うだけです。

まとめ

本記事ではtidymodels配下のdialstuneを用いたハイパーパラメータのサーチを紹介しました。 前回の記事ではCross Validationデータにどうアクセスするか結構ややこしかったと思うのですが、tuneを用いることでよりスッキリと探索と評価を行うことができます。 dialstuneは公式ドキュメントに個別の使い方はまとまっているのですが、目的がわかりにくい部分や、どの情報が最新かわからないところもあり、典型的なタスクに対しての用途を自分でまとめ直したという経緯になります。

ちなみに、今回はランダムサーチを用いましたが、tuneではbaysian optimizationを用いたパラメータ探索も可能となっています。ぜひ確認してみて下さい*2

本記事で使用したコードは以下にまとめてあります。

github.com

参考文献

*1:実際は特徴量の数よりも大きい値を入れるとエラーを吐きます

*2:Classification Example • tune

Synthetic Difference In Differences(Arkhangelsky et. al., 2019)を読んだ


はじめに

GW中にSynthetic difference in differences(Arkhangelsky, D., Athey, S., Hirshberg, D. A., Imbens, G. W., & Wager, S. (2019)) を読みました。
ランダム化比較試験(RCT)が行えない状況でのパネルデータのから処置効果を推定する際には、バイアスを取り除くためにDifference In DifferencesやSynthetic Controlがよく用いられていますが、
Synthetic Difference In Differencesは名前の通りDifference In DifferencesとSynthetic Controlのいいとこ取りになっています。
パネルデータから因果推論を行う際に非常に強力な武器になると思ったので、本記事でコンセプトをまとめ、論文内の比較実験を再現しました。


コンセプト

セッティング

 i = 1, \dots, Nのユニット(たとえば個人)に対して t = 1, \dots, T時点でのデータがとれているパネルデータを考えます。このとき、アウトカム Y_{it}は以下の N \times T行列で表現できます。*1

\begin{align}
\mathbf{Y}
= \begin{pmatrix}
Y_{11} & \cdots & Y_{1T}\\
\vdots & \ddots & \vdots\\
Y_{N1} & \cdots & Y_{NT}
\end{pmatrix}
\end{align}


簡単のために、一番シンプルな設定として、 i = Nのユニットが  t = T 時点のみ処置を受けるケースを考えます。つまり、 i = 1, \dots, N-1 はコントロール群(c)で i=Nは処置群(t)になりますし、 t = 1, \dots T-1が処置前(pre)、 t = Tが処置後(post)となります。

f:id:dropout009:20190506184615p:plain

ここで、仮想的に i=Nのユニットが t=T時点で処置を受けた場合のアウトカムを Y_{NT}(1)、受けなかった場合のアウトカムを Y_{NT}(0)と書くことにしましょう。このとき、 Y_{NT}(1) - Y_{NT}(0)つまり処置を受けた場合と受けなかった場合の差をみることで、処置がアウトカムに与えるインパクトを測定することができます。ただ、実際に僕たちに観測できるのは処置を受けた Y_{NT}(1)のみで、処置を受けなかったときの Y_{NT}(0)は観測することができません。そこで、なんらかの手法を使って、実際には観測できない Y_{NT}(0)を推定する必要が出てきます。これを反事実と呼びます。以降、反事実 Y_{NT}(0)を推定する手法としてDifference In DIfference(DID)、Synthetic Control(SC)、そしてその融合であるSynthetic Difference In Difference(SDID)を比較していきます。

Difference in Differences (DID)

処置前コントロール群の値にに加えて、(1)コントロール群と処置群のアウトカムにそもそもどのくらい違いがあるのか、(2)処置を受けなくても処置の前後でどのくらいアウトカムが変わるのかを考慮することで、処置後処置群の値 Y_{NT}(0)が推定できるのではというのがDIDの発想です。具体的には、DIDによる推定量 \hat{Y}_{NT}^{\text{DID}}(0) は以下で定義されます。

\begin{align}
\hat{Y}_{NT}^{\text{DID}}(0) &= \bar{Y}^{c, pre} + \left(\bar{Y}^{c, post} - \bar{Y}^{c, pre}\right)+ \left(\bar{Y}^{t, pre} - \bar{Y}^{c, pre}\right)\\
&=\frac{1}{N- 1}\frac{1}{T-1}\sum_{i=1}^{N-1}\sum_{t=1}^{T-1}Y_{it} \\
&\quad+ \left(\frac{1}{N- 1}\sum_{i=1}^{N-1}Y_{iT} - \frac{1}{N- 1}\frac{1}{T-1}\sum_{i=1}^{N-1}\sum_{t=1}^{T-1}Y_{it}\right)\\
&\quad+ \left(\frac{1}{T-1}\sum_{t=1}^{T-1}Y_{Nt} - \frac{1}{N- 1}\frac{1}{T-1}\sum_{i=1}^{N-1}\sum_{t=1}^{T-1}Y_{it}\right)
\end{align}

ここで、 \bar{Y}はそれぞれのグループの平均値を表しています。上式は以下のように解釈できます。まず、単に処置前コントロール群のみを用いて予測を行うと(第一項)、コントロール群と処置群の違いでバイアスがかかり、次に処置前と処置後の違いでもバイアスがかかります。そこで、コントロール群<->処置群の差を第二項で、処置前<->処置後の差を第三項で補正しています。この意味において、DID推定量はバイアスを二重に補正していると言えます。

f:id:dropout009:20190506185121p:plain

最後に、こうしてDIDで推定した \hat{Y}_{NT}^{\text{DID}}(0)と実際の観測値 Y_{NT}(1)の差を見ることで処置のインパクトが推定できます。

\begin{align}
\hat{\tau}^{\text{DID}} = Y_{NT}(1) -\hat{Y}_{NT}^{\text{DID}}(0)
\end{align}

Synthetic Control (SC)

DIDは単純平均でしたが、コントロール群の加重平均で処置群を表現しようというのがSCの発想になります。SCは以下の2ステップを踏みます。まず、加重平均に用いる各コントロール i = 1, \dots, N-1の重み \hat{\omega}_iを、うまく処置群を近似できるように決めます。

\begin{align}
\sum_{i=1}^{N-1} \hat{\omega}_{i} Y_{it} \approx Y_{N t} \text { for all } t=1, \ldots, T-1
\end{align}

f:id:dropout009:20190506185108p:plain

具体的には、以下の二乗誤差を小さくするように \hat{\omega}_iを求めます。

\begin{align}
\hat{\omega}&=\underset{\omega \in \mathbb{W}}{\arg \min }\;\frac{1}{T-1} \sum_{t=1}^{T-1}\left(\sum_{i=1}^{N-1} \omega_{i} Y_{i t}-Y_{N t}\right)^{2} + \frac{1}{2}\zeta\|\omega\|_2\\
\text{where}\quad \mathbb{W}&=\left\{\omega \in \mathbb{R}^{N - 1} \;\bigg|\; \omega_{i} \geq 0, \; \sum_{i=1}^{N-1} \omega_{i}=1\right\}
\end{align}

 \omegaは加重平均の重みなので0以上、足して1の制約がかかっています。また、 L_2正則化を入れることで \omegaの推定を安定させています。

次に、この重みを使って処置後である Y_{NT}(0)を推定します。
\begin{align}
\hat{Y}_{N T}^{\mathrm{SC}}(0)=\frac{1}{T-1} \sum_{i=1}^{N-1} \sum_{t=1}^{T-1} \hat{\omega}_{i} Y_{i t}+\left(\sum_{i=1}^{N-1} \hat{\omega}_{i}Y_{i T}-\frac{1}{T-1} \sum_{i=1}^{N-1} \sum_{t=1}^{T-1} \hat{\omega}_{i}Y_{i t}\right)
\end{align}

DIDの推定量 \hat{Y}_{NT}^{\text{DID}}(0)と見比べると:

  • DIDでは単純平均でしたが、SCは加重平均を用いることで処置群の近似を改善しています。
  • その一方で、SCにはDIDにあった処置前後のバイアス補正(第三項)が存在しません。

Synthetic Difference In Differences (SDID)

DIDとSCを見比べることで、双方の利点と欠点が見えました。そこで、SDIDでは両方のいいとこ取りをします。つまり、

  • 単純平均ではなく加重平均を用いる
  • コントロール/トリートメントのバイアス補正だけでなく、処置前後のバイアス補正を入れる

ことで反事実 Y_{N T}(0)のよりよい推定を目指します。

\begin{align}
\hat{Y}_{N T}^{\mathrm{SDID}}(0)= \sum_{i=1}^{N-1} \sum_{t=1}^{T-1} \hat{\omega}_{i} \hat{\lambda}_{t} Y_{i t}+\left(\sum_{t=1}^{T-1} \hat{\lambda}_{t}Y_{N t}-\sum_{t=1}^{T-1} \sum_{i=1}^{N-1} \hat{\omega}_{i}\hat{\lambda}_{t} Y_{i t}\right)+\left(\sum_{i=1}^{N-1} \hat{\omega}_{i}Y_{i T}-\sum_{i=1}^{N-1} \sum_{t=1}^{T-1} \hat{\omega}_{i}\hat{\lambda}_{t} Y_{i t}\right)
\end{align}


ここで、 \hat{\lambda}_tは時間方向の重みであり、 \hat{\omega}_iと同様に以下の最小化問題を解くことで求めます。

\begin{align}
\hat{\lambda}&=\underset{\lambda_0\in\mathbb{R},\; \lambda \in \mathbb{L}}{\arg \min } \;\frac{1}{N-1}\sum_{i=1}^{N-1}\left(\lambda_0 + \sum_{t=1}^{T-1} \lambda_{t} Y_{i t}-Y_{i T}\right)^{2}+ \frac{1}{2}\xi\|\lambda\|_2\\
\text{where}\quad \mathbb{L}&=\left\{\lambda \in \mathbb{R}^{T - 1} \;\bigg|\; \lambda_{t} \geq 0, \; \sum_{t=1}^{T-1} \lambda_{t}=1\right\}
\end{align}


ただし、トレンドに対応するために \lambda_0が入っています。これは重みとしては用いません。


比較実験

論文ではAbadie, Diamond, and Hainmueller(2010)のデータを用いてDID, SC, SDIDを比較した実験を行っているので、それを再現した結果を紹介します。

ADH(2010)ではカリフォルニアの禁煙法が喫煙に与えたインパクトをSynthetic Controlを用いて推定しています。
使用データはカリフォルニアを含むアメリカ39州の年度別喫煙量のデータで、1970年から2000年にかけて31年分あります。なお、禁煙法は1989年から実施されています。
今回の実験では実際に禁煙法の効果を調べたいのではなく、 Y_{NT}(0)をうまく予測できるかを確かめたいので、処置前(1988年以前)のデータに関して、各州の1期先の喫煙量を他の38州の喫煙量からDID/SC/SDIDで予測し、精度を検証します。*2
よりよい精度で1期先の予測ができるなら、処置の効果 \hat{\tau}_{NT} = Y_{NT}(1) - \hat{Y}_{NT}(0)をより正確に推定できると言えます。

具体的には、各州 i = 1, \dots, 39に対して1980年から1988年までの9年分の予測をDID/SC/SDIDで行い、RMSEを計算します。
\begin{align}
\operatorname{RMSE}_{i}=\sqrt{\frac{1}{9} \sum_{t=1980}^{1988}\left(Y_{it}-\hat{Y}_{it}\right)^{2}}
\end{align}


その結果が以下になります。

f:id:dropout009:20190506172341p:plain


DID/SCと比較して、SDIDはより高い精度で Y_{NT}(0)を予測することに成功しています。中央値で見るとDIDと比較して50%、SCと比較しても15%の改善です。


具体的にどの部分で精度に差が出ているのかプロットを見て確認してみます。

f:id:dropout009:20190506172657p:plain

上図は各州の喫煙量の実測値とDID/SC/SDIDの推定値がプロットされています。横軸が年度で縦軸が喫煙量です。
図が細かいのですが、基本的にSDID(赤)はDID(オレンジ)/SC(緑)よりも実測値Y(青)をうまく予測できています。
特に差が顕著な部分をいくつか抜き出したものが下の図になります。

f:id:dropout009:20190506173610p:plain

DIDは全体的に精度が悪いと思いますが、
Synthetic Controlは非常に特徴的で、New HampshireやUtahなど、他の州の加重平均で表現できない州は予測がうまくいっていません。
これに対して、SDIDは非常にロバストな予測ができています。


まとめ

本記事では、Difference In DifferencesとSynthetic Controlを融合させたSynthetic Difference In Differencesのコンセプトをまとめ、実際に精度が改善することを確認しました。
SDIDはRCTが行えない状況でパネルデータの因果推論を行う際、非常に強力な武器になると思います。


参考文献

Arkhangelsky, D., Athey, S., Hirshberg, D. A., Imbens, G. W., & Wager, S. (2019). Synthetic difference in differences (No. w25532). National Bureau of Economic Research.[1812.09970] Synthetic Difference in Differences

GitHub - swager/sdid-script: Example scripts for synthetic difference in differences

Alberto Abadie, Alexis Diamond, and Jens Hainmueller. Synthetic control methods for comparative case studies: Estimating the effect of California’s tobacco control program. Journal of the American Statistical Association, 105(490):493–505, 2010. https://economics.mit.edu/files/11859

再現コード
GitHub - dropout009/sdid_python: replication python code for synthetic difference in difference simulation. please check https://arxiv.org/abs/1812.09970 and https://github.com/swager/sdid-script

*1:簡単のため、とりあえず共変量 X_{it}は無視することにします。

*2:論文ではやってませんが、カリフォルニアを除いてしまえば2000年まで実験することもできると思います

XGBoostの論文を読んだのでGBDTについてまとめた

はじめに

今更ですが、XGboostの論文を読んだので、2章GBDT部分のまとめ記事を書こうと思います。*1
この記事を書くにあたって、できるだけ数式の解釈を書くように心がけました。数式の意味をひとつひとつ追っていくことは、実際にXGBoost(またはLightGBMやCatBoostなどのGBDT実装)を使う際にも役立つと考えています。たとえばハイパーパラメータがどこに効いているかを理解することでチューニングを効率化したり、モデルを理解することでよりモデルに合った特徴量のエンジニアリングができるのではないかと思います。

また、この記事に限りませんが、記述に間違いや不十分な点などあればご指摘頂ければ嬉しいです。

XGBoost論文

目的関数の設定


一般的な状況として、サンプルサイズが Iで特徴量の数が Mのデータ \mathcal{D} = \left\{ (\mathbf{x}_i, y_i) \right\}(i \in\mathcal{I} = \{1, \dots, I\}, \; \mathbf{x}_i \in \mathbb{R}^M, \; y_i \in\mathbb{R})に対する予測モデルを構築することを想定しましょう。*2
今回はツリーをアンサンブルした予測モデルを構築します。
 \mathcal{K}  =\{1, \dots, K\}のツリーを加法的に組み合わせた予測モデルは以下のように定式化できます。*3


\begin{align}
\hat{y}_{i} &= \phi\left(\mathbf{x}_{i}\right)=\sum_{k\in\mathcal{K}} f_{k}\left(\mathbf{x}_{i}\right),\\
\text{where}\quad f_{k} \in \mathcal{F} &= \left\{f(\mathbf{x})=w_{q(\mathbf{x})}\right\}\left(q : \mathbb{R}^{m} \rightarrow \mathcal{T},\; \mathcal{T} = \{1, \dots, T\}, \; w \in \mathbb{R}^{T}\right)
\end{align}

ここで、 f_kはひとつひとつのツリーを表しています。ツリー f(\mathbf{x})は特徴量 \mathbf{x} が与えられると、それを q(\mathbf{x})に従って各ノード t = 1, \dots, Tに紐づけ、それぞれのノードに対応する予測値 w_{q(\mathbf{x})}を返します。そして、ひとつひとつのツリーの予測値を足し合わせることで、最終的な予測結果 \hat{y}_iとします。



では、具体的にツリーをどうやって作っていくかを決めるために、最小化したい目的関数 \mathcal{L}(\phi)を設定します。

\begin{align}
\mathcal{L}(\phi) &= \sum_{i\in \mathcal{I}} l\left(y_{i}, \hat{y}_{i}\right)+\sum_{k\in \mathcal{K}} \Omega\left(f_{k}\right), \\
\text{where}\quad\Omega(f) &= \gamma T+\frac{1}{2} \lambda\|w\|^{2} = \gamma T+\frac{1}{2} \lambda \sum_{t \in \mathcal{T}} w_{t}^{2}
\end{align}

ここで、 l(y_{i}, \hat{y}_{i})は損失関数で、たとえば二乗誤差になります。ただし、単に二乗誤差を最小化するのではなく、過適合を回避して汎化性能を上げるために正則化 \Omega(f)が追加されています。なお、 \gamma \lambda はハイパーパラメータであり、交差検証などで最適な値を探索する必要があります。

  •  \Omega(f)の第一項 \gamma Tはツリーのノードの数に応じてペナルティが課されるようになっています。ハイパーパラメータ \gammaを大きくするとよりノード数少ないツリーが好まれるようになります。ツリーの大きさに制限をかけることで過適合を回避することが目的です。
  •  \Omega(f)の第二項 \frac{1}{2}\lambda\|w\|^{2}は各ノードが返す値の大きさに対してペナルティがかかることを意味しています。ハイパーパラメータ \lambdaを大きくすると、(絶対値で見て)より小さい wが好まれるようになります。 wが小さいということは最終的な出力を決める \sum_{k\in\mathcal{K}} f_{k}部分で足し合わされる値が小さくなるので、過適合を避けることに繋がります。

勾配ブースティング

さて、目的関数 \mathcal{L}(\phi)を最小化するような K個のツリー構築したいわけですが、 K個ツリーを同時に構築して最適化するのではなく、 k個目のツリーを作る際には、 k-1個目までに構築したツリーを所与として、目的関数を最小化するようなツリーを作ることにしましょう。*4


\begin{align}
\min_{f_k}\;\mathcal{L}^{(k)}=\sum_{i\in\mathcal{I}} l\left(y_{i}, \hat{y}_i^{(k - 1)}+f_{k}\left(\mathbf{x}_{i}\right)\right)+\Omega\left(f_{k}\right)
\end{align}

このステップで作成する k個目のツリーを合わせた予測値は \hat{y}_i^{(k)} = \hat{y}_{i}^{(k - 1)}+f_{k}\left(\mathbf{x}_{i}\right)であり、 k-1個目までのツリーではうまく予測できていない部分に対してフィットするように新しいツリーを構築すると解釈できます。このように残差に対してフィットするツリーを逐次的に作成していく手法をブースティングと呼びます。



さて、損失関数 \sum_{i\in\mathcal{I}} l\left(y_{i}, \hat{y}_{i}^{(k - 1)}+f_{k}\left(\mathbf{x}_{i}\right)\right)を直接最適化するのではなく、その2階近似を最適化することにしましょう。後にわかるように、2次近似によって解析的に解を求めることができます。 f_k = 0の周りで2階のテイラー展開を行うと、目的関数 \mathcal{L}^{(k)}は以下で近似できます。

\begin{align}
\mathcal{L}^{(k)} &\approx \sum_{i\in\mathcal{I}}\left[l\left(y_{i}, \hat{y}^{(k - 1)}\right)+g_{i} f_{k}\left(\mathbf{x}_{i}\right)+\frac{1}{2} h_{i} f_{k}^{2}\left(\mathbf{x}_{i}\right)\right] + \Omega\left(f_{k}\right),\\
\text{where} \quad g_i &= \frac{\partial }{\partial \hat{y}^{(k - 1)}}l\left(y_{i}, \hat{y}^{(k - 1)}\right),\\
h_i &= \frac{\partial^2 }{\partial \left(\hat{y}^{(k - 1)}\right)^2}l\left(y_{i}, \hat{y}^{(k - 1)}\right)
\end{align}

ここで、 g_i h_iはそれぞれ損失関数の1階と2階の勾配情報になります。勾配情報を使ったブースティングなので勾配ブースティングと呼ばれています。*5
今回 f_kを動かすことで目的関数を最小化するので、 f_kと関係ない第一項は無視できます。

\begin{align}
\tilde{\mathcal{L}}^{(k)} &=\sum_{i\in\mathcal{I}}\left[g_{i} f_{k}\left(\mathbf{x}_{i}\right)+\frac{1}{2} h_{i} f_{k}^{2}\left(\mathbf{x}_{i}\right)\right]+\gamma T+\frac{1}{2} \lambda \sum_{t \in \mathcal{T}} w_{t}^{2} \\
&= \sum_{t \in \mathcal{T}}\left[\sum_{i \in \mathcal{I}_{t}} g_{i} f_{k}\left(\mathbf{x}_{i}\right)+\frac{1}{2} \sum_{i \in \mathcal{I}_{t}} h_{i} f_{k}^{2}\left(\mathbf{x}_{i}\right)\right]+\gamma T +\frac{1}{2} \lambda \sum_{t \in \mathcal{T}} w_{t}^{2}\\
&=\sum_{t \in \mathcal{T}}\left[\left(\sum_{i \in \mathcal{I}_{t}} g_{i}\right) w_{t}+\frac{1}{2}\left(\lambda + \sum_{i \in \mathcal{I}_{t}} h_{i}\right) w_{t}^{2}\right]+\gamma T
\end{align}

  • 1行目の式では、 f_kと関係ない第一項を取り除き、 \Omega(f_k)の中身を書き下しました。
  • 2行目への変換ですが、全ての i\in\mathcal{I}について足し合わせている部分を、まずノード tに所属する部分 i \in \mathcal{I}_t (\mathcal{I}_t = \{i | q(\mathbf{x}_i) = t \}) を足し合わせてから、全てのノード t \in \mathcal{T}について足し合わせるように分解しています。
  • 3行目への変換では、同じノードに所属する f_k(\mathbf{x}_i)は全て w_tを返すというツリーの性質を利用しています。また、 w^2_tの共通部分をくくっています。

さて、 \tilde{\mathcal{L}}^{(k)} w_tに関しての2次式なので、解析的に解くことができます。

\begin{align}
w_{t}^{*}=-\frac{\sum_{i \in \mathcal{I}_{t}} g_{i}}{\lambda + \sum_{i \in \mathcal{I}_{t}} h_{i}}
\end{align}

以上で、 k個目のツリーに関して、各ノードが返すべき値 w_t^*が解析的に求まりました。*6
この式からもハイパーパラメータ \lambdaを大きくすると w^*_tが(絶対値で見て)小さくなることが見て取れます。この w^*_tを元の目的関数に代入してあげることで

\begin{align}
\tilde{\mathcal{L}}^{(k)}(q)=-\frac{1}{2} \sum_{t\in\mathcal{T}} \frac{\left(\sum_{i \in \mathcal{I}_{t}} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}_{t}} h_{i}}+\gamma T
\end{align}

を得ます。あとはツリーの構造 q、言い換えれば特徴量の分割ルールを決める必要があります。たとえば、一番シンプルなケースとして、全く分割を行わない場合( \mathcal{I})と一度だけ分割を行う場合( \mathcal{I}_L, \mathcal{I}_Rに分割)を比較しましょう。分割による目的関数の値の減少分は

\begin{align}
\mathcal{L}_{\text{split}}&= -\frac{1}{2} \frac{\left(\sum_{i \in \mathcal{I}} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}} h_{i}}+\gamma - \left(-\frac{1}{2} \frac{\left(\sum_{i \in \mathcal{I}_L} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}_L} h_{i}}-\frac{1}{2} \frac{\left(\sum_{i \in \mathcal{I}_R} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}_R} h_{i}}+2\gamma \right)\\
&= \frac{1}{2}\left(\frac{\left(\sum_{i \in \mathcal{I}_L} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}_L} h_{i}}+\frac{\left(\sum_{i \in \mathcal{I}_R} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}_R} h_{i}} - \frac{\left(\sum_{i \in \mathcal{I}} g_{i}\right)^2}{\lambda + \sum_{i \in \mathcal{I}} h_{i}}\right) - \gamma
\end{align}


であり、これがプラスなら分割を行い、マイナスなら分割を行わないということになります。上式からも、ハイパーパラメータ \gamma を大きくするとより分割が行われなくなることが見て取れます。

ところで、そもそもどの特徴量のどの値で分割するべきかなのでしょうか?一番ナイーブな考え方は、全ての変数に対して全ての分割点を考慮して、一番目的関数の値を減少させるような分割を選ぶというものがあります。ただし、この方法は膨大な計算量が必要になるため、XGBoostでは近似手法が提供されており、3章に記述されています。さらに、4章移行では並列計算や比較実験などが記されています。


まとめ

XGboostの論文を読んだので、自身の理解を深めるために2章GBDT部分のまとめ記事を書きました。
今までなんとなく使っていたハイパーパラメータが具体的にどの部分に効いているのか学ぶことができて、とても有意義だったと感じています。

*1:なお、元論文のノーテーションがおかしい/統一的でないように感じたので、一部表記を変更しています。

*2:元論文では |\mathcal{D}| = Nですが、 i Nが混じるとややこしい気もするので Iにしました。

*3:元論文の q : \mathbb{R}^{m} \rightarrow Tとなっているのですが、 qはインデックス 1, \dots, Tを返す関数なので、タイポかと思われます。

*4:元論文の添字 tがツリーの数 Tとややこしいので kのまま進めることにしました。

*5:たぶん

*6:論文にはこれがimpuityみたいなものと書かれているのですが、不勉強で理解できませんでした。

tidymodelsによるtidyな機械学習(その2:Cross Varidation)

はじめに

本記事ではtidymodelsを用いたCross Validationとハイパーパラメータのチューニングについて紹介したいと思います。 なお、tidymodelsの基本的な操作方法については以下の記事をご覧下さい。

dropout009.hatenablog.com

前処理

まずは前回の記事と同様、訓練/テストデータの分割と前処理を行います。なお、例によってデータはdiamondsを用います。

# パッケージ
library(tidyverse)
library(tidymodels)
set.seed(42)

# 分割
df_split = initial_split(diamonds,  p = 0.8)

df_train = training(df_split)
df_test  = testing(df_split)

# 前処理レシピ
rec = recipe(price ~ ., data = df_train) %>% 
  step_log(price) %>% 
  step_ordinalscore(all_nominal())

Cross Validation

さて、ここからが新しい内容になります。

df_cv = vfold_cv(df_train, v = 10)

> df_cv
#  10-fold cross-validation 
# A tibble: 10 x 2
   splits               id    
   <list>               <chr> 
 1 <split [38.8K/4.3K]> Fold01
 2 <split [38.8K/4.3K]> Fold02
 3 <split [38.8K/4.3K]> Fold03
 4 <split [38.8K/4.3K]> Fold04
 5 <split [38.8K/4.3K]> Fold05
 6 <split [38.8K/4.3K]> Fold06
 7 <split [38.8K/4.3K]> Fold07
 8 <split [38.8K/4.3K]> Fold08
 9 <split [38.8K/4.3K]> Fold09
10 <split [38.8K/4.3K]> Fold10

tidymodels配下のrsampleパケージにはCross Validationを行うための関数vfold_cv()があるので、これを使います。

> df_cv$splits[[1]]
<38837/4316/43153>

splitsを確認します。全43153サンプルを38837の訓練データと4316の検証データに分割していることを示しています。

> df_cv$splits[[1]] %>% analysis()
# A tibble: 38,837 x 10
   carat cut       color clarity depth table price     x     y     z
   <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
 1 0.23  Ideal     E     SI2      61.5    55   326  3.95  3.98  2.43
 2 0.21  Premium   E     SI1      59.8    61   326  3.89  3.84  2.31
 3 0.23  Good      E     VS1      56.9    65   327  4.05  4.07  2.31
 4 0.290 Premium   I     VS2      62.4    58   334  4.2   4.23  2.63
 5 0.31  Good      J     SI2      63.3    58   335  4.34  4.35  2.75
 6 0.24  Very Good J     VVS2     62.8    57   336  3.94  3.96  2.48
 7 0.26  Very Good H     SI1      61.9    55   337  4.07  4.11  2.53
 8 0.22  Fair      E     VS2      65.1    61   337  3.87  3.78  2.49
 9 0.23  Very Good H     VS1      59.4    61   338  4     4.05  2.39
10 0.3   Good      J     SI1      64      55   339  4.25  4.28  2.73

> df_cv$splits[[1]] %>% assessment()
# A tibble: 4,316 x 10
   carat cut       color clarity depth table price     x     y     z
   <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
 1  0.24 Very Good I     VVS1     62.3    57   336  3.95  3.98  2.47
 2  0.3  Very Good J     VS2      62.2    57   357  4.28  4.3   2.67
 3  0.23 Very Good F     VS1      59.8    57   402  4.04  4.06  2.42
 4  0.31 Good      H     SI1      64      54   402  4.29  4.31  2.75
 5  0.32 Good      H     SI2      63.1    56   403  4.34  4.37  2.75
 6  0.22 Premium   E     VS2      61.6    58   404  3.93  3.89  2.41
 7  0.3  Very Good I     SI1      63      57   405  4.28  4.32  2.71
 8  0.24 Premium   E     VVS1     60.7    58   553  4.01  4.03  2.44
 9  0.26 Very Good D     VVS2     62.4    54   554  4.08  4.13  2.56
10  0.86 Fair      E     SI2      55.1    69  2757  6.45  6.33  3.52
# ... with 4,306 more rows

訓練データにはanalysis()で、検証データにはassesment()でアクセスできます。

CVデータが準備できたので、学習と予測を行いましょう。まずは適当なRandom Forestモデルを用意します。

rf = rand_forest(mode = "regression",
                 trees = 50,
                 min_n = 10,
                 mtry = 3) %>%
  set_engine("ranger", num.threads = parallel::detectCores(), seed = 42)

CVデータに前処理レシピを適用するにはmap()prepper()を使います。

> df_cv %>% 
+   mutate(recipes = map(splits, prepper, recipe = rec))
#  10-fold cross-validation 
# A tibble: 10 x 3
   splits               id     recipes     
 * <list>               <chr>  <list>      
 1 <split [38.8K/4.3K]> Fold01 <S3: recipe>
 2 <split [38.8K/4.3K]> Fold02 <S3: recipe>
 3 <split [38.8K/4.3K]> Fold03 <S3: recipe>
 4 <split [38.8K/4.3K]> Fold04 <S3: recipe>
 5 <split [38.8K/4.3K]> Fold05 <S3: recipe>
 6 <split [38.8K/4.3K]> Fold06 <S3: recipe>
 7 <split [38.8K/4.3K]> Fold07 <S3: recipe>
 8 <split [38.8K/4.3K]> Fold08 <S3: recipe>
 9 <split [38.8K/4.3K]> Fold09 <S3: recipe>
10 <split [38.8K/4.3K]> Fold10 <S3: recipe>

prepper()prep()前のレシピを渡すことで、CV訓練データを用いてprep()したレシピを返してくれます。

次に前処理済みのCVデータでモデルを学習します。

> df_cv %>% 
+   mutate(recipes = map(splits, prepper, recipe = rec),
+          fitted = map(recipes, ~ fit(rf, price ~ ., data = juice(.))))
#  10-fold cross-validation 
# A tibble: 10 x 4
   splits               id     recipes      fitted  
 * <list>               <chr>  <list>       <list>  
 1 <split [38.8K/4.3K]> Fold01 <S3: recipe> <fit[+]>
 2 <split [38.8K/4.3K]> Fold02 <S3: recipe> <fit[+]>
 3 <split [38.8K/4.3K]> Fold03 <S3: recipe> <fit[+]>
 4 <split [38.8K/4.3K]> Fold04 <S3: recipe> <fit[+]>
 5 <split [38.8K/4.3K]> Fold05 <S3: recipe> <fit[+]>
 6 <split [38.8K/4.3K]> Fold06 <S3: recipe> <fit[+]>
 7 <split [38.8K/4.3K]> Fold07 <S3: recipe> <fit[+]>
 8 <split [38.8K/4.3K]> Fold08 <S3: recipe> <fit[+]>
 9 <split [38.8K/4.3K]> Fold09 <S3: recipe> <fit[+]>
10 <split [38.8K/4.3K]> Fold10 <S3: recipe> <fit[+]>

学習モデルを用いて検証用データに対して予測を行います。

# ラッパーを作っておく
pred_wrapper = function(split_obj, rec_obj, model_obj, ...) {
  df_pred = bake(rec_obj, assessment(split_obj)) %>%
    bind_cols(predict(model_obj, .)) %>% 
    select(price, .pred)
  
  return(df_pred)
}

> df_cv %>% 
+   mutate(recipes = map(splits, prepper, recipe = rec),
+          fitted = map(recipes, ~ fit(rf, price ~ ., data = juice(.))),
+          predicted = pmap(list(splits, recipes, fitted), pred_wrapper))
#  10-fold cross-validation 
# A tibble: 10 x 5
   splits               id     recipes      fitted   predicted           
 * <list>               <chr>  <list>       <list>   <list>              
 1 <split [38.8K/4.3K]> Fold01 <S3: recipe> <fit[+]> <tibble [4,316 × 2]>
 2 <split [38.8K/4.3K]> Fold02 <S3: recipe> <fit[+]> <tibble [4,316 × 2]>
 3 <split [38.8K/4.3K]> Fold03 <S3: recipe> <fit[+]> <tibble [4,316 × 2]>
 4 <split [38.8K/4.3K]> Fold04 <S3: recipe> <fit[+]> <tibble [4,315 × 2]>
 5 <split [38.8K/4.3K]> Fold05 <S3: recipe> <fit[+]> <tibble [4,315 × 2]>
 6 <split [38.8K/4.3K]> Fold06 <S3: recipe> <fit[+]> <tibble [4,315 × 2]>
 7 <split [38.8K/4.3K]> Fold07 <S3: recipe> <fit[+]> <tibble [4,315 × 2]>
 8 <split [38.8K/4.3K]> Fold08 <S3: recipe> <fit[+]> <tibble [4,315 × 2]>
 9 <split [38.8K/4.3K]> Fold09 <S3: recipe> <fit[+]> <tibble [4,315 × 2]>
10 <split [38.8K/4.3K]> Fold10 <S3: recipe> <fit[+]> <tibble [4,315 × 2]>

pmap()は入力列が複数の場合のmap()の拡張です。

予測結果を用いて検証用データで精度を計算して完了です。

> df_cv %>% 
+   mutate(recipes = map(splits, prepper, recipe = rec),
+          fitted = map(recipes, ~ fit(rf, price ~ ., data = juice(.))),
+          predicted = pmap(list(splits, recipes, fitted), pred_wrapper),
+          evaluated = map(predicted, metrics, price, .pred))
#  10-fold cross-validation 
# A tibble: 10 x 6
   splits               id     recipes      fitted   predicted            evaluated       
 * <list>               <chr>  <list>       <list>   <list>               <list>          
 1 <split [38.8K/4.3K]> Fold01 <S3: recipe> <fit[+]> <tibble [4,316 × 2]> <tibble [3 × 3]>
 2 <split [38.8K/4.3K]> Fold02 <S3: recipe> <fit[+]> <tibble [4,316 × 2]> <tibble [3 × 3]>
 3 <split [38.8K/4.3K]> Fold03 <S3: recipe> <fit[+]> <tibble [4,316 × 2]> <tibble [3 × 3]>
 4 <split [38.8K/4.3K]> Fold04 <S3: recipe> <fit[+]> <tibble [4,315 × 2]> <tibble [3 × 3]>
 5 <split [38.8K/4.3K]> Fold05 <S3: recipe> <fit[+]> <tibble [4,315 × 2]> <tibble [3 × 3]>
 6 <split [38.8K/4.3K]> Fold06 <S3: recipe> <fit[+]> <tibble [4,315 × 2]> <tibble [3 × 3]>
 7 <split [38.8K/4.3K]> Fold07 <S3: recipe> <fit[+]> <tibble [4,315 × 2]> <tibble [3 × 3]>
 8 <split [38.8K/4.3K]> Fold08 <S3: recipe> <fit[+]> <tibble [4,315 × 2]> <tibble [3 × 3]>
 9 <split [38.8K/4.3K]> Fold09 <S3: recipe> <fit[+]> <tibble [4,315 × 2]> <tibble [3 × 3]>
10 <split [38.8K/4.3K]> Fold10 <S3: recipe> <fit[+]> <tibble [4,315 × 2]> <tibble [3 × 3]>

ハイパーパラメータのサーチ

Cross Validationで精度を計算する方法がわかったので、これを用いてハイパーパラメータのサーチを行います。 今回はmin_nのみグリッドサーチすることにします。

grid_min_n = c(5, 10, 15)
df_result_cv = tibble() # 結果をいれる
for (n in grid_min_n) {
  
  rf = rand_forest(mode = "regression",
                   trees = 50,
                   min_n = n,
                   mtry = 3) %>%
    set_engine("ranger", num.threads = parallel::detectCores(), seed = 42)
  
  tmp_result = df_cv %>% 
    mutate(recipes = map(splits, prepper, recipe = rec),
           fitted = map(recipes, ~ fit(rf, price ~ ., data = juice(.))),
           predicted = pmap(list(splits, recipes, fitted), pred_wrapper),
           evaluated = map(predicted, metrics, price, .pred))
  
  df_result_cv = df_result_cv %>% 
    bind_rows(tmp_result %>% 
                select(id, evaluated) %>% 
                mutate(min_n = n))
}

> df_result_cv %>% 
+   unnest() %>% 
+   group_by(min_n, .metric) %>% 
+   summarise(mean(.estimate))
# A tibble: 9 x 3
# Groups:   min_n [?]
  min_n .metric `mean(.estimate)`
  <dbl> <chr>               <dbl>
1     5 mae                0.0671
2     5 rmse               0.0932
3     5 rsq                0.992 
4    10 mae                0.0679
5    10 rmse               0.0940
6    10 rsq                0.991 
7    15 mae                0.0686
8    15 rmse               0.0947
9    15 rsq                0.991 

どうやらmin_n = 5の精度が最も良さそうです。 このパラメータを用いて訓練データ全体でモデルを再学習し、予測精度を確認します。 parsnipで作ったモデルはupdate()でパラメータを更新できます。

rf_best = update(rf, min_n = 5)

rec_preped = rec %>% 
  prep(df_train)

fitted = rf_best %>% 
  fit(price ~ ., data = juice(rec_preped))


df_test_baked = bake(rec_preped, df_test)


> df_test_baked %>% 
+   bind_cols(predict(fitted, df_test_baked)) %>% 
+   metrics(price, .pred)
# A tibble: 3 x 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard      0.0905
2 rsq     standard      0.992 
3 mae     standard      0.0665

予測結果と実際の値もプロットして確認します。

df_test_baked %>% 
  bind_cols(predict(fitted, df_test_baked)) %>% 
  ggplot(aes(.pred, price)) +
  geom_hex() +
  geom_abline(slope = 1, intercept = 0) + 
  coord_fixed() +
  scale_x_continuous(breaks = seq(6, 10, 1), limits = c(5.8, 10.2)) + 
  scale_y_continuous(breaks = seq(6, 10, 1), limits = c(5.8, 10.2)) + 
  scale_fill_viridis_c() + 
  labs(x = "Prediction", y = "Truth") + 
  theme_minimal(base_size = 12) +
  theme(panel.grid.minor = element_blank(),
        axis.text = element_text(color = "black", size =12))

f:id:dropout009:20190109213420p:plain

大きく外したり系統的なバイアスも見られないので、うまく予測できていると言えるのではないかと思います。

まとめ

tidymodels配下のパッケージ、特にpurrrrsampleを用いることでtidyにCross Validationを行うことができました。

参考

変数重要度とPartial Dependence Plotで機械学習モデルを解釈する

はじめに

RF/GBDT/NNなどの機械学習モデルは古典的な線形回帰モデルよりも高い予測精度が得られる一方で、インプットとアウトプットの関係がよくわからないという解釈性の問題を抱えています。 この予測精度と解釈性のトレードオフでどちらに重点を置くかは解くべきタスクによって変わってくると思いますが、私が仕事で行うデータ分析はクライアントの意思決定に繋げる必要があり、解釈性に重きを置いていることが多いです。

とはいえ機械学習モデルの高い予測精度は惜しく、悩ましかったのですが、学習アルゴリズムによらずモデルに解釈性を与えられる手法が注目され始めました。 本記事では変数重要度とPDP/ICE Plot (Partial Dependence/Individual Conditional Expectation)を用いて所謂ブラックボックスモデルを解釈する方法を紹介します。

モデルの学習

まずはパッケージを読み込みます。

library(tidyverse)
library(tidymodels)

library(pdp) # partial dependence plot
library(vip) # variable importance plot

例によってdiamondsデータを使用し、Rondom Forestでダイヤの価格を予測するモデルを作ります。 tidymodelsの使い方は以前の記事をご覧下さい。

set.seed(42)

df_split = initial_split(diamonds,  p = 0.9)

df_train = training(df_split)
df_test  = testing(df_split)


rec = recipe(price ~ carat + clarity + color + cut, data = df_train) %>% 
  step_log(price) %>% 
  step_ordinalscore(all_nominal()) %>% 
  prep(df_train)

df_train_baked = juice(rec)
df_test_baked = bake(rec, df_test)

fitted = rand_forest(mode = "regression", 
                     trees = 100,
                     min_n = 5,
                     mtry = 2) %>% 
  set_engine("ranger", num.threads = parallel::detectCores(), seed = 42) %>% 
  fit(price ~ ., data = juice(rec))

> fitted
parsnip model object

Ranger result

Call:
 ranger::ranger(formula = formula, data = data, mtry = ~2, num.trees = ~100,      min.node.size = ~5, num.threads = ~parallel::detectCores(),      seed = ~42, verbose = FALSE) 

Type:                             Regression 
Number of trees:                  100 
Sample size:                      48547 
Number of independent variables:  4 
Mtry:                             2 
Target node size:                 5 
Variable importance mode:         none 
Splitrule:                        variance 
OOB prediction error (MSE):       0.01084834 
R squared (OOB):                  0.9894755 

変数重要度

モデルの解釈において最も重要なのは、どの変数がアウトプットに強く影響し、どの変数は影響しないのかを特定すること、つまり変数重要度を見ることだと思います。 たとえば特に重要な変数を見定めて施策を打つことでアウトプットをより効率よく改善できる可能性が高まりますし*1、ありえない変数の重要度が高く出ていればデータやモデルが間違っていることに気づけるかもしれません。

tree系のアルゴリズムだとSplit(分割の回数)やGain(分割したときの誤差の減少量)などで変数重要度を定義することもできますが、ここではアルゴリズムに依存しない変数重要度であるPermutationベースのものを紹介します。 Permutationベースの変数重要度のコンセプトは極めて直感的で、「ある変数の情報を壊した時にモデルの予測精度がすごく落ちるならそれは大事な変数だ」というものです。

f:id:dropout009:20190107002315p:plain

たとえば今回のモデルでcaratの変数重要度を見たいとします。

  1. 訓練データを使って学習済みモデルを用意し、テストデータに対する予測精度を出します。
  2. テストデータのcarat列の値をシャッフルしますし、シャッフル済みテストデータを使って同様に予測を行います。
  3. 2つの予測精度を比較し、予測精度の減少度合いを変数重要度とします。

carat列のシャッフルによって、もしcaratが重要な変数ならモデルは的はずれな予測をするようになりますし、逆に重要でないなら特に影響は出ないはずで、これはある意味においての変数の重要度を捉えらていると思います。

また、この手法はあくまで予測の際に用いるデータをシャッフルしていて、学習をやり直しているわけではないため、計算が軽いことも利点です。

さて、実際にPermutationベースの変数重要度を計算してみましょう。 パッケージはvipを使います。変数重要度が計算できるパッケージは他にもいくつかありますが、vipはPermutationベースの変数重要度が計算でき、次に見るPartial Dependence Plotで利用するパッケージpdpと同じシンタックスが使えます(作者が同じなので)

# objectとnewdataを受けて予測結果をベクトルで返す関数を作っておく必要がある
pred_wrapper = function(object, newdata) {
  return(predict(object, newdata) %>% pull(.pred))
}


fitted %>% 
  vip(method = "permute",  # 変数重要度の計算方法
      pred_fun = pred_wrapper, 
      target = df_test_baked %>% pull(price),
      train = df_test_baked %>% select(-price),
      metric = "rsquared", # 予測精度の評価関数) 

f:id:dropout009:20190106230934p:plain

一番手っ取り早い方法はvip()関数を使った可視化です(グラフを自分で細かくいじりたい場合はvi()での計算結果を自分で可視化します)。 caratの変数重要度がその他の変数と比べてダントツであることが見て取れます。

Partial Dependence Plot

各変数の重要度がわかったら、次に行うべきは重要な変数とアウトカムの関係を見ることだと思います。 ただ、一般にブラックボックスモデルにおいてインプットとアウトカムの関係は非常に複雑で、可視化することは困難です。

そこで、複雑な関係を要約する手法の一つにPartial Dependence Plot(PDP)があります。PDPは興味のある変数以外の影響を周辺化して消してしまうことで、インプットとアウトカムの関係を単純化しようというものです。

特徴量を  x とした学習済みモデル  \hat{f}(x) があるとします。 x を興味のある変数  x_s とその他の変数 x_c に分割し、以下のpartial dependence function

 \begin{align} \hat{f_{s}}(x_s) = E_c\left[ \hat{f}(x_s, x_c) \right] = \int \hat{f}(x_s, x_c) p(x_c) d x_c \end{align}

を定義し、これを

\begin{align} \bar{f_{s}}(x_s) = \frac{1}{N}\sum_i\hat{f}\left(x_s, x_c^{(i)}\right) \end{align}

で推定します。具体的には以下のようなことをやっています。

f:id:dropout009:20190107002309p:plain

なお、PDPは平均をとりますが、平均を取らずに各 iについて個別にプロットしたものをIndividual Conditional Expectation (ICE) Plotと呼びます。 特に変数間の交互作用がある場合にPDPでは見失ってしまう関係を確認することができます。

それでは、特に重要だった変数caratについて実際にPDPを作成してみましょう。パッケージはpdpを使います。

ice_carat = fitted %>% 
  partial(pred.var = "carat",
          pred.fun = pred_wrapper,
          train = df_test_baked, 
          type = "regression")

ice_carat %>% 
  autoplot()

f:id:dropout009:20190107011226p:plain

partial()でPD/ICE Plotの元となるデータを計算できます。細部にこだわらなければautoplot()で可視化するのが一番手っ取り早いです。 赤い線がPDP、黒い線がICEになります。caratlog(price)に与える影響が徐々に逓減していく様子が見て取れます。 事前に関数形の仮定を置かずにインプットとアウトカムの非線形な関係を柔軟に捉えることができるのがRF+PDPの強力な点かと思います。

まとめ

アルゴリズムによらずモデルを解釈する手法としては他にもALE、SHAP value、LIMEなどがあり、次回以降に紹介できればと思っています。

参考

*1:変数間の関係が相関関係なのか因果関係なのかという問題はありますが