ブログではないです

ブログでなくはないです

Overcoming catastrophic forgetting in neural networks [Kirkpatrick+, arXiv'17]

解説スライドを見つけたのでメモ。

以前読んだ論文でもそうだったが、Multi-task Learningにおいて普通にそれぞれのタスクの最適化をしてしまうと前のタスクの情報を忘れてしまう、そのためにどうにかして前のタスクの重みを覚えておこうという話。
全タスクを含むデータの分布Dとパラメータ{ \displaystyle \theta}に対して最適化対象となる

{ \displaystyle
\log p (\theta| D) = \log p (D|\theta) + \log p (\theta) - \log p(D)
}

はタスク別にデータを考えた時ベイズの定理から

{\displaystyle \log p (\theta| D) = \log p (D_B|\theta) + \log p (\theta|D_A) - \log p(D_B) }
となる。第三項はconstとして無視,問題は第二項をどうするかだが[MacKay+, '92] によるとそれぞれのタスクAを訓練した後のパラメータ{ \displaystyle \theta^*_A} を平均としたガウス分布を仮定し,フィッシャー情報行列Fでこの項を近似する。最終的な目的関数は以下のようになるらしい。

{ \displaystyle
L(\theta) = L_B (\theta) + \sum_i {\lambda \over 2} F_i (\theta_i - \theta^*_{A,i})^{2}
}

実験ではMNISTの手書き数字認識タスクで評価。ピクセルをそれぞれのタスクごとに共通のランダムな置換パターンを用いてシャッフルして,擬似的に複数のタスクを作っている。結果は以下の通り。図Aでは普通にSGDした場合とL2正規化をかけた場合と提案手法(EWC)を比較。L2正規化の場合は逆に制限が強すぎて新しいタスクについて学習されない。図BではEWC, single taskの場合, [Goodfellow+, '14] の手法の比較.
また,この手法で結果的に得られたネットワークはそれぞれのパラメータがどちらにも使えるような形で学習されているのか,それともタスク間で使うパラメータを上手いこと分けているのかを確認するためにタスク間でのoverlapを図Cで確認している。(Fisher overlap についてはAppendix参照) 置換したピクセル数が多い(=それぞれのタスクが異なる)ほど低レイヤでのoverlapは低くなるが,出力ラベルが各タスクで共通のため高レイヤではoverlapが大きくなるのではないかとのこと。

f:id:jack_and_rozz:20170509165622p:plain

また、MNISTだけではなく最近盛んなDRLでAtariを解くタスクでも手法の評価を行っている。この辺りあまり詳しくないので採用している既存研究の手法との絡みが分からなかったが、グラフを見る限り明らかに上手くいっているっぽい。 フィッシャー情報行列で近似できる云々の根拠がよくわからなかったので要確認。