어텐션 시각화(Attention Visualization)

머신러닝과 딥러닝 · 2020. 1. 23. 22:39

“November 5, 2016”, “5th November 2016”와 같은  human dates에서

“2016–11–05”와 같은 스탠다드한 포맷으로 번역하는 어텐션 모델을 살펴보고, 그 결과를 시각화해본다.

 

신경망이 어떻게 번역하는지 직관을 얻고자 

어떤 입력문자가 결과문자를 예측하는데 중요한지 보여주는 맵(map)을 생성해 볼 수 있다.

 

16을 사용하여 연도가 2016년임을 결정하고, "Ja"는 month가 01를 결정하는 것을 볼 수 있다.

How to Visualize Your Recurrent Neural Network with Attention in Keras에서는 직접 디코더를 구현했다.

인코더의 경우는 keras에서 제공하는 BLSTM을 사용하였다.

 

인코더

BLSTM = Bidirectional(LSTM(encoder_units, return_sequences=True))

encoder_units는 weight 행렬의 사이즈이다. 

또한 return_sequences=True로 설정하여, 모든 시점의 인코딩된 sequence에 접근하길 원한다.

(return_sequences=False의 경우에는 마지막 스텝의 인코딩된 sequence만 출력된다.)

 

입력 문장에서 character는 $ x=(x1,x2,...,xT) $이고,

인코딩된 sequence는 $ h = (h1, h2,...,hT)$이다.

T는 date의 문자(character)의 갯수이다. 

 

디코더

인코더와 연결된 디코더 네트워크의 구조는 아래와 같다. 

어텐션 메커니즘은 context vectors를 생성하고, 

디코더 네트워크는 이 context vectors와 이전의 예측값을 사용하여 다음 예측을 알려준다.

빨간색 화살표는 어텐션 메커니즘이 출력문자 "1" 과 "6"을 생성할 때 

어느 문자를 강하게 가중치를 주었는지 표시한 것이다.

 

모델의 코드는 아래와 같다.

여기서 AttentionDecoder가 직접 customizing한 부분이다.

import numpy as np
import os
from keras.models import Model
from keras.layers import Dense, Embedding, Activation, Permute
from keras.layers import Input, Flatten, Dropout
from keras.layers.recurrent import LSTM
from keras.layers.wrappers import TimeDistributed, Bidirectional
from .custom_recurrents import AttentionDecoder

def simpleNMT(pad_length=100,
              n_chars=105,
              n_labels=6,
              embedding_learnable=False,
              encoder_units=256,
              decoder_units=256,
              trainable=True,
              return_probabilities=False):
    """
    Builds a Neural Machine Translator that has alignment attention
    :param pad_length: the size of the input sequence
    :param n_chars: the number of characters in the vocabulary
    :param n_labels: the number of possible labelings for each character
    :param embedding_learnable: decides if the one hot embedding should be refinable.
    :return: keras.models.Model that can be compiled and fit'ed
    *** REFERENCES ***
    Lee, Jason, Kyunghyun Cho, and Thomas Hofmann. 
    "Neural Machine Translation By Jointly Learning To Align and Translate" 
    """
    input_ = Input(shape=(pad_length,), dtype='float32')
    input_embed = Embedding(n_chars, n_chars,
                            input_length=pad_length,
                            trainable=embedding_learnable,
                            weights=[np.eye(n_chars)],
                            name='OneHot')(input_)

    rnn_encoded = Bidirectional(LSTM(encoder_units, return_sequences=True),
                                name='bidirectional_1',
                                merge_mode='concat',
                                trainable=trainable)(input_embed)

    y_hat = AttentionDecoder(decoder_units,
                             name='attention_decoder_1',
                             output_dim=n_labels,
                             return_probabilities=return_probabilities,
                             trainable=trainable)(rnn_encoded)

    model = Model(inputs=input_, outputs=y_hat)

    return model

decoder_units의 개수는 256개이고,

n_labels는 스탠다드한 포맷으로 출력되기 위한 문자의 수이다.

즉, 0,1,2,3,4,5,6,7,8,9,-,<unk>,<eot> 총 13개이다.

 

AttentionDecoder 부분을 자세히 살펴보면 다음과 같다.

 

이전 문자인 ytm의 shape는 (None, 13)이고,

hidden state인 stm의 shape는 (None, 256)이다.

 

아래 수식에서 대문자는 학습가능한 파라메터를 의미한다.

Equation 1 (top) 문자 t를 예측하는데 문자 j의 중요성을 계산하는 feed-forward neural network , Equation 2 소프트맥스 함수( bottom) 

시퀀스 길이만큼 hidden state를 반복하여 _stm을 만든다.

_stm의 shape는 (None, 50, 256)이다. 시퀀스(문자의 수)는 50이다.

 

W_a는 (decoder units, decoder units)의 shape를 갖는다.

_stm과 W_a의 dot product 연산은 다음과 같다.

$ (None, 50, 256) \cdot (256, 256) = (None, 50, 256)$

 

et는 $ \tanh ( (None, 50, 256) + (None, 50, 256) ) \cdot (256, 1) = (None, 50, 1)$ 이다.

 

et는 attention probability를 의미한다. 시퀀스 길이만큼 있다.

여기서 각 시퀀스의 중요도를 알기 위해 어텐션 분포(attention distribution)를 구한다.

소프트맥스 함수를 취해주는 것이 equation 2 이다.

 

이제 이 소프트맥스의 결과값과 인코더의 hidden state를 multiplication하고 sum하는 것이

context vector를 구하는 것이다.

 

x_seq는 인코더의 hidden state와 동일하다.

x_seq의 shape는 ( None, 50, 512 ) 이고,

 

at의 shape는 ( None, 50, 1)이다.

batch dot을 사용해서 연산하는 의미는 

 

attention distribution을 인코더의 hidden state와 곱해서 더한다는 것이므로

가중치를 곱해서 집중해서 볼 단어들을  더 잘 보겠다는 의미이다.

 

batch_dot 결과값은 (None, 1, 512)의 모양을 갖고

squeeze를 해주면 (None, 512)의 모양을 갖는다.

이것이 context vector이다.

    def step(self, x, states):

        ytm, stm = states

        # equation 1
        
        # repeat the hidden state to the length of the sequence
        _stm = K.repeat(stm, self.timesteps)

        # now multiplty the weight matrix with the repeated hidden state
        _Wxstm = K.dot(_stm, self.W_a)

        # calculate the attention probabilities
        # this relates how much other timesteps contributed to this one.
        et = K.dot(activations.tanh(_Wxstm + self._uxpb),
                   K.expand_dims(self.V_a))
                   
        # equation 2
                   
        at = K.exp(et)
        at_sum = K.sum(at, axis=1)
        at_sum_repeated = K.repeat(at_sum, self.timesteps)
        at /= at_sum_repeated  # vector of size (batchsize, timesteps, 1)

        # equation 3

        # calculate the context vector
        context = K.squeeze(K.batch_dot(at, self.x_seq, axes=1), axis=1)
        # ~~~> calculate new hidden state
        # first calculate the "r" gate:

        # equation 4 (reset gate)
   
        rt = activations.sigmoid(
            K.dot(ytm, self.W_r)
            + K.dot(stm, self.U_r)
            + K.dot(context, self.C_r)
            + self.b_r)
       
        # equation 5 (update gate)

        # now calculate the "z" gate
        zt = activations.sigmoid(
            K.dot(ytm, self.W_z)
            + K.dot(stm, self.U_z)
            + K.dot(context, self.C_z)
            + self.b_z)
            
        # equation 6 (proposal state)

        # calculate the proposal hidden state:
        s_tp = activations.tanh(
            K.dot(ytm, self.W_p)
            + K.dot((rt * stm), self.U_p)
            + K.dot(context, self.C_p)
            + self.b_p)

        # equation 7 (new hidden states)

        # new hidden state:
        st = (1-zt)*stm + zt * s_tp

        # equation 8 (the probability of having each character)

        yt = activations.softmax(
            K.dot(ytm, self.W_o)
            + K.dot(stm, self.U_o)
            + K.dot(context, self.C_o)
            + self.b_o)

        if self.return_probabilities:
            return at, [yt, st]
        else:
            return yt, [yt, st]

나머지 수식은 LSTM과 유사하다.

 

여기서는 이미 학습된 모델을 사용하여

human date형식을 입력값으로 사용해본다.

python visualize.py -e 'Saturday 9 May 2018'

입력값 9가 day를 결정하는데 활성화가 많이 되었고

May의 M 글자가 month를 결정하는데 활성화가 많이 된 것을 확인 할 수 있다.

 

 

https://medium.com/datalogue/attention-in-keras-1892773a4f22

 

How to Visualize Your Recurrent Neural Network with Attention in Keras

A technical discussion and tutorial

medium.com