バイアスと戯れる

Rと言語処理と(Rによる言語処理100本ノック終了)

"ranger: A Fast Implementation of Random Forests"のメモ書き

前書き

 とあるRのお悩み相談室*1にて、激しい圧力を受けたRandom Forestの新しいパッケージ{ranger}の「行数よりも列数の方が大きい疎なデータ」への適用に関して、とてもざっくりとした申し訳ない程度のメモを書きました。
 パッケージ自体は下記リンクをご参照ください。


 「行数よりも列数の方が大きい疎なデータ」として、今回はLIBSVMの二値分類タスクのデータセットのうち、news20.binaryを利用しました。


 「news20.binary」を取り扱った話には次のようなブログ記事があります。



Rコード(準備)

# 関数定義

conv2Sparse <- function (feature, feature_num) {
  feature <- feature[stringr::str_length(string = feature) > 0]
  feature_map <- do.call("rbind", stringr::str_split(string = feature[-1], pattern = ":"))
  if (nrow(feature_map) == 1) {
    return(
      Matrix::cBind(
        as.numeric(feature[1]),
        Matrix::Matrix(data = 0, nrow = 1, ncol = feature_num)
      )
    )
  }
  return(
    Matrix::cBind(
      as.numeric(feature[1]),
      Matrix::sparseMatrix(
        i = numeric(length = nrow(feature_map)) + 1,  j = as.integer(feature_map[, 1]),
        x = as.numeric(feature_map[, 2]),
        dims = c(1, feature_num)
      )
    )
  )
}


# データ準備
## https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#news20.binary
### of classes: 2
### of data: 19996
### of features: 1355191
SET_DATASET <- "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/news20.binary.bz2"
SET_DATASET_INFO <- list(FEATURE_NUM = 1355191)

## 並列処理設定
SET_PARALLE <- list(IS_PARALLEL = TRUE, CORE = 3, ITERATORS_CHUNK_SIZE = 100)
## パッケージ名
SET_LOAD_LIB <- c("hadleyverse", "pforeach", "Matrix", "ranger", "glmnet")


# パッケージロード
all(sapply(X = SET_LOAD_LIB, FUN = library, character.only = TRUE, logical.return = TRUE))
# 対象データ処理
## 対象ファイルのダウンロード
download.file(url = SET_DATASET, destfile = basename(SET_DATASET), quiet = FALSE)

## 入力ファイルデータを疎行列化
input_con <- file(description = basename(SET_DATASET), open = "r")
iter_file <- iterators::ireadLines(con = input_con, n = SET_PARALLE$ITERATORS_CHUNK_SIZE)
sparse_features <- pforeach::pforeach(
  read_df = iter_file,
  .c = rBind,
  .export = c("conv2Sparse", "SET_DATASET_INFO"),
  .packages = as.character(na.omit(object = stringr::str_match(string = search(), pattern = "^package:(.*)")[, 2])),
  .parallel = SET_PARALLE$IS_PARALLEL, .cores = SET_PARALLE$CORE
)({
  do.call(
    "rBind",
    lapply(
      X = stringr::str_split(string = read_df, pattern = "[:blank:]"),
      FUN = conv2Sparse,
      feature_num = SET_DATASET_INFO$FEATURE_NUM
    )
  )
}) 
sparse_features <- sparse_features[rowSums(sparse_features[, -1]) != 0, ]
colnames(sparse_features) <- c("Y", numeric(length = ncol(sparse_features) - 1))

set.seed(seed = 71)
train_ind <- sample(x = c(0, 1, 2), size = nrow(sparse_features), replace = TRUE, prob = c(0.01, 0.05, 0.94))


# 確認用実行コード
# 学習データ
> sum(train_ind == 1)
[1] 1007
> table(sparse_features[train_ind == 1, 1])

 -1   1 
507 500 

# テストデータ
> sum(train_ind == 0)
[1] 196
> table(sparse_features[train_ind == 0, 1])

 -1   1 
 86 110 

Rコード(実行)

(注)下記以降のコードはメモリを非常に使用します(10GB以上推奨。gc()の適宜実行を推奨)。 

# 疎行列のままでモデル作成
> class(sparse_features)
[1] "dgCMatrix"
attr(,"package")
[1] "Matrix"
> model_fit <- ranger::ranger(
+   dependent.variable.name = "Y", data = sparse_features[train_ind == 1, ],
+   save.memory = TRUE, classification = TRUE, importance = "impurity", write.forest = TRUE,
+   seed = 71
+ )
Growing trees.. Progress: 3%. Estimated remaining time: 15 minutes, 37 seconds.
Growing trees.. Progress: 3%. Estimated remaining time: 15 minutes, 37 seconds.
Growing trees.. Progress: 10%. Estimated remaining time: 10 minutes, 12 seconds.
Growing trees.. Progress: 15%. Estimated remaining time: 9 minutes, 44 seconds.
Growing trees.. Progress: 22%. Estimated remaining time: 7 minutes, 53 seconds.
Growing trees.. Progress: 31%. Estimated remaining time: 6 minutes, 5 seconds.
Growing trees.. Progress: 43%. Estimated remaining time: 4 minutes, 14 seconds.
Growing trees.. Progress: 50%. Estimated remaining time: 3 minutes, 47 seconds.
Growing trees.. Progress: 59%. Estimated remaining time: 3 minutes, 2 seconds.
Growing trees.. Progress: 73%. Estimated remaining time: 1 minute, 49 seconds.
Growing trees.. Progress: 87%. Estimated remaining time: 47 seconds.

# 学習データへの適合を確認
> table(sparse_features[train_ind == 1, 1], model_fit$predictions)
    
      -1   1
  -1 356 151
  1   62 438

# テストデータにモデルを適用して予測したいが、predict関数でエラーが起きる
> predict_res <- predict(object = model_fit$forest, data = sparse_features[train_ind == 0, -1])
 predict.ranger.forest(object = model_fit$forest, data = sparse_features[train_ind ==  でエラー: 
  Error: Invalid forest object.


{Matrix}の疎行列を通常のデータフレームに変換してモデル作成(こちらだと上記のpredict関数のエラーが起きない)。

> to_df_trn <- data.frame(
+   as(object = sparse_features[train_ind == 1, ], Class = "matrix"),
+   stringsAsFactors = FALSE
+ )
# 目的変数を因子化
> to_df_trn$Y <- as.factor(to_df_trn$Y)
> df_model_fit <- ranger::ranger(
+   dependent.variable.name = "Y", data = to_df_trn,
+   save.memory = TRUE, classification = TRUE, importance = "impurity", write.forest = TRUE,
+   seed = 71
+ )
Growing trees.. Progress: 2%. Estimated remaining time: 25 minutes, 19 seconds.
Growing trees.. Progress: 5%. Estimated remaining time: 19 minutes, 57 seconds.
Growing trees.. Progress: 9%. Estimated remaining time: 15 minutes, 27 seconds.
Growing trees.. Progress: 16%. Estimated remaining time: 10 minutes, 47 seconds.
Growing trees.. Progress: 26%. Estimated remaining time: 7 minutes, 42 seconds.
Growing trees.. Progress: 32%. Estimated remaining time: 6 minutes, 51 seconds.
Growing trees.. Progress: 37%. Estimated remaining time: 6 minutes, 18 seconds.
Growing trees.. Progress: 43%. Estimated remaining time: 5 minutes, 32 seconds.
Growing trees.. Progress: 50%. Estimated remaining time: 4 minutes, 49 seconds.
Growing trees.. Progress: 58%. Estimated remaining time: 3 minutes, 48 seconds.
Growing trees.. Progress: 63%. Estimated remaining time: 3 minutes, 28 seconds.
Growing trees.. Progress: 67%. Estimated remaining time: 3 minutes, 5 seconds.
Growing trees.. Progress: 77%. Estimated remaining time: 2 minutes, 4 seconds.
Growing trees.. Progress: 84%. Estimated remaining time: 1 minute, 26 seconds.
Growing trees.. Progress: 99%. Estimated remaining time: 6 seconds.

# 学習データへの適合を確認(疎行列時と同じ結果)
> table(to_df_trn$Y, df_model_fit$predictions)
    
      -1   1
  -1 356 151
  1   62 438

# データフレームに変換してテストデータを適用
> to_df_tst <- data.frame(
+   as(object = sparse_features[train_ind == 0, ], Class = "matrix"),
+   stringsAsFactors = FALSE
+ )
> to_df_tst$Y <- as.factor(to_df_tst$Y)
> class(to_df_tst)
[1] "data.frame"
# predict関数の実行でもメモリを多く消費するので注意
> to_df_predict <- predict(object = df_model_fit, data = to_df_tst[, -1])
> table(to_df_tst$Y, to_df_predict$predictions)
    
     -1  1
  -1 62 24
  1  14 96

 学習データに適合するだけなら疎行列でもデータフレームでもよさそうですが、入力はデータフレームにしておいた方が賢明そうです。


 疎行列(dgcMatrix)で適用可能な他パッケージと比較してみる。

# {glmnet}と比較(適当なパラメータで試したRandom Forestに合わせて、glmnet::cv.glmnet関数は使わず)
> glmnet_mdl <- glmnet::glmnet(
+     x = sparse_features[train_ind == 1, -1],  y = as.factor(sparse_features[train_ind == 1, 1]),
+     family = "binomial", alpha = 0.5
+ )

# 学習データへの適合を確認
> table(
+   sparse_features[train_ind == 1, 1],
+   sign(rowSums(
+     t(apply(
+       X = predict(object = glmnet_mdl, type = "class", newx = sparse_features[train_ind == 1, -1]),
+       MARGIN = 1, FUN = as.numeric
+     ))
+   ))
+ )
    
      -1   1
  -1 507   0
  1    0 500

# テストデータにモデルを適用
> table(
+   sparse_features[train_ind == 0, 1],
+   sign(rowSums(
+     t(apply(
+       X = predict(object = glmnet_mdl, type = "class", newx = sparse_features[train_ind == 0, -1]),
+       MARGIN = 1, FUN = as.numeric
+     ))
+   ))
+ )
    
     -1  1
  -1 75 11
  1  13 97

 大敗を喫しております。



まとめ

 二値分類タスクの「行数よりも列数の方が大きい疎なデータ」に対して、{ranger}を適用してみました。予測時の適合度(accuracy)が約80%ほどで、{glmnet}の約87%と比べると劣っていましたが、特徴数が大きくないデータで試すと{glmnet}と同程度になりました。上記に記載していませんが、LIBSVMのデータセットのページにある「a7a」のデータで試したところ、予測時の適合度が両方とも約84%ほどでした。
 データに依存しそうなので(当たり前の話ですが)、興味がある方は入力データを変えて試してみましょう。

 疎行列(特に特徴数が多い)に{ranger}が適用できるか試したのですが、私のMac Book Proのメモリを容赦なく使い尽くしました。メモリ効率が良いかどうか、よくわかりません(誰かの検証を期待したいところです)。
 ただ、下記のソースコードを読むと、論文中にあったsparse データの話は、{GenABEL}のgwaa.dataでないとダメな気が(きちんと論文を読まないとダメダメですね)。
 ranger/predict.R at master · mnwright/ranger · GitHub

 もう少し調べたいと思います。


 (2015.10.12 追記)
 {ranger}の疎行列に関する扱いについて、@sfchaos さんがTokyoR51の発表にて補足してくださいました。ありがとうございます。

 

www.slideshare.net


参考



実行環境

> devtools::session_info()
Session info -----------------------------------------------------------------------------------------
 setting  value                       
 version  R version 3.2.1 (2015-06-18)
 system   x86_64, darwin13.4.0        
 ui       RStudio (0.99.467)          
 language (EN)                        
 collate  ja_JP.UTF-8                 
 tz       Asia/Tokyo                  

Packages ---------------------------------------------------------------------------------------------
 package      * version     date       source                               
 assertthat   * 0.1         2013-12-06 CRAN (R 3.2.0)                       
 class          7.3-13      2015-06-29 CRAN (R 3.2.0)                       
 codetools      0.2-11      2015-03-10 CRAN (R 3.2.0)                       
 colorspace     1.2-6       2015-03-11 CRAN (R 3.2.0)                       
 crayon         1.3.0       2015-06-05 CRAN (R 3.2.1)                       
 curl           0.9         2015-06-19 CRAN (R 3.2.0)                       
 DBI            0.3.1       2014-09-24 CRAN (R 3.2.0)                       
 devtools     * 1.8.0       2015-05-09 CRAN (R 3.2.0)                       
 digest         0.6.8       2014-12-31 CRAN (R 3.2.0)                       
 doParallel     1.0.8       2014-02-28 CRAN (R 3.2.0)                       
 doRNG          1.6         2014-03-07 CRAN (R 3.2.0)                       
 dplyr        * 0.4.2.9002  2015-07-25 Github (hadley/dplyr@75e8303)        
 e1071        * 1.6-4       2014-09-01 CRAN (R 3.2.0)                       
 foreach        1.4.2       2014-04-11 CRAN (R 3.2.0)                       
 GenABEL      * 1.8-0       2013-12-27 CRAN (R 3.2.0)                       
 GenABEL.data * 1.0.0       2013-12-27 CRAN (R 3.2.0)                       
 ggplot2      * 1.0.1       2015-03-17 CRAN (R 3.2.0)                       
 git2r          0.10.1      2015-05-07 CRAN (R 3.2.0)                       
 gtable         0.1.2       2012-12-05 CRAN (R 3.2.0)                       
 hadleyverse  * 0.1         2015-08-09 Github (aaboyles/hadleyverse@16532fe)
 haven        * 0.2.0       2015-04-09 CRAN (R 3.2.0)                       
 iterators      1.0.7       2014-04-11 CRAN (R 3.2.0)                       
 lattice        0.20-31     2015-03-30 CRAN (R 3.2.0)                       
 lazyeval     * 0.1.10.9000 2015-07-25 Github (hadley/lazyeval@ecb8dc0)     
 lubridate    * 1.3.3       2013-12-31 CRAN (R 3.2.0)                       
 magrittr       1.5         2014-11-22 CRAN (R 3.2.0)                       
 MASS         * 7.3-41      2015-06-18 CRAN (R 3.2.0)                       
 Matrix       * 1.2-1       2015-06-01 CRAN (R 3.2.0)                       
 memoise        0.2.1       2014-04-22 CRAN (R 3.2.0)                       
 munsell        0.4.2       2013-07-11 CRAN (R 3.2.0)                       
 pforeach     * 1.3         2015-07-25 Github (hoxo-m/pforeach@2c44f3b)     
 pkgmaker       0.22        2014-05-14 CRAN (R 3.2.0)                       
 plyr         * 1.8.3       2015-06-12 CRAN (R 3.2.0)                       
 proto          0.3-10      2012-12-22 CRAN (R 3.2.0)                       
 R6             2.0.1       2014-10-29 CRAN (R 3.2.0)                       
 ranger       * 0.2.7       2015-07-29 CRAN (R 3.2.1)                       
 Rcpp           0.12.0      2015-07-26 Github (RcppCore/Rcpp@6ae91cc)       
 readr        * 0.1.1.9000  2015-07-25 Github (hadley/readr@f4a3956)        
 readxl       * 0.1.0       2015-04-14 CRAN (R 3.2.0)                       
 registry       0.2         2012-01-24 CRAN (R 3.2.0)                       
 reshape2       1.4.1       2014-12-06 CRAN (R 3.2.0)                       
 rngtools       1.2.4       2014-03-06 CRAN (R 3.2.0)                       
 rversions      1.0.1       2015-06-06 CRAN (R 3.2.0)                       
 scales         0.2.5       2015-06-12 CRAN (R 3.2.0)                       
 stringi      * 0.5-5       2015-06-29 CRAN (R 3.2.0)                       
 stringr      * 1.0.0.9000  2015-07-25 Github (hadley/stringr@380c88f)      
 testthat     * 0.10.0      2015-05-22 CRAN (R 3.2.0)                       
 tidyr        * 0.2.0.9000  2015-07-25 Github (hadley/tidyr@0dc87b2)        
 xml2         * 0.1.1       2015-06-02 CRAN (R 3.2.0)                       
 xtable         1.7-4       2014-09-12 CRAN (R 3.2.0)