"ranger: A Fast Implementation of Random Forests"のメモ書き
前書き
とあるRのお悩み相談室*1にて、激しい圧力を受けたRandom Forestの新しいパッケージ{ranger}の「行数よりも列数の方が大きい疎なデータ」への適用に関して、とてもざっくりとした申し訳ない程度のメモを書きました。
パッケージ自体は下記リンクをご参照ください。
- mnwright/ranger · GitHub
- CRAN - Package ranger
- [1508.04409] ranger: A Fast Implementation of Random Forests for High Dimensional Data in C++ and R
「行数よりも列数の方が大きい疎なデータ」として、今回はLIBSVMの二値分類タスクのデータセットのうち、news20.binaryを利用しました。
「news20.binary」を取り扱った話には次のようなブログ記事があります。
- Perceptron を勉強する前にオンライン機械学習ライブラリを試してみる (nakatani @ cybozu labs)
- SVMソフトウェアの比較 - tsubosakaの日記
- NBSVMを試す - Negative/Positive Thinking
Rコード(準備)
# 関数定義 conv2Sparse <- function (feature, feature_num) { feature <- feature[stringr::str_length(string = feature) > 0] feature_map <- do.call("rbind", stringr::str_split(string = feature[-1], pattern = ":")) if (nrow(feature_map) == 1) { return( Matrix::cBind( as.numeric(feature[1]), Matrix::Matrix(data = 0, nrow = 1, ncol = feature_num) ) ) } return( Matrix::cBind( as.numeric(feature[1]), Matrix::sparseMatrix( i = numeric(length = nrow(feature_map)) + 1, j = as.integer(feature_map[, 1]), x = as.numeric(feature_map[, 2]), dims = c(1, feature_num) ) ) ) } # データ準備 ## https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#news20.binary ### of classes: 2 ### of data: 19996 ### of features: 1355191 SET_DATASET <- "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/news20.binary.bz2" SET_DATASET_INFO <- list(FEATURE_NUM = 1355191) ## 並列処理設定 SET_PARALLE <- list(IS_PARALLEL = TRUE, CORE = 3, ITERATORS_CHUNK_SIZE = 100) ## パッケージ名 SET_LOAD_LIB <- c("hadleyverse", "pforeach", "Matrix", "ranger", "glmnet") # パッケージロード all(sapply(X = SET_LOAD_LIB, FUN = library, character.only = TRUE, logical.return = TRUE))
# 対象データ処理 ## 対象ファイルのダウンロード download.file(url = SET_DATASET, destfile = basename(SET_DATASET), quiet = FALSE) ## 入力ファイルデータを疎行列化 input_con <- file(description = basename(SET_DATASET), open = "r") iter_file <- iterators::ireadLines(con = input_con, n = SET_PARALLE$ITERATORS_CHUNK_SIZE) sparse_features <- pforeach::pforeach( read_df = iter_file, .c = rBind, .export = c("conv2Sparse", "SET_DATASET_INFO"), .packages = as.character(na.omit(object = stringr::str_match(string = search(), pattern = "^package:(.*)")[, 2])), .parallel = SET_PARALLE$IS_PARALLEL, .cores = SET_PARALLE$CORE )({ do.call( "rBind", lapply( X = stringr::str_split(string = read_df, pattern = "[:blank:]"), FUN = conv2Sparse, feature_num = SET_DATASET_INFO$FEATURE_NUM ) ) }) sparse_features <- sparse_features[rowSums(sparse_features[, -1]) != 0, ] colnames(sparse_features) <- c("Y", numeric(length = ncol(sparse_features) - 1)) set.seed(seed = 71) train_ind <- sample(x = c(0, 1, 2), size = nrow(sparse_features), replace = TRUE, prob = c(0.01, 0.05, 0.94)) # 確認用実行コード # 学習データ > sum(train_ind == 1) [1] 1007 > table(sparse_features[train_ind == 1, 1]) -1 1 507 500 # テストデータ > sum(train_ind == 0) [1] 196 > table(sparse_features[train_ind == 0, 1]) -1 1 86 110
Rコード(実行)
(注)下記以降のコードはメモリを非常に使用します(10GB以上推奨。gc()の適宜実行を推奨)。
# 疎行列のままでモデル作成 > class(sparse_features) [1] "dgCMatrix" attr(,"package") [1] "Matrix" > model_fit <- ranger::ranger( + dependent.variable.name = "Y", data = sparse_features[train_ind == 1, ], + save.memory = TRUE, classification = TRUE, importance = "impurity", write.forest = TRUE, + seed = 71 + ) Growing trees.. Progress: 3%. Estimated remaining time: 15 minutes, 37 seconds. Growing trees.. Progress: 3%. Estimated remaining time: 15 minutes, 37 seconds. Growing trees.. Progress: 10%. Estimated remaining time: 10 minutes, 12 seconds. Growing trees.. Progress: 15%. Estimated remaining time: 9 minutes, 44 seconds. Growing trees.. Progress: 22%. Estimated remaining time: 7 minutes, 53 seconds. Growing trees.. Progress: 31%. Estimated remaining time: 6 minutes, 5 seconds. Growing trees.. Progress: 43%. Estimated remaining time: 4 minutes, 14 seconds. Growing trees.. Progress: 50%. Estimated remaining time: 3 minutes, 47 seconds. Growing trees.. Progress: 59%. Estimated remaining time: 3 minutes, 2 seconds. Growing trees.. Progress: 73%. Estimated remaining time: 1 minute, 49 seconds. Growing trees.. Progress: 87%. Estimated remaining time: 47 seconds. # 学習データへの適合を確認 > table(sparse_features[train_ind == 1, 1], model_fit$predictions) -1 1 -1 356 151 1 62 438 # テストデータにモデルを適用して予測したいが、predict関数でエラーが起きる > predict_res <- predict(object = model_fit$forest, data = sparse_features[train_ind == 0, -1]) predict.ranger.forest(object = model_fit$forest, data = sparse_features[train_ind == でエラー: Error: Invalid forest object.
{Matrix}の疎行列を通常のデータフレームに変換してモデル作成(こちらだと上記のpredict関数のエラーが起きない)。
> to_df_trn <- data.frame( + as(object = sparse_features[train_ind == 1, ], Class = "matrix"), + stringsAsFactors = FALSE + ) # 目的変数を因子化 > to_df_trn$Y <- as.factor(to_df_trn$Y) > df_model_fit <- ranger::ranger( + dependent.variable.name = "Y", data = to_df_trn, + save.memory = TRUE, classification = TRUE, importance = "impurity", write.forest = TRUE, + seed = 71 + ) Growing trees.. Progress: 2%. Estimated remaining time: 25 minutes, 19 seconds. Growing trees.. Progress: 5%. Estimated remaining time: 19 minutes, 57 seconds. Growing trees.. Progress: 9%. Estimated remaining time: 15 minutes, 27 seconds. Growing trees.. Progress: 16%. Estimated remaining time: 10 minutes, 47 seconds. Growing trees.. Progress: 26%. Estimated remaining time: 7 minutes, 42 seconds. Growing trees.. Progress: 32%. Estimated remaining time: 6 minutes, 51 seconds. Growing trees.. Progress: 37%. Estimated remaining time: 6 minutes, 18 seconds. Growing trees.. Progress: 43%. Estimated remaining time: 5 minutes, 32 seconds. Growing trees.. Progress: 50%. Estimated remaining time: 4 minutes, 49 seconds. Growing trees.. Progress: 58%. Estimated remaining time: 3 minutes, 48 seconds. Growing trees.. Progress: 63%. Estimated remaining time: 3 minutes, 28 seconds. Growing trees.. Progress: 67%. Estimated remaining time: 3 minutes, 5 seconds. Growing trees.. Progress: 77%. Estimated remaining time: 2 minutes, 4 seconds. Growing trees.. Progress: 84%. Estimated remaining time: 1 minute, 26 seconds. Growing trees.. Progress: 99%. Estimated remaining time: 6 seconds. # 学習データへの適合を確認(疎行列時と同じ結果) > table(to_df_trn$Y, df_model_fit$predictions) -1 1 -1 356 151 1 62 438 # データフレームに変換してテストデータを適用 > to_df_tst <- data.frame( + as(object = sparse_features[train_ind == 0, ], Class = "matrix"), + stringsAsFactors = FALSE + ) > to_df_tst$Y <- as.factor(to_df_tst$Y) > class(to_df_tst) [1] "data.frame" # predict関数の実行でもメモリを多く消費するので注意 > to_df_predict <- predict(object = df_model_fit, data = to_df_tst[, -1]) > table(to_df_tst$Y, to_df_predict$predictions) -1 1 -1 62 24 1 14 96
学習データに適合するだけなら疎行列でもデータフレームでもよさそうですが、入力はデータフレームにしておいた方が賢明そうです。
疎行列(dgcMatrix)で適用可能な他パッケージと比較してみる。
# {glmnet}と比較(適当なパラメータで試したRandom Forestに合わせて、glmnet::cv.glmnet関数は使わず) > glmnet_mdl <- glmnet::glmnet( + x = sparse_features[train_ind == 1, -1], y = as.factor(sparse_features[train_ind == 1, 1]), + family = "binomial", alpha = 0.5 + ) # 学習データへの適合を確認 > table( + sparse_features[train_ind == 1, 1], + sign(rowSums( + t(apply( + X = predict(object = glmnet_mdl, type = "class", newx = sparse_features[train_ind == 1, -1]), + MARGIN = 1, FUN = as.numeric + )) + )) + ) -1 1 -1 507 0 1 0 500 # テストデータにモデルを適用 > table( + sparse_features[train_ind == 0, 1], + sign(rowSums( + t(apply( + X = predict(object = glmnet_mdl, type = "class", newx = sparse_features[train_ind == 0, -1]), + MARGIN = 1, FUN = as.numeric + )) + )) + ) -1 1 -1 75 11 1 13 97
大敗を喫しております。
まとめ
二値分類タスクの「行数よりも列数の方が大きい疎なデータ」に対して、{ranger}を適用してみました。予測時の適合度(accuracy)が約80%ほどで、{glmnet}の約87%と比べると劣っていましたが、特徴数が大きくないデータで試すと{glmnet}と同程度になりました。上記に記載していませんが、LIBSVMのデータセットのページにある「a7a」のデータで試したところ、予測時の適合度が両方とも約84%ほどでした。
データに依存しそうなので(当たり前の話ですが)、興味がある方は入力データを変えて試してみましょう。
疎行列(特に特徴数が多い)に{ranger}が適用できるか試したのですが、私のMac Book Proのメモリを容赦なく使い尽くしました。メモリ効率が良いかどうか、よくわかりません(誰かの検証を期待したいところです)。
ただ、下記のソースコードを読むと、論文中にあったsparse データの話は、{GenABEL}のgwaa.dataでないとダメな気が(きちんと論文を読まないとダメダメですね)。
ranger/predict.R at master · mnwright/ranger · GitHub
もう少し調べたいと思います。
(2015.10.12 追記)
{ranger}の疎行列に関する扱いについて、@sfchaos さんがTokyoR51の発表にて補足してくださいました。ありがとうございます。
www.slideshare.net
参考
- RでランダムフォレストやるならRboristかrangerか - 盆栽日記
- TagTeam :: Predicting Titanic deaths on Kaggle V: Ranger - R-bloggers - Statistics and Visualization
- [R言語]library("ranger")とlibrary("randomForest")の速度を比較する - gepulog
- 新型のランダムフォレスト(Random Forest)パッケージ比較:Rborist・ranger・randomForest - My Life as a Mock Quant
実行環境
> devtools::session_info() Session info ----------------------------------------------------------------------------------------- setting value version R version 3.2.1 (2015-06-18) system x86_64, darwin13.4.0 ui RStudio (0.99.467) language (EN) collate ja_JP.UTF-8 tz Asia/Tokyo Packages --------------------------------------------------------------------------------------------- package * version date source assertthat * 0.1 2013-12-06 CRAN (R 3.2.0) class 7.3-13 2015-06-29 CRAN (R 3.2.0) codetools 0.2-11 2015-03-10 CRAN (R 3.2.0) colorspace 1.2-6 2015-03-11 CRAN (R 3.2.0) crayon 1.3.0 2015-06-05 CRAN (R 3.2.1) curl 0.9 2015-06-19 CRAN (R 3.2.0) DBI 0.3.1 2014-09-24 CRAN (R 3.2.0) devtools * 1.8.0 2015-05-09 CRAN (R 3.2.0) digest 0.6.8 2014-12-31 CRAN (R 3.2.0) doParallel 1.0.8 2014-02-28 CRAN (R 3.2.0) doRNG 1.6 2014-03-07 CRAN (R 3.2.0) dplyr * 0.4.2.9002 2015-07-25 Github (hadley/dplyr@75e8303) e1071 * 1.6-4 2014-09-01 CRAN (R 3.2.0) foreach 1.4.2 2014-04-11 CRAN (R 3.2.0) GenABEL * 1.8-0 2013-12-27 CRAN (R 3.2.0) GenABEL.data * 1.0.0 2013-12-27 CRAN (R 3.2.0) ggplot2 * 1.0.1 2015-03-17 CRAN (R 3.2.0) git2r 0.10.1 2015-05-07 CRAN (R 3.2.0) gtable 0.1.2 2012-12-05 CRAN (R 3.2.0) hadleyverse * 0.1 2015-08-09 Github (aaboyles/hadleyverse@16532fe) haven * 0.2.0 2015-04-09 CRAN (R 3.2.0) iterators 1.0.7 2014-04-11 CRAN (R 3.2.0) lattice 0.20-31 2015-03-30 CRAN (R 3.2.0) lazyeval * 0.1.10.9000 2015-07-25 Github (hadley/lazyeval@ecb8dc0) lubridate * 1.3.3 2013-12-31 CRAN (R 3.2.0) magrittr 1.5 2014-11-22 CRAN (R 3.2.0) MASS * 7.3-41 2015-06-18 CRAN (R 3.2.0) Matrix * 1.2-1 2015-06-01 CRAN (R 3.2.0) memoise 0.2.1 2014-04-22 CRAN (R 3.2.0) munsell 0.4.2 2013-07-11 CRAN (R 3.2.0) pforeach * 1.3 2015-07-25 Github (hoxo-m/pforeach@2c44f3b) pkgmaker 0.22 2014-05-14 CRAN (R 3.2.0) plyr * 1.8.3 2015-06-12 CRAN (R 3.2.0) proto 0.3-10 2012-12-22 CRAN (R 3.2.0) R6 2.0.1 2014-10-29 CRAN (R 3.2.0) ranger * 0.2.7 2015-07-29 CRAN (R 3.2.1) Rcpp 0.12.0 2015-07-26 Github (RcppCore/Rcpp@6ae91cc) readr * 0.1.1.9000 2015-07-25 Github (hadley/readr@f4a3956) readxl * 0.1.0 2015-04-14 CRAN (R 3.2.0) registry 0.2 2012-01-24 CRAN (R 3.2.0) reshape2 1.4.1 2014-12-06 CRAN (R 3.2.0) rngtools 1.2.4 2014-03-06 CRAN (R 3.2.0) rversions 1.0.1 2015-06-06 CRAN (R 3.2.0) scales 0.2.5 2015-06-12 CRAN (R 3.2.0) stringi * 0.5-5 2015-06-29 CRAN (R 3.2.0) stringr * 1.0.0.9000 2015-07-25 Github (hadley/stringr@380c88f) testthat * 0.10.0 2015-05-22 CRAN (R 3.2.0) tidyr * 0.2.0.9000 2015-07-25 Github (hadley/tidyr@0dc87b2) xml2 * 0.1.1 2015-06-02 CRAN (R 3.2.0) xtable 1.7-4 2014-09-12 CRAN (R 3.2.0)