E資格学習 深層学習 Day3 自然言語処理(RNN, AE, Attention)

Simple RNN

要点まとめ
  • 時系列データ処理に適したネットワーク。時系列データとは、時間的順序を追って一定間隔ごとに観察され、相互に統計的依存関係が認められるデータのこと(例)株価、音声データ、テキストデータ
  • 時間tでの中間層での出力をt+1の中間層に反映させる。(数珠つなぎに過去の情報を遡って継承されている=再帰構造)
  • RNNでは、以下3種類の重みがある。
  1. 入力層から中間層への重み W(in)
  2. 中間層から出力層への重みW(out)
  3. 中間層から次の中間層への重みW

f:id:rakurakura:20220104142935p:plain

yについて、 z^{t-1}, z^{t+1}, w_{in}, w, w_{out}を使って数式で表す。
(バイアスはb, c 活性化関数はf(x), g(x)とする。)

  •  u^t = W_{in}x^t + Wz^{t-1} + b
  •  z^t = f(W_{in}x^t + Wz^{t-1} + b) = f(u^t )
  •  v^t = W_{out}z^t + c
  •  y^t = g(W_{out}z^t + c) = g(v^t )
u[:, t+1] = np.dot(X, W_in) + np.dot(z[:, t].reshape(1, -1),W) + b
z[:, t+1] = sigmoid(u[:, t+1])  #f(u)がシグモイド関数の場合
v[:, t+1] = np.dot(z[:, t].reshape(1, -1), W_out) + c
y[:, t+1] = sigmoid(v[:, t+1])  #g(v)がシグモイド関数の場合
構文木再帰的に文章全体の表現ベクトルを得るプログラム
def traverse(node):
  if not isinstance(node, dict):
     v = node
  else:
     left = traverse(node["left"])
     right = traverse(node["right"])
     v  =_activation(W.dot(np.concatenate([left, right])) #左と右とのベクトル結合
  return v
BPTT(back propagation through time)
  • RNNでの時間軸方向の逆伝播のこと。
  • 上記の数式から、損失関数に対して重みで偏微分すると以下のようになる。

(ただし、∂E/∂u = δ^tとする)

勾配計算式

  •  ∂E/ ∂W_{in} = ∂E/ ∂u^t \cdot  ∂u^t/ ∂W_{in}  = δ^t [x^t]
  •  ∂E/ ∂W_{out} =∂E/ ∂v^t \cdot ∂u^t/ ∂W_{out} = δ^{out, t} [z^t ]
  •  ∂E/ ∂W =  ∂E/ ∂u^t \cdot ∂u^t/ ∂W = δ^{t} [z^{t--1} ]
  •  ∂E/ ∂b  =  δ^t × 1 = δ^t
  •  ∂E/ ∂c  =  δ^{out, t}
def bptt(cs, ys, W ,U ,V):

   hiddens, outputs = rnn_net(xs, W, U, V)

   dW = np.zeros_like(W)
   dU  = np.zeros_like(U)
   dV  = np.zeros_like(V)
   do  = _calculate_do(outputs, ys)

   batch_size, n_seq = ys.shape[:2] 

   for t in reversed(range(n_seq)):
         dV += np.dot(do[:, t].T, hiddens[:, t]) / batch_size
         delta_t = do[:, t].dot(V)

         for bptt_step in reversed(range(t+1)):
             dW += np.dot(delta_t.T, xs[:, bptt_step])/ batch_size
             dU  += np.dot(delta_t.T, hiddens[:, bptt_step-1])/ batch_size
             delta_t = delta_t.dot(U)
   return dW, dU, dV

RNNの課題感

勾配消失
  • RNNの課題は、時系列を遡れば勾配が消失していくことがあり、長い系列の学習が困難となる。勾配消失は、0~1の間が微分値が複数掛け合わされることで起こる。シグモイド関数の場合微分値は最大でも0.25。)
勾配爆発
  • 微分値が1よりも大きくなる場合、勾配消失とは逆に、勾配が大きくなりすぎる問題もある。これが勾配爆発と呼ぶ。
  • これを防ぐ方法として、勾配クリッピングという手法がある。
#勾配クリッピングの実装
def gradient_clipping(grad, threshold):
  norm = np.linalg.norm(grad) #与えられた勾配のノルムを算出
  rate = threshold /norm #スレッショルドよりもノルムが大か小かが決まる。
  if rate < 1:  #もしrateが1よりも小さい=ノルムの方が大きい場合は
     return grad * rate #勾配にrateをかけることでthresholdまで値を小さくして出力する
  return grad   #rateが1よりも大きい場合はそのままgradを出力する

LSTM

要点まとめ
  • 勾配消失を解決したのがLSTM。CECという機構により、メモのように情報を記憶しておく機能を導入した!
  • 特長:3つのゲート(忘却ゲート、入力ゲート、出力ゲート)とCECを持つ。
  • 学習と記憶の機能を分離させるイメージ。(CEC自体には学習機能がないため、ゲートによって重みを学習させる。)
入力ゲート
  • 情報を取捨選択しながらCECに入力するゲート。どれをどのくらい使うか。
  •  g = tanh (x_t W_x^g + h_{prev}W_h^g + b^g)
  •  i = σ (x_t W_x^i + h_{prev}W_h^i + b^i)
  •  g \odot i
忘却ゲート
  • CECから不要な記憶を忘却させるゲート。
  • あってもなくても影響がないような言葉は、忘却ゲートで処理されるように学習させる。
  •  f= σ (x_t W_x^f + h_{prev}W_h^f + b^f
出力ゲート
  • CECからの情報のうち、どれだけ次の隠れ状態に情報を渡すか調整する
  •  o= σ (x_t W_x^o + h_{prev}W_h^o + b^o
ピープホールつきLSTM(覗き穴結合)
  • CECの状態を各ゲートから覗き見る機能をつける手法も編み出された。
LSTMとCECの課題
  • パラメータ数が増えたことにより計算量が多くなるのが課題。
【実装】
def lstm(x, prev_h, prev_c, W, U, b):
	lstm_in = _activation(x.dot(W.T)+prev_h.dot(U.T)+b)
	a, i, f, o = np.hsplit(lstm_in, 4)

	a = np.tanh(a) 
	input_gate = _sigmoid(i)
	forget_gate = _sigmoid(f)
	output_gate = _sigmoid(o)

	#セルの状態を更新し、中間層の出力を計算する
	c=input_gate * a + forget_gate *c
	h=output_gate * np.tanh(c)

	return c,h

GRU

要点まとめ
  • LSTMとの違い:CECは持たない。LSTMはパラメータ多いがGRUはパラメータ少ないため計算量が少なく済むのがメリット。LSTMでは3つのゲートであったが、GRUではresetゲートとupdateゲートの2つになった。
resetゲート
  • 過去の隠れ状態をどれだけ反映させるかを決めるゲート
update gateの2つのゲートに集約。
  • 隠れ状態を更新するゲート

双方向RNN (Bi-directional RNN)

要点まとめ
  • 過去→未来の情報だけでなく未来→過去の情報も加味することで精度を向上させる。
  • 順方向へ接続するRNNと、逆方向へ接続するRNNの両方を使い結果をマージする。
  • 実用例としては、機械翻訳など。

Seq2Seq

要点まとめ
  • 入力データ(時系列)から異なるデータ(時系列)を生成するモデル。RNNを用いた、Encoder-Decoderモデルの一種であり、EncoderRNNとDecoderRNNの2つのニューラルネットワークを組み合わせる。機械翻訳に使われる。
  • EncoderRNNによって最終的に文脈(コンテキスト)がベクトルとして表現され、その意味ベクトルから、新たな系列データをDecoderRNNが生成する。
EncoderRNNの意味集約とは
  • 自然言語処理では、単語をベクトルに変換して入力値とする。単語ごとに数字を割り振るとone-hotベクトルとして表現することができるが、これではあまりにもスパースで無駄が多いので、Embeddingにより数万→数百程度までベクトルの大きさが小さくできる。
  • このEmbedding処理には機械学習が用いられ、うまくできると意味の近さを抽出することができる。
DecoderRNNの系列生成
  • EncoderRNNの最終状態(意味ベクトル)をもとに系列を生成する。
  • Embeddingの逆方向に変換していく。=Detokenize
HREDとVHRED
  • Seq2Seqでは、1問1答しかできないが、少し前の会話の文脈に沿った文章を生成できるようになったのがHRED。
  • HREDでは、結構ありがちな会話しかできなくなった。そのような当たり障りのない単語以上の出力を得られるようになったのがVHRED。

AE(オートエンコーダ=自己符号化器)とVAE

  • 画像などをEncoderで潜在変数zに変換し、その後Decoderで元に戻す。
  • AEのように、データを変換して結局元に戻してなんの意味があるのか?→ 入力データを情報量を落とさずに圧縮変換することに成功できれば、次元削減に使うことができる技術。
  • VAEでは、上記の自動符号化器に確率分布を導入した手法。潜在変数zに、確率分布を仮定し、似たような入力データは似たようなベクトルのまま押し込めることを可能にした。Encoderの出力にノイズを与えた上で潜在変数zを出力することで、より汎用性が高くなる。

Word2Vecの概要

  • 文字データを分散表現ベクトルに変換する手法。ワンホットベクトルをそのまま使うとスパースすぎるので、Embedding表現による数百ベクトル程度で表現できる。機械学習
  • 実際にはWord2VecなどでEmbedding表現したデータをSeq2Seq等の入力データとして取り扱うことが多い。

Attentionの概要

  • Seq2Seqでは、中間層の出力において、固定長のベクトルに変換していた。そのため1万語の入力も10単語の入力も同様の固定長ベクトルだったため、長文の場合、うまく翻訳することが難しかった。
  • Attention機構では、上記の課題を解決し固定長ベクトルではなくなった。また、入出力単語間の、相互の関連度を重みとして学習するため、長文への対応を可能にした。

その他調べたこと等

BPTTの課題の解決アプローチ(教師強制)
  • BPTTでは、並列処理が不可能。(全ての中間状態を保存しておく必要があるためメモリコスト高い)そこで、前時刻の正解ラベルを予め入力として使うことで、並列処理が可能にする「教師強制」という手法がある。
  • 教師強制では、「訓練時」に前時刻の正解ラベルを参照し、「テスト時」には前時刻の出力層の状態を参照する。教師強制がない場合は、いずれも前時刻のRNNユニットの状態を参照する。
BPTTの課題の解決アプローチ(Truncated BPTT)
  • 全ての時刻の中間状態を保存しておく課題を解決するために、RNNユニット間の逆伝播の接続を切る方法(順伝播の接続はそのままにしておく)
Leaky接続
  • Leaky接続では、前時刻からの情報の入力をα倍、入力層からの接続を1-α倍して過去の情報と現在の情報の割合を調節する手法。αが0になるほど過去の情報を破棄させる効果がある。
PyTorchでの実装例

Simple RNN

import torch
import torch.nn as nn

SimpleRnn = nn.RNN( input_size=3, hidden_size=2)

input = torch.randn(5,1,3)
h0  = torch.randn(1, 1, 2)

output, hn = SimpleRnn(input, h0)

print(output.shape)
print(hn.shape)


LSTM

import torch
import torch.nn as nn

lstm= nn.LSTM( input_size=10, hidden_size=20) #入力ベクトルと隠れベクトル次元数

input = torch.randn(5,3,10) #(単語数、バッチサイズ、入力ベクトル次元数)
h0 = torch.randn(1, 3, 20)  #(RNNユニット数、バッチサイズ、隠れベクトル次元数)
c0 = torch.randn(1, 3, 20) #(RNNユニット数、バッチサイズ、隠れベクトル次元数)

output, (hn, cn) = lstm (input, (h0, c0))

print(output.shape)
print(hn.shape)


双方向RNN

import torch
import torch.nn as nn

birnn= nn.RNN( input_size=3, hidden_size=2, bidirectional=True)

input = torch.randn(5,1,3)
h0  = torch.randn(2, 1, 2) #双方向のためRNNユニット数は2になっている

output, hn = birnn(input, h0)

print(output.shape)
print(hn.shape)