そのために解析は必要? 未加工の生データこそ宝?
往々にして複雑過ぎ、情報多すぎ、そのままでは手に負えない
print(ggplot2::diamonds)
carat cut color clarity depth table price x y z
1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
4 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
--
53937 0.72 Good D SI1 63.1 55 2757 5.69 5.75 3.61
53938 0.70 Very Good D SI1 62.8 60 2757 5.66 5.68 3.56
53939 0.86 Premium H SI2 61.0 58 2757 6.15 6.12 3.74
53940 0.75 Ideal D SI2 62.2 55 2757 5.83 5.87 3.64
ダイヤモンド53,940個について10項目の値を持つデータセット
各列の平均とか標準偏差とか:
stat carat depth table price
1 mean 0.80 61.75 57.46 3932.80
2 sd 0.47 1.43 2.23 3989.44
3 max 5.01 79.00 95.00 18823.00
4 min 0.20 43.00 43.00 326.00
大きさ carat
と価格 price
の相関係数はかなり高い:
carat depth table price
carat 1.00
depth 0.03 1.00
table 0.18 -0.30 1.00
price 0.92 -0.01 0.13 1.00
生のままよりは把握しやすいかも。
目的や状況に応じて使い分けよう。
「こんなことがたまたま起こる確率はすごく低いです!」
をちゃんと示す手続きが統計的仮説検定。
同じデータでも見せ方で印象・情報量が変わる。
情報をうまく絞って整理 → 直感的にわかる
carat
が大きいほど price
も高いらしい。
その度合いは clarity
によって異なるらしい。
データをうまくまとめ、それに基づいて推論するための手法。
「グラフを眺めてなんとなく分かる」以上の分析にはモデルが必要
対象システムを単純化・理想化して扱いやすくしたもの
データ生成をうまく真似できそうな仮定の数式表現。
データ生成をうまく真似できそうな仮定の数式表現。
e.g., 大きいほど高く売れる: $\text{price} = A \times \text{carat} + B + \epsilon$
新しく採れたダイヤモンドの価格予想とかにも使える。
このように「YをXの関数として表す」ようなモデルを回帰と呼ぶ。
単純な直線あてはめから出発し、ちょっとずつ統計モデリング。
でも統計モデリングはいわゆる“機械学習”とは違う気もする…?
項目 | 統計モデリング | 近年の機械学習 |
---|---|---|
モデル構造 | 単純化したい | 性能のためなら複雑化 |
モデル解釈 | ここが強み | 難しい。重視しない。途上。 |
予測・生成 | うまくすれば頑健 | 主目的。強力。高精度 |
データ量 | 少なくてもそれなり | 大量に必要 |
計算量 | 場合による | 場合による |
例 | 一般化線形モデル 階層ベイズモデル |
ランダムフォレスト ニューラルネットワーク |
教科書的には概ねこんな感じとして、実際の仕事ではどうだろう?
協力: @kato_kohaku
さん、@teuder
さん
久保先生の"緑本"こと
「データ解析のための統計モデリング入門」
をベースに回帰分析の概要を紹介。
統計解析と作図の機能が充実したプログラミング言語・環境
Workspace (Environment) = 一時オブジェクトの集まり
RStudio → Preferences command,
Tools → Global options
“Restore …” のチェックを外して、 “Save …” のNeverを選択
File → New Project… → New Directory → New Project →
→ Directory name: r-training-2022
→ as subdirectory of: ~/project
or C:/Users/yourname/project
📁 ディレクトリ = フォルダ。 ~/
= ホームディレクトリ
File → New File → R script
File → New File → R script
Select text with shift←↓↑→
Execute them with ctrlreturn
hello.R
🔰 いろんな四則演算を試して hello.R
に保存してみよう。
e.g., 1 + 2 + 3
, 3 * 7 * 2
, 4 / 2
, 4 / 3
, etc.
スクリプト、データ、結果を分けて整理する例:
r-training-2022/ # プロジェクトの最上階
├── r-training-2022.Rproj # これダブルクリックでRStudioを起動
├── hello.R
├── transform.R # データ整理・変形のスクリプト
├── visualize.R # 作図のスクリプト
├── data/ # 元データを置くところ
│ ├── iris.tsv
│ └── diamonds.xlsx
└── output/ # 結果の出力先
├── iris-petal.png
└── iris-summary.tsv
プロジェクト最上階を作業ディレクトリとし、
ファイル読み書きの基準にする。(後で詳しく)
ほんの一例です。好きな構造に決めてください。
とにかく手を動かして体感しよう!
こういう枠が出てきたら、自分のRスクリプトにコピペして保存:
head(iris)
実行してコンソールを確認。思ったとおりの出力?
Error
や Warning
があったらよく読んで対処する。
(無視していい Warning
もたまーにあるけど)
🔰若葉マークの練習問題があれば解いてみる。
そこまでのコードのコピペ+改変でできるはず。
疑問・困りごとがある場合は気軽にChat欄に書き込んでください。
x = 2 # Create x
x # What's in x?
[1] 2
y = 5 # Create y
y # What's in y?
[1] 5
Rでは代入演算子として矢印 <-
も使えるけど私は =
推奨。
#
記号より右はRに無視される。コメントを書くのに便利。
x + y
[1] 7
🔰 x
と y
の引き算、掛け算、割り算をやってみよう
+
とか *
のような演算子(operator)を変数の間に置く。
10 + 3 # addition
10 - 3 # subtraction
10 * 3 # multiplication
10 / 3 # division
10 %/% 3 # integer division
10 %% 3 # modulus 剰余
10 ** 3 # exponent 10^3
🔰 コピペして結果を確認してみよう。
変数を受け取って、何か仕事して、返す命令セット
x = seq(1, 3) # 1と3を渡すとvectorが返ってくる
x
[1] 1 2 3
sum(x) # vectorを渡すと足し算が返ってくる
[1] 6
square = function(something) { # 自分の関数を定義
something ** 2
}
square(x) # 使ってみる
[1] 1 4 9
🔰 自分の関数を何か作ってみよう。
e.g., 2倍にする関数 twice
x = 42 # Create x
x # What's in x?
[1] 42
y = "24601" # Create y
y # What's in y?
[1] "24601"
この x
と y
を足そうとするとエラーになる。なぜ?
x + y # Error! Why?
Error in x + y: non-numeric argument to binary operator
class(x)
[1] "numeric"
is.numeric(x)
[1] TRUE
is.character(x)
[1] FALSE
as.character(x)
[1] "42"
🔰 さっき作った y
にも同じ関数を適用してみよう。
vector
: 基本型。一次元の配列。
logical
: 論理値 (TRUE
or FALSE
)numeric
: 数値 (整数 42L
or 実数 3.1416
)character
: 文字列 ("a string"
)factor
: 因子 (文字列っぽいけど微妙に違う)array
: 多次元配列。vector
同様、全要素が同じ型。
matrix
: 行列 = 二次元の配列。list
: 異なる型でも詰め込める太っ腹ベクトル。data.frame
: 同じ長さのベクトルを並べた長方形のテーブル。重要。 tibble
とか tbl_df
と呼ばれる亜種もあるけどほぼ同じ。1個の値でもベクトル扱い。
同じ長さ(または長さ1)の相手との計算が得意。
x = c(1, 2, 9) # 長さ3の数値ベクトル
x + x # 同じ長さ同士の計算
[1] 2 4 18
y = 10 # 長さ1の数値ベクトル
x + y # 長さ3 + 長さ1 = 長さ3 (それぞれ足し算)
[1] 11 12 19
x < 5 # 5より小さいか
[1] TRUE TRUE FALSE
🔰 この x, y
を使っていろいろな演算を試してみよう
[]
を使う。番号は1から始まる。
letters
[1] "a" "b" "c" "d" "e" "f" "g" "h" "i" "j" "k" "l" "m" "n" "o" "p" "q" "r" "s" "t" "u" "v" "w" "x" "y" "z"
letters[3]
[1] "c"
letters[seq(4, 6)] # 4 5 6
[1] "d" "e" "f"
letters[seq(1, 26) < 4] # TRUE TRUE TRUE FALSE FALSE ...
[1] "a" "b" "c"
各要素に適用するもの:
x = c(1, 2, 9)
y = sqrt(x) # square root
y
[1] 1.000000 1.414214 3.000000
全要素を集約した値を返すもの:
z = sum(x)
z
[1] 12
🔰 log()
, exp()
, length()
, max()
, mean()
にvectorを渡してみよう。
1本のvectorを折り曲げて長方形にしたもの。
中身は全て同じ型。機械学習とか画像処理とかで使う。
v = seq(1, 8) # c(1, 2, 3, 4, 5, 6, 7, 8)
x = matrix(v, nrow = 2) # 2行に畳む。列ごとに詰める
x
[,1] [,2] [,3] [,4]
[1,] 1 3 5 7
[2,] 2 4 6 8
y = matrix(v, nrow = 2, byrow = TRUE) # 行ごとに詰める
y
[,1] [,2] [,3] [,4]
[1,] 1 2 3 4
[2,] 5 6 7 8
🔰 結果を確認してみよう: x + y
, dim(x)
, nrow(x)
, ncol(x)
.
同じ長さの列vectorを複数束ねた長方形の表。
e.g., 長さ150の数値ベクトル4本と因子ベクトル1本:
print(iris)
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1 5.1 3.5 1.4 0.2 setosa
2 4.9 3.0 1.4 0.2 setosa
3 4.7 3.2 1.3 0.2 setosa
4 4.6 3.1 1.5 0.2 setosa
--
147 6.3 2.5 5.0 1.9 virginica
148 6.5 3.0 5.2 2.0 virginica
149 6.2 3.4 5.4 2.3 virginica
150 5.9 3.0 5.1 1.8 virginica
iris
はアヤメ属3種150個体に関する測定データ。
Rに最初から入ってて、例としてよく使われる。
概要を掴む:
head(iris, 6) # 先頭だけ見てみる。末尾は tail()
nrow(iris) # 行数: Number of ROWs
ncol(iris) # 列数: Number of COLumns
names(iris) # 列名
summary(iris) # 要約
View(iris) # RStudioで閲覧
str(iris) # 構造が分かる形で表示
tibble [150 × 5] (S3: tbl_df/tbl/data.frame)
$ Sepal.Length: num [1:150] 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
$ Sepal.Width : num [1:150] 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
$ Petal.Length: num [1:150] 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
$ Petal.Width : num [1:150] 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
$ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
🔰 ほかのデータもいろいろ見てみよう。
e.g., mtcars
, quakes
, data()
部分的なdata.frameを取得する:
iris[2, ] # 2行目
iris[2:5, ] # 2行目から5行目まで
iris[, 3:4] # 3-4列目
iris[2:5, 3:4] # 2-5行目, 3-4列目
vectorとして取得する:
iris[[3]] # 3列目
iris$Petal.Length # Petal.Length列
iris[["Petal.Length"]] # Petal.Length列
iris[["Petal.Length"]][2] # Petal.Length列の2番目
結果がdata.frameになるかvectorになるか微妙:
iris[, 3] # 3列目
iris[, "Petal.Length"] # Petal.Length列
iris[2, 3] # 2行目3列目
iris[2, "Petal.Length"] # 2行目Petal.Length列
同じ長さの 列(column) vector を結合して作る:
x = c(1, 2, 3)
y = c("A", "B", "C")
mydata = data.frame(x, y)
print(mydata)
x y
1 1 A
2 2 B
3 3 C
🔰 次のようなdata.frameを作って theDF
と名付けよう:
i s
24 x
25 y
26 z
ヒント: c()
無しでクリアすることも可能。
readxlパッケージを使えば .xlsx
ファイルも読める、けど
カンマ区切り(CSV)とかタブ区切り(TSV)のテキストが無難。
ファイル名は作業ディレクトリからの相対パスで指定。
install.packages("readr") # R標準の read.table() とかは難しいので
library(readr) # パッケージのやつを使うよ
write_tsv(iris, "data/iris.tsv") # 書き出し
iris2 = read_tsv("data/iris.tsv") # 読み込み
あれれー、エラーが出る?
Error: Cannot open file for writing:
* 'data/iris.tsv'
冷静に、現在の作業ディレクトリとその中身を確認しよう:
getwd() # GET Working Directory
list.files(".") # List files in "." (here)
list.files("data") # List files in "./data"
dir.create("data") # Create directory
よくあるエラー集 (石川由希さん@名古屋大) を読んでおきましょう。
🔰 R組み込みデータや自作データを読み書きしてみよう。
便利な関数やデータセットなどをひとまとめにしたもの。
install.packages("readr") # 一度やればOK
library(readr) # 読み込みはRを起動するたびに必要
update.packages() # たまには更新しよう
install.packages("tidyverse")
library(conflicted) # 安全のおまじない
library(tidyverse) # 一挙に読み込み
── Attaching core tidyverse packages ──── tidyverse 2.0.0 ──
✔ dplyr 1.1.1 ✔ readr 2.1.4
✔ forcats 1.0.0 ✔ stringr 1.5.0
✔ ggplot2 3.4.1 ✔ tibble 3.2.1
✔ lubridate 1.9.2 ✔ tidyr 1.3.0
✔ purrr 1.0.1
一貫したデザインでデータ解析の様々な工程をカバー
R標準のやつとは根本的に違うシステムで作図する。
+
で重ねていく+
で重ねていくggplot(data = diamonds) # diamondsデータでキャンバス準備
# aes(x = carat, y = price) + # carat,price列をx,y軸にmapping
# geom_point() + # 散布図を描く
# facet_wrap(vars(clarity)) + # clarity列に応じてパネル分割
# stat_smooth(method = lm) + # 直線回帰を追加
# coord_cartesian(ylim = c(0, 2e4)) + # y軸の表示範囲を狭く
# theme_classic(base_size = 20) # クラシックなテーマで
+
で重ねていくggplot(data = diamonds) + # diamondsデータでキャンバス準備
aes(x = carat, y = price) # carat,price列をx,y軸にmapping
# geom_point() + # 散布図を描く
# facet_wrap(vars(clarity)) + # clarity列に応じてパネル分割
# stat_smooth(method = lm) + # 直線回帰を追加
# coord_cartesian(ylim = c(0, 2e4)) + # y軸の表示範囲を狭く
# theme_classic(base_size = 20) # クラシックなテーマで
+
で重ねていくggplot(data = diamonds) + # diamondsデータでキャンバス準備
aes(x = carat, y = price) + # carat,price列をx,y軸にmapping
geom_point() # 散布図を描く
# facet_wrap(vars(clarity)) + # clarity列に応じてパネル分割
# stat_smooth(method = lm) + # 直線回帰を追加
# coord_cartesian(ylim = c(0, 2e4)) + # y軸の表示範囲を狭く
# theme_classic(base_size = 20) # クラシックなテーマで
+
で重ねていくggplot(data = diamonds) + # diamondsデータでキャンバス準備
aes(x = carat, y = price) + # carat,price列をx,y軸にmapping
geom_point() + # 散布図を描く
facet_wrap(vars(clarity)) # clarity列に応じてパネル分割
# stat_smooth(method = lm) + # 直線回帰を追加
# coord_cartesian(ylim = c(0, 2e4)) + # y軸の表示範囲を狭く
# theme_classic(base_size = 20) # クラシックなテーマで
+
で重ねていくggplot(data = diamonds) + # diamondsデータでキャンバス準備
aes(x = carat, y = price) + # carat,price列をx,y軸にmapping
geom_point() + # 散布図を描く
facet_wrap(vars(clarity)) + # clarity列に応じてパネル分割
stat_smooth(method = lm) # 直線回帰を追加
# coord_cartesian(ylim = c(0, 2e4)) + # y軸の表示範囲を狭く
# theme_classic(base_size = 20) # クラシックなテーマで
+
で重ねていくggplot(data = diamonds) + # diamondsデータでキャンバス準備
aes(x = carat, y = price) + # carat,price列をx,y軸にmapping
geom_point() + # 散布図を描く
facet_wrap(vars(clarity)) + # clarity列に応じてパネル分割
stat_smooth(method = lm) + # 直線回帰を追加
coord_cartesian(ylim = c(0, 2e4)) # y軸の表示範囲を狭く
# theme_classic(base_size = 20) # クラシックなテーマで
+
で重ねていくggplot(data = diamonds) + # diamondsデータでキャンバス準備
aes(x = carat, y = price) + # carat,price列をx,y軸にmapping
geom_point() + # 散布図を描く
facet_wrap(vars(clarity)) + # clarity列に応じてパネル分割
stat_smooth(method = lm) + # 直線回帰を追加
coord_cartesian(ylim = c(0, 2e4)) + # y軸の表示範囲を狭く
theme_classic(base_size = 20) # クラシックなテーマで
+
で重ねていくggplot(data = diamonds) + # diamondsデータでキャンバス準備
aes(x = carat, y = price) + # carat,price列をx,y軸にmapping
geom_point() + # 散布図を描く
# facet_wrap(vars(clarity)) + # clarity列に応じてパネル分割
# stat_smooth(method = lm) + # 直線回帰を追加
# coord_cartesian(ylim = c(0, 2e4)) + # y軸の表示範囲を狭く
theme_classic(base_size = 20) # クラシックなテーマで
p1 = ggplot(data = diamonds)
p2 = p1 + aes(x = carat, y = price)
p3 = p2 + geom_point()
p4 = p3 + facet_wrap(vars(clarity))
print(p3)
width
やheight
が小さいほど、文字・点・線が相対的に大きく
# 7inch x 300dpi = 2100px四方 (デフォルト)
ggsave("dia1.png", p3) # width = 7, height = 7, dpi = 300
# 4 x 300 = 1200 全体7/4倍ズーム
ggsave("dia2.png", p3, width = 4, height = 4) # dpi = 300
# 2 x 600 = 1200 全体をさらに2倍ズーム
ggsave("dia3.png", p3, width = 2, height = 2, dpi = 600)
# 4 x 300 = 1200 テーマを使って文字だけ拡大
ggsave("dia4.png", p3 + theme_bw(base_size = 22), width = 4, height = 4)
別のパッケージ (cowplot や patchwork) の助けを借りて
pAB = cowplot::plot_grid(p3, p3, labels = c("A", "B"), nrow = 1L)
cowplot::plot_grid(pAB, p3, labels = c("", "C"), ncol = 1L)
ggplot(data, ...)
, glm(..., data, ...)
, …Happy families are all alike;
every unhappy family is unhappy in its own way
— Leo Tolstoy “Anna Karenina”
tidy datasets are all alike,
but every messy dataset is messy in its own way
— Hadley Wickham
print(ggplot2::diamonds)
carat cut color clarity depth table price x y z
1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
4 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
--
53937 0.72 Good D SI1 63.1 55 2757 5.69 5.75 3.61
53938 0.70 Very Good D SI1 62.8 60 2757 5.66 5.68 3.56
53939 0.86 Premium H SI2 61.0 58 2757 6.15 6.12 3.74
53940 0.75 Ideal D SI2 62.2 55 2757 5.83 5.87 3.64
x軸、y軸、色分け、パネル分けなどを列の名前で指定して簡単作図:
ggplot(diamonds) + aes(x = carat, y = price) +
geom_point(mapping = aes(color = color, size = clarity)) +
facet_wrap(vars(cut))
print(VADeaths)
Rural Male Rural Female Urban Male Urban Female
50-54 11.7 8.7 15.4 8.4
55-59 18.1 11.7 24.3 13.6
60-64 26.9 20.3 37.0 19.3
65-69 41.0 30.9 54.6 35.1
70-74 66.0 54.3 71.1 50.0
↓ 下ごしらえ: 作図・解析で使いやすい整然データに
lbound ubound region sex death
1 50 54 Rural Male 11.7
2 50 54 Rural Female 8.7
3 50 54 Urban Male 15.4
4 50 54 Urban Female 8.4
--
17 70 74 Rural Male 66.0
18 70 74 Rural Female 54.3
19 70 74 Urban Male 71.1
20 70 74 Urban Female 50.0
シンプルな関数がたくさん。繋げて使う (piping)
select()
,filter()
, distinct()
, slice()
group_by()
, summarize()
, count()
arrange()
, relocate()
mutate()
, rename()
bind_rows()
left_join()
, inner_join()
, full_join()
小さな関数を繋げて使う流れ作業:
result = diamonds |> # 生データから出発して
select(carat, cut, price) |> # 列を抽出して
filter(carat > 1) |> # 行を抽出して
group_by(cut) |> # グループ化して
summarize(mean(price)) |> # 平均を計算
print() # 表示してみる
cut mean(price)
1 Fair 7177.856
2 Good 7753.601
3 Very Good 8340.549
4 Premium 8487.249
5 Ideal 8674.227
この見慣れぬ記号 |>
は何?
(select()
など個々の関数には今日は触れません)
|>
パイプの左側の変数を、右側の関数の第一引数にねじ込む:
diamonds |> filter(carat > 1)
filter(diamonds, carat > 1) # これと同じ
# 前処理の流れ作業に便利:
diamonds |> filter(carat > 1) |> select(carat, price) |> ...
potatoes |> cut() |> fry() |> season("salt") |> eat()
🔰 パイプを使わない形に書き換え、出力を確認しよう:
seq(1, 6) |> sum()
[1] 21
letters |> toupper() |> head(3)
[1] "A" "B" "C"
[解答例]
sum(seq(1, 6))
head(toupper(letters), 3)
|>
を使わない方法😐 一時変数をイチイチ作る:
tmp1 = select(diamonds, carat, cut, price) # 列を抽出して
tmp2 = filter(tmp1, carat > 1) # 行を抽出して
tmp3 = group_by(tmp2, cut) # グループ化して
result = summarize(tmp3, mean(price)) # 平均を計算
😐 同じ名前を使い回す:
result = select(diamonds, carat, cut, price) # 列を抽出して
result = filter(result, carat > 1) # 行を抽出して
result = group_by(result, cut) # グループ化して
result = summarize(result, mean(price)) # 平均を計算
どちらも悪くない。 何度も変数名を入力するのがやや冗長。
|>
を使わない方法😫 一時変数を使わずに:
result = summarize( # 平均を計算
group_by( # グループ化して
filter( # 行を抽出して
select(diamonds, carat, cut, price), # 列を抽出して
carat > 1), # 行を抽出して
cut), # グループ化して
mean(price)) # 平均を計算
🤪 改行さえせずに:
result = summarize(group_by(filter(select(diamonds, carat, cut, price), carat > 1), cut), mean(price))
論理の流れとプログラムの流れが合わず、目が行ったり来たり。
さっきのほうがぜんぜんマシ。
|>
を使おう😁 慣れれば、論理の流れを追いやすい:
result = diamonds |>
select(carat, cut, price) |> # 列を抽出して
filter(carat > 1) |> # 行を抽出して
group_by(cut) |> # グループ化して
summarize(mean(price)) |> # 平均を計算
print() # 表示してみる
cut mean(price)
1 Fair 7177.856
2 Good 7753.601
3 Very Good 8340.549
4 Premium 8487.249
5 Ideal 8674.227
tidyverseパッケージ群はこういう使い方をしやすい設計。
使わなければならないわけではないが、読めたほうがいい。
R < 4.2 までよく使われていた %>%
もほぼ同じ。
pivot_longer()
, gather()
pivot_wider()
, spread()
separate()
, unite()
nest()
, unnest()
etc.
pivot_longer()
横広から縦長に複数列にまたがる値を1列にする。
そのラベルも合わせて移動。
table4a
pivot_longer(table4a, 2:3, names_to = "year", values_to = "cases")
pivot_wider()
縦長から横広に1列に収まっていた値を複数列の行列に変換。
そのラベルを列の名前にする。
pivot_wider(table2, names_from = type, values_from = count)
2022年9月に開講。おそらくe-なんとかで視聴可能。
2023年の開講は未定…?
(説明のために作った架空のデータ。今後もほぼそうです)
Define a family of models: だいたいどんな形か、式をたてる
Generate a fitted model: データに合うようにパラメータを調整
なんとなく $y = a x + b$ でいい線が引けそう
なんとなく $y = a x + b$ でいい線が引けそう
じゃあ傾き a と切片 b、どう決める?
回帰直線からの残差平方和(RSS)を最小化する。
ランダムに試してみて、上位のものを採用。
この程度の試行回数では足りなそう。
グリッドサーチ: パラメータ空間の一定範囲内を均等に試す。
さっきのランダムよりはちょっとマシか。
こうした最適化の手法はいろいろあるけど、ここでは扱わない。
par_init = c(intercept = 0, slope = 0)
result = optim(par_init, fn = rss_weight, data = df_weight)
result$par
intercept slope
-69.68394 78.53490
上記コードは最適化一般の書き方。
回帰が目的なら次ページのようにするのが楽 →
lm()
で直線あてはめしてみる架空のデータを作る(乱数生成については後述):
n = 50
df_weight = tibble::tibble(
height = rnorm(n, 1.70, 0.05),
bmi = rnorm(n, 22, 1),
weight = bmi * (height**2)
)
データと関係式(Y ~ X
の形)を lm()
に渡して係数を読む:
fit = lm(data = df_weight, formula = weight ~ height)
coef(fit)
(Intercept) height
-69.85222 78.63444
せっかくなので作図もやってみる→
lm()
の結果をggplotするdf = modelr::add_predictions(df_weight, fit, type = "response")
head(df, 2L)
height bmi weight pred
1 1.718019 21.55500 63.62151 65.24322
2 1.782862 22.83775 72.59199 70.34213
ggplot(df) +
aes(height, weight) +
geom_point() +
geom_line(aes(y = pred), linewidth = 1, color = "#3366ff")
lm()
を試してみようfit = lm(data = mpg, formula = hwy ~ displ)
coef(fit)
(Intercept) displ
35.697651 -3.530589
mpg_added = modelr::add_predictions(mpg, fit)
ggplot(mpg_added) + aes(displ, hwy) + geom_point() +
geom_line(aes(y = pred), linewidth = 1, color = "#3366ff")
🔰 diamonds
などほかのデータでも lm()
を試してみよう。
No such file or directory
str(iris)
, attributes(iris)
?sum
, help.start()