Stan — 高速MCMCでパラメータ推定
数あるMCMCアルゴリズムの中でも効率的なHMC(Hybrid/Hamiltonian Monte Carlo)を用いてベイズ推定を行うツール。 Pythonやコマンドラインなどいろんな形で利用可能だが、 とりあえずRでRStanを使ってみる。
インストール
Rからinstall.packages("rstan")
で一発。
jagsと違ってstan本体も同時に入れてくれる。
RStan-Getting-Started
を見ると、時代や環境によってはいろいろ難しいかったのかも。
標準的な開発環境(Mac なら Command Line Tools、Ubuntu なら build-essential)はどっちみち必要。
基本的な流れ
-
rstanを読み込む
library(rstan) rstan_options(auto_write = TRUE) options(mc.cores = parallel::detectCores())
-
名前付きlistとしてデータを用意する。 e.g., 平均10、標準偏差3の正規乱数。
observation = list(x = rnorm(10000, 10, 3)) observation$length = length(observation$x)
-
Stan言語でモデルを記述する。 別ファイルにしてもいいし、下記のようにR文字列でもいい。 e.g., 与えられたデータが正規分布から取れてきたとすると、 その平均と標準偏差はどれくらいだったか?
stan_code = " data { int length; real x[length]; } parameters { real mu; real<lower=0> sigma; } model { x ~ normal(mu, sigma); }"
-
モデルをC++に変換してコンパイルする。 ファイルから読み込んだ場合は中間ファイル
*.rda
がキャッシュされる。mod = rstan::stan_model(model_code = stan_code) # or mod = rstan::stan_model(file = "model.stan")
-
コンパイル済みモデルを使ってMCMCサンプリング
fit = rstan::sampling(mod, data = observation, iter = 10000, chains = 3)
-
結果を見てみる
print(fit) summary(fit) plot(fit) pairs(fit) rstan::traceplot(fit) rstan::stan_trace(fit) rstan::stan_hist(fit) rstan::stan_dens(fit)
Stan文法
https://mc-stan.org/documentation/
ブロック
コード内に登場できるブロックは7種類で、順番はこの通りでなければならない。
functions {...}
- 関数を定義できる。
data {...}
- Rから受け取る定数の宣言。
transformed data {...}
- 定数の宣言と代入。 決め打ちのハイパーパラメータとか。 決定論的な変換のみ可能。
parameters {...}
- サンプリングされる変数の宣言。
transformed parameters {...}
- 変数の宣言と代入。 モデルで使いやすい形にパラメータを変形しておくとか?
model {...}
- 唯一の必須ブロック。 サンプルされないローカル変数を宣言してもよいが、制約をかけることはできない。
generated quantities {...}
- サンプリング後の値を使って好きなことをするとこ?
normal_rng()
などによる乱数生成が許される唯一のブロック。 rstanならここを使わずRで結果を受け取ってからどうにかするほうが簡単?
モデリング
あるパラメータにおけるlog probabilityと近傍での傾きを計算し、
それらを元に次の値にジャンプする、という操作が繰り返される。
modelブロック内で暗黙的に定義されている target
変数に対して
+=
演算子で対数確率をどんどん加算していく。
(昔は隠れ変数lp__
やincrement_log_prob()
などを使ってた。)
サンプリング文(sampling statement)はそれを簡単に記述するためのショートカット。 名前とは裏腹に、確率分布からのサンプリングが行われるわけではないので紛らわしい。 例えば以下の表現はほぼ等価。 (定数の扱い方がうまいとかでサンプリング文のほうが効率的らしいけど)
x ~ normal(0.0, 1.0);
target += normal_lpdf(x | 0.0, 1.0);
target += -0.5 * square(x);
確率分布としての正規化はうまいことやっといてくれるから気にしなくていいらしい
(が、T[,]
によるtruncated distributionではこうやって調整する、
とかいう記述もあるので、そのへんはまだよく分からない)。
名のある確率分布はだいたい関数として用意されている。 形のバリエーションとしては:
- 確率密度関数:
*_lpdf(y | ...)
,*_lpmf(y | ...)
- 累積分布関数:
*_cdf(y | ...)
,*_lcdf(y | ...)
- 相補累積分布関数:
*_lccdf(y | ...)
- 乱数生成:
*_rng(...)
(対数版のsuffixは昔は _cdf_log()
, _ccdf_log()
という形だった)
型
整数(int
)、実数(real
)、実数ベクトル(vector
, row_vector
)、実数行列(matrix
)。
内部的に Eigen::Vector
や Eigen::Matrix
が使われているので、
可能な限りfor
文よりも行列演算を使うように心がける。
配列(array)は std::vector
で実装されていて、
整数配列や行列配列など何でも作れるが、行列演算はできない。
宣言時に上限下限を設定できる (constrained integer/real)。
bool型は無くて基本的に整数の1/0。分岐ではnon-zeroがtrue扱い。
int i;
int v[42];
real x;
real x[42];
int<lower=1,upper=6> dice;
vector[3] v;
row_vector[3] r;
matrix[3, 3] m;
x * v // vector[3]
r * v // real
v * r // matrix[3, 3]
m * v // vector[3]
m * m // matrix[3, 3]
m[1] // row_vector[3]
そのほかの特殊な制約つきの型
simplex
: 合計が1になる非負実数ベクトルunit_vector
: 二乗和が1になる実数ベクトルordered
,positive_ordered
: 昇順実数ベクトル。降順にしたければtransformed parameters
ブロックで。cov_matrix
,corr_matrix
,cholesky_factor_cov
,cholesky_factor_corr
Tips
条件分岐するときはなるべくif
文を避けて三項演算子やステップ関数を使うべし、
という言語が多いけどStanでは逆にif
文を素直に書くほうが良いらしい。
if_else()
では真値でも両方の引数が評価されちゃうし、
step()
や int_step()
からの掛け算は遅いのだとか。
代入演算子は普通に =
イコール。(昔は <-
矢印だった)
対数尤度の値を確認したいときは print("log_prob: ", target())
可視化
https://www.rdocumentation.org/packages/rstan/topics/Plots
stan_plot()
stan_trace()
stan_scat()
stan_hist()
stan_dens()
stan_ac()
# S3 method
pairs()
print()
stanfit
クラスのmethodとして plot()
や traceplot()
が定義されているが、
いくつかのチェックとともに stan_plot()
系の関数を呼び出すだけで大きな違いは無さそう。
トラブル対処
StanHeaders version is ahead of rstan version
Stanのヘッダーライブラリとrstanは別々のパッケージで提供されていて、
Stan更新への追従にタイムラグがあるらしい。
こんなん開発者側でどうにかして欲しいけど、
とりあえず古い StanHeaders
を入れてしのぐしかない。
https://github.com/stan-dev/rstan/wiki/RStan-Transition-Periods
install.packages("https://cran.r-project.org/src/contrib/Archive/StanHeaders/StanHeaders_2.9.0.tar.gz", repos=NULL, type="source")
https://cran.r-project.org/src/contrib/Archive/StanHeaders/
最新版をGitHubからインストール
リポジトリの構造が標準とはちょっと違う
remotes::install_github("stan-dev/rstan", ref="develop", subdir="rstan/rstan")