ブログではないです

ブログでなくはないです

Variational Transformers for Diverse Response Generation (Lin+, arXiv'20)

https://arxiv.org/pdf/2003.12738.pdf

自分が対話モデルを訓練する時は主にメジャーなフレームワーク中で実装されていることが多いという理由で通常のTransformerを使うことが多く,CVAEなどVariationalな系統のモデルは触ってはみたもののあんまり深入りして来なかった.自分が以前その辺り調べていたときはKL-Vanishing problemとかが囁かれだして,まだ道具として使うには枯れていないのかなあという印象だったが,今どんなもんなのか,対話やるならとりあえず使っといて損はない感じなのか把握したいので新しめのを読んでみた.

名前の通り,TransformerにVariationalな要素を入れたよという研究.コードも公開されている. 提案モデルは2つあり,1つ目がGloval Variational Transformer (GVT) というモデル. 普通にやるとCVAEで分布のパラメータを計算する際にTransformer-Encoderのどの出力ベクトルを用いるかが問題となるわけだが,ここではエンコーダの先頭に特殊なトークン ("CLS") を追加して,その位置の出力ベクトル(≒ CLSを中心に入力全体に対してself-attentionしたもの)を用いている.RNN-Encoderで最後のトークンについての出力を使うとか,全出力の平均を取るとかに比べてself-attentionしてるからまあまだマシ?という印象.それ以外は基本的にCVAEと大きくは変わらず,分布からサンプルされた潜在変数をデコーダの初めの入力にして,Transformerと同じdecodingをする.

2つ目はSequential Variational Transformer (SVT) .GVTではCVAEと同様にUtterance, ResponseそれぞれのTransformer-Encoderを用意して,入力をエンコードしたベクトルからpriorを,入出力両方をエンコードしたベクトルからposteriorを一度だけ計算していた. 一方で,SVTではそれをdecodingの各ステップで行う(詳細はfigure 2).Priorを計算する際は応答のまだ出力していない部分についてはマスクして,入力+出力済みの応答からpriorを,入力 + 応答全体から posteriorを計算するような仕組みになっている.結構ごちゃごちゃしていて把握が大変.

また,どちらのモデルも訓練時に,後述するKL annealing (Bowman+, CoNLL'16) *1 とbag-of-words loss (Zhao+, ACL'17) を使って学習を安定させているらしい. *2

実験ではMojiTalk, PersonaChat, Empathetic-Dialoguesなどのデータセットで,perplexity, dist-N, 人手付与したCoference, Engagednessなどで評価.C-VAE, Transformerと比べても各指標はそこそこ良くなってはいるが・・・?対話の評価指標って複数サンプルをじっくり見たり自分でInteractiveに会話してみないとどこまで良くなってんのかなんとも(英語だから自分の言語感覚が乏しいというのもある).

同様の発想に基づくモデルとして T-CVAE: Transformer-Based Conditioned Variational Autoencoder for Story Completion (Wang+, IJCAI'19) というのがあるらしいので後で読む.

*1:cross entropy lossとKLの損失を足し合わせる際,KLの方の重みを0->1 で増加させる.おそらく学習後半でKLがほとんど無視されてしまうのを防ごうとしている?

*2: bag-of-words lossは正解応答内の単語を順序関係なしにpredict出来ているかを表す損失で, (Zhao+, ACL'17) eq. (5) によるとbag-of-words lossは { \displaystyle \log p(x_{bow}|z, c) = \log \Pi^{|x|}_{t=1}  \frac{e^{f_{x_t}}}{\sum_{j}^{V} e^{f_j} } } と表される. { \displaystyle x}は応答,{ \displaystyle x_{bow} } を応答に含まれる単語のBoW,{ \displaystyle V}はdecoderの語彙セット, { \displaystyle f = MLP_b(z, x) \in R^V} は潜在変数と入力発話だけから推定されるdecoder語彙内の各単語へのスコア.MLPは別途この損失のためのレイヤを追加している?おそらく意図としてはBoWの推定にdecoderの状態を関与させず,入力と潜在変数だけから出力すべき単語を推定させることで,decoderの言語モデルとしての能力に支配された汎用的な応答生成を防ぐ,ということだと思う.