Dropout

データサイエンスについて勉強したことを書いていきます。機械学習、解釈性、因果推論など。

purrrとbroomで複数の回帰モデルを効率的に管理する

私は探索的にデータを見てく段階では、可視化に加えて複数の回帰モデルを作成して比較をする、ということをよくやっています。 モデルの数が少ない場合は個別にモデルを作成してsummary()で見ていく事もできますが、モデルの数が増えるにつれてそのやり方では管理が難しくなってきます。 そこで、本記事では、purrrmap()broomtidy(), glance()を用いて複数の回帰モデルを効率的に扱う方法を紹介したいと思います。

まずはライブラリを読み込みます。tidyverseはおなじみのデータハンドリングと可視化のためのパッケージ群です。tidymodelsモデリングをtidyなやり方で統一的に扱えるようにするパッケージ群になります。今回はbroomのみ用いますが、後日他のパッケージの紹介記事も書ければと思っています。

library(tidyverse)
library(tidymodels)

# 可視化用
theme_scatter = theme_minimal() + 
    theme(panel.grid.minor = element_blank(), 
          axis.text  = element_text(color = "black")) 

theme_minimal2 = theme_scatter + 
    theme(panel.grid.major.x = element_blank()) 

データセットはこれまたおなじみのdiamondを使います。 詳細は?diamondsを見て頂ければと思いますが、ダイヤの重さ(carat)や透明度(clarity)とその値段(price)などが入ったデータセットになります。 今回はこのデータセットを用いて、ダイヤの属性がダイヤの値段にどんな影響を与えるのかを探索することにします。

可視化によるデータ分析

まずは非説明変数である価格の分布を見てみましょう。

df = diamonds #面倒なので名前をdfに

df %>% 
    ggplot(aes(price)) +
    geom_histogram(fill = "#56B4E9", color = "white") +
    theme_minimal2

f:id:dropout009:20190102144514p:plain

ものすごく歪んだ分布をしています。 歪んだ分布への簡便な対応策として、今回は対数を取ることにします。

df %>% 
    ggplot(aes(log(price))) +
    geom_histogram(fill = "#56B4E9", color = "white") +
    theme_minimal2

f:id:dropout009:20190102144521p:plain

きれいな正規分布とまではいきませんが先程よりは中心によった分布になりました。価格は対数を取って分析することにします。

次にダイヤの属性と値段の関係を見ていきましょう。 たとえばダイヤが重ければ重いほど(大きければ大きいほど)値段は高くなりそうです。

df %>% 
    sample_frac(0.1) %>% #データが多いので減らす
    ggplot(aes(carat, log(price))) +
    geom_point(color = "#0072B2", alpha = 0.5) +
    theme_scatter

f:id:dropout009:20190102144526p:plain

実際に可視化してみると、どうやらこの仮説は正しそうです。 ただ、caratlog(price)の関係は非線形に見えます。 caratlog(price)に与えるインパクトはcaratが大きくなるにつれて逓減するという関係を反映するために、caratにも対数をとってみましょう。

df %>% 
    sample_frac(0.1) %>% #データが多いので減らす
    ggplot(aes(log(carat), log(price))) +
    geom_point(color = "#0072B2", alpha = 0.5) +
    theme_scatter

f:id:dropout009:20190102144533p:plain

両変数に対数をとることで線形の関係が構築できました! これなら線形回帰でうまくモデリングできそうです。

同様に透明度と価格の関係も見てみましょう。 直感的には透明度が高ければ高いほど価格が高くなる傾向がありそうに思えます。

df %>% 
    ggplot(aes(clarity, log(price))) +
    geom_boxplot(fill = "#56B4E9") +
    theme_minimal2

f:id:dropout009:20190102144540p:plain 直感に反する結果となりました。上のグラフは右に行くほど透明度のランクが高いのですが、透明度が高いほどダイヤが安くなっています。

これは一体どういうことでしょうか? 透明度とカラット数の関係をみることでこの謎が解けます。

df %>% 
    ggplot(aes(clarity, log(carat))) +
    geom_boxplot(fill = "#56B4E9") +
    theme_minimal2

f:id:dropout009:20190102144545p:plain 透明度が高い場合にはカラット数が小さくなる傾向が見て取れます。 僕はダイヤモンドに詳しくないのですが、大きくて透明なダイヤを作るのは難しいということのようです。*1

つまり、ここでは「ダイヤのカラット数と価格には正の相関がある一方で、カラット数と透明度には負の相関があり、結果として透明度と価格に負の相関があるように見えてしまう」という典型的な交絡の問題が起きています。 この場合は、重回帰分析でカラット数の影響を取り除くことで透明度と価格の正しい関係を見ることができます。 今回のように、単純に一つ一つの説明変数と被説明変数の関係を見ているだけでは間違った結論を下していまう可能性があるので注意が必要です。

モデリングによるデータ分析

ここまでの可視化で得られた仮説を回帰モデルで確かめてみましょう。

回帰モデルを作る前に、カテゴリカル変数がordere=TRUEになっているとlm()の挙動が面倒なので、factor型の変数は全てordered=FALSEにしておきます。

df_input = df %>% 
    mutate_if(is.ordered, factor, ordered = FALSE)

mutate_if()は複数の列に同じ処理を行う際に便利な関数です。 第一引数で条件を指定して(is.ordered)、条件に当てはまった列にのみ第二引数の関数を適用します(factor)。関数のオプションは後ろにくっつけて指定すればOKです(ordered = FALSE)。

複数のモデルを当てはめる場合

今回は

  1. 価格を透明度のみで説明するモデル
  2. 透明度とカラット数で説明するモデル
  3. 上と同じだが、カラット数に対数をとったモデル

の3つのモデルを作成してみます。

formulas = c(log(price) ~ clarity,
             log(price) ~ clarity + carat,
             log(price) ~ clarity + log(carat)) %>% 
    enframe("model_no", "formula")

formulas 

# A tibble: 3 x 2
  model_no formula      
     <int> <list>       
1        1 <S3: formula>
2        2 <S3: formula>
3        3 <S3: formula>

enframe()はベクトルをデータフレームにしてくれる関数です。データフレームとして持っていたほうが後々管理がしやすいと思うので、変換しておきました。

上のデータフレームにmap()を適用することで複数の回帰モデルを一気に作成できます。

df_result = formulas %>% 
    mutate(model = map(formula, lm, data = df_input), #(1)
           tidied = map(model, tidy), #(2)
           glanced = map(model, glance)) #(3)

df_result

# A tibble: 3 x 5
  model_no formula       model    tidied           glanced          
     <int> <list>        <list>   <list>           <list>           
1        1 <S3: formula> <S3: lm> <tibble [8 × 5]> <tibble [1 × 11]>
2        2 <S3: formula> <S3: lm> <tibble [9 × 5]> <tibble [1 × 11]>
3        3 <S3: formula> <S3: lm> <tibble [9 × 5]> <tibble [1 × 11]>

(1)ではmap()を用いて「formula列の一つ一つの値を引数としてlm()を実行する」ということをやっています。lm()はデータも指定する必要がありますが、map()mutate_if()同様後ろにくっつけて指定できます。 さらに、(2), (3)では推定されたモデルに対して、それぞれtidy()glance()を適用しています。

分析結果をデータフレームとして持つことで、どのモデルがどの結果に対応するかを間違うことなく効率的に管理することができます。

tidy()は回帰モデルの係数をtidyなデータフレームとして持ってきてくれる関数です。データフレームの中のデータフレームを取り出すためにunnest()を使います。

df_coef = df_result %>% 
    select(model_no, tidied) %>% 
    unnest() %>% 
    mutate_if(is.double, round, digits=2) 

df_coef

# A tibble: 26 x 6
   model_no term        estimate std.error statistic p.value
      <int> <chr>          <dbl>     <dbl>     <dbl>   <dbl>
 1        1 (Intercept)    8.03       0.04    221.         0
 2        1 claritySI2     0.14       0.04      3.69       0
 3        1 claritySI1    -0.18       0.04     -4.81       0
 4        1 clarityVS2    -0.26       0.04     -7.08       0
 5        1 clarityVS1    -0.3        0.04     -7.99       0
 6        1 clarityVVS2   -0.5        0.04    -12.8        0
 7        1 clarityVVS1   -0.7        0.04    -17.7        0
 8        1 clarityIF     -0.62       0.04    -14.4        0
 9        2 (Intercept)    5.36       0.01    372.         0
10        2 claritySI2     0.570      0.01     40.1        0
# ... with 16 more rows

3つのモデルを比較してみましょう。 モデルを横に並べるためにspread()を使います。 そのままだと回帰係数がアルファベット順になってしまうので、並び順を維持するためにfct_inorder()を使っています(出てきた順に並ぶ)。

df_coef %>% 
    mutate(term = fct_inorder(term)) %>% 
    select(model_no, term, estimate) %>% 
    spread(model_no, estimate)

# A tibble: 10 x 4
   term           `1`    `2`   `3`
   <fct>        <dbl>  <dbl> <dbl>
 1 (Intercept)   8.03  5.36   7.77
 2 claritySI2    0.14  0.570  0.48
 3 claritySI1   -0.18  0.72   0.62
 4 clarityVS2   -0.26  0.82   0.78
 5 clarityVS1   -0.3   0.86   0.82
 6 clarityVVS2  -0.5   0.93   0.98
 7 clarityVVS1  -0.7   0.92   1.03
 8 clarityIF    -0.62  1      1.11
 9 carat        NA     2.08  NA   
10 log(carat)   NA    NA      1.81

モデル1は透明度のみのモデルであり、カラット数の影響を取り除いていないので、透明度と価格には負の関係があるように見えています。 その一方で、モデル2はカラット数を変数として加えることで影響を取り除いた上での透明度と価格の影響を見ています。モデル2の係数を見ると透明度が高いほど価格が高くなる傾向が見て取れ、こちらのほうがより尤もらしい結果であると言えそうです。

モデル3はカラット数に対数をとったモデルになります。 モデル2もモデル3も透明度と価格の関係には大きな変化はなさそうです。

glance()はモデルの性能をtidyなデータフレームとしてまとめてくれる関数です。

df_result %>% 
    select(model_no, glanced) %>% 
    unnest() %>% 
    mutate_if(is.double, round, digits=2) 

# A tibble: 3 x 12
  model_no r.squared adj.r.squared sigma statistic p.value    df  logLik
     <int>     <dbl>         <dbl> <dbl>     <dbl>   <dbl> <int>   <dbl>
1        1      0.05          0.05  0.99      415.       0     8 -75907.
2        2      0.87          0.87  0.37    43737.       0     9 -23024.
3        3      0.97          0.97  0.19   187918.       0     9  13378.
# ... with 4 more variables: AIC <dbl>, BIC <dbl>, deviance <dbl>,
#   df.residual <int>

様々な指標がまとめて出力されますが、今回は自由度調整済み決定係数(adj.r.squared)を見ることにしましょう。 透明度のみモデル(モデル1)と比べると、カラット数を加えることでモデルの説明力は大幅に上昇しています(モデル2)。 カラット数に対数をとるとさらに説明力が改善されており(モデル3)、3つのモデルの中ではモデル3がベストであると言えそうです。

サブサンプルに分けて分析する場合

先程は複数のモデルを同じデータに当てはめましたが、同じモデルを複数のデータにまとめてフィットさせることもできます。 たとえば透明度によって価格とカラット数の関係が異なるのかを調べてみることにしましょう。 透明度の低いダイヤは大きくてもあまり価値は上がらないが、透明度が高い場合はサイズが大きくなると価値が大きく上昇する、といった関係があるかもしれません。 これは透明度でデータをサブサンプルに分割して、各データに対して回帰モデルを当てはめることで確認できます。

df_nested = df %>% 
    group_by(clarity) %>% #nest(-clarity)でも同じ
    nest() %>% 
    arrange(clarity)

df_nested

# A tibble: 8 x 2
  clarity data                 
  <ord>   <list>               
1 I1      <tibble [741 × 9]>   
2 SI2     <tibble [9,194 × 9]> 
3 SI1     <tibble [13,065 × 9]>
4 VS2     <tibble [12,258 × 9]>
5 VS1     <tibble [8,171 × 9]> 
6 VVS2    <tibble [5,066 × 9]> 
7 VVS1    <tibble [3,655 × 9]> 
8 IF      <tibble [1,790 × 9]> 

nest()はグループごとにデータフレームを分割してdata列に格納する関数です。 data列にmap()を適用することで各サブサンプルに対して同じモデルをまとめてフィットさせることができます。

df_result2 = df_nested %>% 
    mutate(model = map(data, ~lm(log(price) ~ log(carat), data = .)),
           tidied = map(model, tidy))

モデルが作成できたので、回帰係数を見てみましょう。

df_result2 %>% 
    select(clarity, tidied) %>% 
    unnest() %>% 
    filter(term != "(Intercept)") %>% #切片は無視
    mutate_if(is.double, round, digits=2) 

# A tibble: 8 x 6
  clarity term       estimate std.error statistic p.value
  <ord>   <chr>         <dbl>     <dbl>     <dbl>   <dbl>
1 I1      log(carat)     1.53      0.02      84.4       0
2 SI2     log(carat)     1.79      0        557.        0
3 SI1     log(carat)     1.82      0        676.        0
4 VS2     log(carat)     1.78      0        563.        0
5 VS1     log(carat)     1.83      0        473.        0
6 VVS2    log(carat)     1.85      0.01     338.        0
7 VVS1    log(carat)     1.83      0.01     245.        0
8 IF      log(carat)     1.87      0.01     169.        0

I1だけは係数が小さくなっており、その他はほとんど同じ係数です。 どうやらと透明度が最低ランクの場合は、他のランクと比べてカラット数が価格に与える影響は小さいようです。

散布図も確認すると、確かにI1のみ傾きが少し緩やかになっているような気がします。

df %>% 
    sample_frac(0.3) %>% #データが多いので減らす
    ggplot(aes(log(carat), log(price))) +
    geom_point(color = "gray", alpha = 0.5) +
    facet_wrap(~clarity) + 
    geom_smooth(method = "lm", se = F, color = "#0072B2")  +
    theme_scatter

f:id:dropout009:20190102144550p:plain

まとめ

このように、探索的にデータを見ていく際には、purrrbroomを使うと複数の回帰モデルを効率的に比較することができ、とても便利です。