最適な学習アルゴリズム・重み・ハイパーパラメータの決め方

本記事は、ディープラーニング入門シリーズの第4回目です。

スポンサーリンク

最適化アルゴリズムの一覧と比較

最適化アルゴリズム最適化アルゴリズムとは、「損失関数を0に近づけるように、重みを決定する」アルゴリズムです。

ニューラルネットワークの学習で利用する最適化アルゴリズムには、以下の種類があります。

SGD

SGD とは、重みをマイナスの勾配方向に進めることで、損失関数を0に近づけるアルゴリズムです。

SGD を表す数式は以下のとおりです。(W は重み、← は代入、ηは学習率、∂L/∂W は勾配)

$$ W ← W - \eta \frac{∂L}{∂W} $$

SGD アルゴリズムについては、以下の記事でより詳しく説明しているのでご覧ください。

SGD の欠点

SGD は、関数の変数が均等でない場合は、非効率な探索を行います。

以下のように、x よりも y のほうが f(x, y) の値の変化に大きな影響を与える関数を例に考えます。

$$ f(x, y) = \frac{1}{20}x^2 + y^2 $$

この関数では y = 0 が最小ですが、y 軸方向で発散 (y = 0 を通り過ぎて行ったり来たり) し、非効率な探索をしています。

SGD のこの欠点を解決するために、以下の4つの方法が存在します。

Momentum

Momentum とは、斜面で物体が転がり落ちる概念を利用して、損失関数を0に落とすアルゴリズムです。

以下の Momentum では、斜面からボールを落とすように徐々に高度が下がっていく様子を確認できます。(SGD と比較して、y 軸で発散する回数が減っています。)

Momentum を表す数式は以下のとおりです。
(v は速度、← は代入、α は 0.0~1.0、ηは学習率、∂L/∂W は勾配、W は重み)

$$ v ← αv - \eta \frac{∂L}{∂W} $$

$$ W ← W + v $$

AdaGrad

AdaGrad とは、学習率を徐々に小さくすることで損失関数を0に近づけるアルゴリズムです。

以下の AdaGrad では、y 軸で発散せずに効率的に学習が行えていることがわかります。

  • 初めに、高い学習率で一気に y 軸を下り、学習時間を短縮しています
  • その後、徐々に学習率を落とすことで y 軸を通りすぎないようにしています
大きく更新したパラメータほど、次回以降の学習係数が小さくなります。

AdaGrad を表す数式は以下のとおりです。
(h は学習率を下げるもの、← は代入、∂L/∂W は勾配、W は重み、ηは学習率)

$$ h ← h + \frac{∂L}{∂W} \otimes \frac{∂L}{∂W}$$

$$ W ← W - \eta \frac{1}{\sqrt{h}} \frac{∂L}{∂W}$$

$$ \otimes は行列の要素ごとに掛け算$$

RMSprop

RMSprop とは、AdaGrad を改良し、損失関数を0に近づけるアルゴリズムです。
RMSprop は AdaGrad と比較して最近の勾配ほど強く影響を受けます。

AdaGrad は学習率が常に小さくなるため、学習初期に大きな更新があったパラメータは、その後ほとんど更新が行われなくなる問題があります。

これを改良するために、RMSprop は AdaGrad と比較して最近の勾配ほど強く影響を受けます。

y 軸の発散が少ない

RMSprop を表す数式は以下のとおりです。
(h は学習率を下げる値、← は代入、d は減衰率、∂L/∂W は勾配、W は重み、ηは学習率)

$$ h ← dh $$

$$ h ← h + (1 - d) * \frac{∂L}{∂W} \otimes \frac{∂L}{∂W}$$

$$ W ← W - \eta \frac{1}{\sqrt{h}} \frac{∂L}{∂W}$$

$$ \otimes は行列の要素ごとに掛け算 $$

Adam

Adam とは、Momentum と RMSProp を組み合わせたアルゴリズムです。

Adam は MomentumRMSProp を組み合わせた動きとなっていることがわかります。

Adam を表す数式は以下のとおりです。
(β1 は 0.9 が標準、β2 は 0.999 が標準、 v は Momentum ベース、h は RMSProp ベース、← は代入、∂L/∂W は勾配、W は重み、ηは学習率)

$$ v ← β_1 v + (1 - β_1) \frac{∂L}{∂W}$$

$$ h ← β_2 h + (1 - β_2) \frac{∂L}{∂W} \otimes \frac{∂L}{∂W} $$

$$ ov ← \frac{v}{(1 - β_1)} $$

$$ oh ← \frac{h}{(1 - β_2)} $$

$$ W ← W - \eta \frac{ov}{\sqrt{oh}}$$

参考数式

https://tech-lab.sios.jp/archives/21823#Adam
https://qiita.com/omiita/items/1735c1d048fe5f611f80#6-rmsprop
スポンサーリンク

重みの初期値

ニューラルネットワークの学習では、重みの初期値が学習結果に大きな影響を与えます。

悪い重みの初期値

一般的に以下の重みの初期値はニューラルネットワークの学習が進まないとされています。

重みの初期値が全て 0 (均一)

以下の式を見てわかるように、重みの初期値0の場合は 入力 X の値がすべて消えてしまうため正しく学習が行えません。

$${A_1 = \left\{\begin{array}{ll}
a_{11} = (b_1+w ^{1}_{11}x_1+w ^{1}_{12}x_1) \\
a_{12} = (b_1+w ^{1}_{21}x_2+w ^{1}_{22}x_2) \\
\end{array} \right. }$$

また、0以外の均一な重みを設定した場合も、同じ比の総和が伝播され層を深くする意味がなくなってしまいます。そのため、均一な重みは推奨されません。

活性化関数の傾きが0に近づく場合

以下の Github にあるニューラルネットワークの学習で、重みの初期値に標準偏差1の正規分布を利用した場合、各中間層の活性化関数 (シグモイド関数) の出力は上記のグラフのようになります。

https://github.com/oreilly-japan/deep-learning-from-scratch/blob/master/ch06/weight_init_activation_histogram.py

中間層の活性化関数の結果が0と1に偏っていますが、これは次に紹介する勾配消失という問題が発生します。

勾配消失
勾配消失とは、勾配 (傾き) がどんどん小さくなり消えてしまうことです。
勾配消失が発生すると、ニューラルネットワークの学習が進まなくなります。

シグモイド関数は0と1に近づくほど、傾きが0に近づきます。

https://github.com/oreilly-japan/deep-learning-from-scratch

また、層が深くなると勾配を求める際の連鎖律により、傾きの掛け算をする回数が増えます。

そのため、「傾きが0に近く」、「層が深い」ほど勾配が0に近づき、消失します。
(例:1層目 0.001 * 2層目 0.001 * 3層目 0.001 = 勾配 0.000000001)

最適化アルゴリズムの数式を見てわかるように、勾配 (∂L/∂W) が0に近づくと、ニューラルネットワークの学習は進まなくなります。

活性化関数の出力の分布が偏る場合

以下の Github にあるニューラルネットワークの学習で、重みの初期値に標準偏差 0.01 の正規分布を利用した場合、各中間層の活性化関数 (シグモイド関数) の出力は上記のグラフのようになります。

https://github.com/oreilly-japan/deep-learning-from-scratch/blob/master/ch06/weight_init_activation_histogram.py

中間層の活性化関数の出力は 0.5 付近に集中しています。

これは、中間層がほとんど同じ内容しか伝播できないため、入力の違いをあまり表現できていないことを意味します。

良い重みの初期値

悪い重みの初期値の内容から、活性化関数の出力の分布に広がりがあるように、重みの初期値を設定すればよいことがわかります。

一般的に良いとされる、重みの初期値は以下のとおりです。

良い重みの初期値適する活性化関数
Xavier の初期値シグモイド関数、tanh 関数
He の初期値ReLU 関数

Xavier の初期値

Xavier の初期値とは、1/√n の標準偏差を持つ正規分布を、重みの初期値に使うことです。
※ n は前層のノード数

これまでより、出力の結果に広がりがあるため、入力の違いを表現できています。

また、0と1に値が少ないことから、勾配消失の発生を抑えています。

He の初期値

He の初期値とは、√(2/n) を標準偏差とする正規分布を、重みの初期値に使うことです。
※ n は前層のノード数

中間層の活性化関数ReLU を利用する場合、He の初期値を利用します。

重みに各初期値を適用した結果を以下に示します。

標準偏差 0.01 の正規分布の初期値
Xavier の初期値
He の初期値

https://github.com/oreilly-japan/deep-learning-from-scratch/blob/master/ch06/weight_init_activation_histogram.py より

  • 標準偏差 0.01 の正規分布の初期値の場合、0 に集中するため勾配消失が発生します
  • Xavier の初期値の場合、層が深くなるにつれ、0 に偏るため表現が偏ったり、勾配消失が発生します
  • He の初期値の場合、層が深くなっても分布の偏りがありません。

各重みの初期値でごとに、ニューラルネットワークの学習の進み方を確認してみます。

https://github.com/oreilly-japan/deep-learning-from-scratch/blob/master/ch06/weight_init_compare.py

上の図では、損失関数 (loss) が減ると、ニューラルネットワークの学習が進んでいることを表します。

  • 標準偏差 0.01 の正規分布の初期値の場合、損失関数がほとんど減っていません
  • Xavier の初期値と He の初期値では、He の初期値のほうが学習が早く、予測精度も高いことがわかります。
スポンサーリンク

Batch Normalization とは

Batch Normalization とはBatch Normalization とは、中間層の活性化関数の前に、全結合層の出力の分布を正規化する Batch Normalization レイヤを挟む方法です。

重みの初期値を適切に設定する方法」と「Batch Normalization」を比較すると以下のとおりです。

重みの初期値を適切に設定する方法」を利用する場合

重み W の初期値を調整することで、活性化関数 ReLU の出力 Z の分布を広げる

「Batch Normalization」を利用する場合

全結合層の出力を正規化する Batch Normalization レイヤを挟むことで、活性化関数 ReLU の入力時点での分布を広げる

正規化の手順

正規化の手順は以下のとおりです。

$$ 平均 \mu _B ← \frac{1}{m} \sum_{i=1}^{m} x_i $$

$$ 分散 \sigma _B^2 ← \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu _B)^2 $$

$$ 正規化 x_{i}^{'} ← \frac{x_i - \mu _B}{\sqrt{\sigma _B^2 + \epsilon}} $$

※ x は入力、m は入力データの個数、εは 0 で除算することを防ぐための限りなく小さい数

Batch Normalization の数式

Batch Normalization レイヤの処理を数式で表すと以下のとおりです。

$$ y_i = \gamma x _{i}^{'} + \beta $$

※x' は正規化した入力、パラメータ γ = 1, β = 0 (一般的な初期値は左記)

Batch Normalization の効果

Batch Normalization の利点は以下の3つです。

  • 学習を早く進行できる (学習率を大きくできる)
  • 重みの初期値にそれほど依存しない (Batch Normalization レイヤーで是正する)
  • 過学習を抑制

過学習とは

過学習とは過学習とは、訓練データ (事前学習用のデータ) だけに過度に適用した状態です。
過学習の問題は、初めてみるデータに対して推論の精度が低いことです。

過学習の起きる原因は以下の3つです。

  • パラメータを大量に持ち、表現力が高いニューラルネットワークのモデル
  • 重みパラメータが大きな値をとる
  • 訓練データが少ない

過学習を防ぐ手法は以下の2つがあります。

Weight decay (荷重減衰)

Weight decay (荷重減衰) とは、重みが大きいほどペナルティを与える (損失関数を増やす) 方法です。これにより、重みが大きな値を取りすぎることを防ぎます。

具体的には損失関数に 1/2λW^2 のノルムを加算します。
(λはペナルティの大きさを調整するパラメータ。勾配には、微分した λW を加算)

なお、加算するノルムによって、以下のように呼ばれます。

$$ L1 ノルム = |w_{1}^{2}|+ |w_{2}^{2}|+・・・+ |w_{n}^{2}| $$

$$ L2 ノルム = \sqrt{w_{1}^{2}+ w_{2}^{2}・・・+ w_{n}^{2}} $$

$$ L \infty ノルム = max|W| (w の絶対値の最大値) $$

Dropout

Dropout とは、ニューロンをランダムに消去しながらニューラルネットワークの学習をする方法です。

Weight decay は実装が簡単ですが、ニューラルネットワークのモデルが複雑になると効果が薄くなります。そこで、Dropout を利用します。

なお、訓練データを流すたびに消去するニューロンを変更します。

ハイパーパラメータの最適化

ハイパーパラメータハイパーパラメータとは、ニューラルネットワークの学習を制御するためのパラメータです。

ハイパーパラメータの例は以下のとおりです。

重みやバイアスはニューラルネットワークの学習によって自動で取得されるのに対して、ハイパーパラメータは人間が手動で設定します。

なお、「ハイパーパラメータを全通り検証し、予測精度が高いもの選ぶ」のが理想ですが、1回の計算にめちゃくちゃ時間がかかります。そのため、いくつかのハイパーパラメータを試し、予測精度が高いものを選びます。

ハイパーパラメータの決め方には、主に以下の3つの方法があります。

ランダム検索とは

ランダム検索とはとは、ランダムなハイパーパラメータに対してニューラルネットワークの学習を行い、もっとも良い予測精度を示したハイパーパラメータを採用します。

例えば、ハイパーパラメータ「学習率」をランダムに設定し、以下の結果を得たとします。

この場合は学習率 0.9 を採用します。(もしくは学習率をあげれば予測精度が上がりそうなので、学習率0.95 で検証してみる)

グリッド検索 (グリッドサーチ) とは

グリッド検索 (グリッドサーチ) とは、グリッドサーチとは、指定したハイパーパラメータの全ての組み合わせに対してニューラルネットワークの学習を行い、もっとも良い予測精度を示したハイパーパラメータを採用します。

初めにランダム検索によりなんとなく良さそうなハイパーパラメータの傾向を調べ、その値に対してグリッド検索を使う方法があります。

ベイズ最適化とは (自動化)

ベイズ最適化とは とは、正規分布を利用して、ハイパーパラメータ x に対する予測精度を出力する関数 f(x) を作成します。

砕けた言い方をすると、「確率的に考えて、平均値付近に点が集まる可能性が高いよね。」という考えです。

結局、ガウス過程って何?
任意の点集合x1,...,xNに対するy(x)の同時分布が、ガウス分布に従うもの。

https://qiita.com/typecprint/items/2745932fdaf36d763623

予測精度を出力する関数 f(x) が分かれば、f(x) の値が良くなるようにハイパーパラメータ x を決定できます。

ベイズ最適化の流れ (活用と探索による自動調整)

※超雰囲気解説です。厳密には参考資料を読んでください。

未知の関数 = 予測精度を出力する関数

上記の予測精度を出力する関数の、一番高いところを求めたい。
(予測精度を出力する関数に損失関数を使う場合は、一番低いところ)

2点はハイパーパラメータ x1, x2 で実際に学習し、予測精度を出力する関数 f(x1), f(x2) を求める
xは横軸、f(x) は縦軸
平均を元に値を f(x) を推測することを「活用」といいます
[今までの経験 (平均) からある程度予想できること]
分散を元に次の x を決定することを「探索」といいます
[経験のないこと (実際にどれぐらい分散するか) に挑戦して知識 (結果) を得ること]
ハイパーパラメータの値 x3 は x1 と x2 の真ん中の値にする
分散 + 平均を acquisition function と呼びます。acquisition function は色んな種類があります。
ハイパーパラメータ x3 で実際に学習 (探索)すると、予測精度を出力する関数 f(x3) が小さくなりました。
ハイパーパラメータ x4, x5 を両端にします
ハイパーパラメータ x4 で学習 (探索) し、予測精度を出力する関数 f(x4) を求める
ハイパーパラメータ x5 で学習 (探索) し、予測精度を出力する関数 f(x5) を求める
ハイパーパラメータ x6 で学習 (探索) し、予測精度を出力する関数 f(x6) を求める
ここが一番高そうです。
ハイパーパラメータ x6 が一番予測精度を出力する関数 f(x) の値が高い付近となっています。

「真の予測精度を出力する関数」と「ベイズ最適化で推測した予測精度を出力する関数」がだいたい同じとなり、高い予測精度を持つハイパーパラメータ x6 が求められました。

関連記事

ディープラーニング入門記事の続きは以下のとおりです。


参考資料

誰でも理解できるガウス過程とガウス過程回帰 基礎編 - Qiita
はじめにガウス過程回帰は機械学習領域ではかなりメジャーなので、皆様御存知かと思いますが、ここで本当に初学者がガウス過程回帰にいきつくような解説と、多次元まで対応したPythonのガウス過程回帰ク…
ベイズ最適化でハイパーパラメータを調整する - Qiita
ベイズ最適化とは?(雑な)A. なるべく実験をサボりつつ一番良いところを探す方法.ある関数を統計的に推定する方法「ガウス過程回帰」を用いて,なるべく良さそうなところだけの値…