[정리] Attention is all you need

Abstract

Introduction

Model Architecture

Encoder

# [2] https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Layers.py
import torch
import torch.nn as nn

class EncoderLayer(nn.Module):
    '''Compose with two layers.'''

    def __init__(
        self,
        d_model: int,
        d_inner: int,
        d_head: int,
        d_k: int,
        d_v: int,
        dropout: float = 0.1
    ) -> None:
        super().__init__()
        self.attention = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.feed_forward = PointwiseFeedFoward(d_model, d_inner, dropout=dropout)


    def forward(
        self,
        enc_input: torch.Tensor,
        slf_attn_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        enc_output, enc_slf_attn = self.slf_attn(
             enc_input, enc_input, enc_input, mask=slf_attn_mask,
        )
        enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_slf_attn

Decoder

# [2] https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Layers.py
class DecoderLayer(nn.Module):
    '''Compose with three layers.'''

    def __init__(
        self,
        d_model: int,
        d_inner: int,
        n_head: int,
        d_k: int,
        d_v: int,
        dropout: float=0.1
    ) -> None:
        super().__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(
        self,
        dec_input: torch.Tensor,
        enc_output: torch.Tensor,
        slf_attn_mask: torch.Tensor=None,
        dec_enc_attn_mask: torch.Tensor=None,
    ) -> torch.Tensor:
        dec_output, dec_slf_attn = self.slf_attn(
            dec_input, dec_input, dec_input, mask=slf_attn_mask,
        )
        dec_output, dec_enc_attn = self.enc_attn(
            dec_output, enc_output, enc_output, mask=dec_enc_attn_mask,
        )
        dec_output = self.pos_ffn(dec_output)
        return dec_output, dec_slf_attn, dec_enc_attn

Attention

Scaled Dot-Product Attention

\[\text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V\\ d_k: \text{Dimension of queries and keys}.\]
# [2] https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Modules.py
class ScaledDotProductAttention(nn.Module):
    '''Scaled Dot-Product Attention.'''

    def __init__(
        self,
        temperature: float,
        attn_dropout: float=0.1,
    ) -> None:
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: torch.Tensor=None,
    ) -> torch.Tensor:

        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn

Multi-Head Attention

저자들은 하나의 attention 함수를 사용하는 것보다 각자 다른 가중치를 가진 \(h\)개의 attention 연산을 하는 것이 좀 더 효과적이라 말한다. Multi-head attention을 사용함으로써 모델은 다수의 representation subspaces에서 서로 다른 위치에 대한 attention 정보를 취득할 수 있다.

\[\text{MultiHead}(Q,K,V) = Concat(head_1, ..., head_h)W^O\\ \text{where } head_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i)\]
# [2] https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/SubLayers.py
class MultiHeadAttention(nn.Module):
    '''Multi-Head Attention module.'''

    def __init__(
        self,
        n_head: int,
        d_model: int,
        d_k: int,
        d_v: int,
        dropout: float=0.1,
    ) -> None:
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)


    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: Optional[torch.Tensor]=None,
    ) -> torch.Tensor:
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)   # For head axis broadcasting.

        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn

Position-wise Feed-Forward Networks

각 layer는 attention sub-layer와 함께 fully-connected feed-forward network를 포함한다. 이 계층은 ReLU activation이 두 개의 linear transformation 사이에 있는 형태이다.

\[FFN(x) = max(0, xW_1 + b_1)W_2 + b_2\]
# [2] https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/SubLayers.py
class PositionwiseFeedForward(nn.Module):
    '''A two-feed-forward-layer module.'''

    def __init__(
        self,
        d_in: int,
        d_hid: int,
        dropout: float=0.1,
    ) -> None:
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid) # position-wise
        self.w_2 = nn.Linear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)

        return x

Embeddings and Softmax

Input tokens와 output tokens를 \(d_{model}\)차원의 벡터로 변환해주는 학습가능한 embedding을 사용한다. 또한 softmax 계산 직전에는 학습 가능한 linear transformation을 적용한다. 이 세가지에 각각 별도의 weight matrices로 정의한다면 어떻게 될까? 30k개의 토큰과 512 embedding size의 정보를 가진 matrix를 예로 들어보자. 이 matrix는 15.3 million의 파라미터를 가지게 된다. 문제는 이런 matrix가 세 개 모여서 총 46 million까지 파라미터가 늘어난다는 것이다 [6].

본 논문에서는 Press, O. and Wolf, L.[5]의 제안에 따라 Encoding, Decoding에서의 embedding layers와 softmax 직전의 linear transformation에 동일한 가중치를 사용한다 (a.k.a weight tying).

# [6] https://github.com/jsbaan/transformer-from-scratch/blob/main/transformer.py#L33
class Transformer(nn.Module):
    def __init__(
        self,
...
        self.embed = nn.Embedding(vocab_size, hidden_dim, padding_idx=padding_idx)
        self.encoder = TransformerEncoder(
            self.embed, hidden_dim, ff_dim, num_heads, num_layers, dropout_p
        )
        self.decoder = TransformerDecoder(
            self.embed,
            hidden_dim,
...

# [6] https://github.com/jsbaan/transformer-from-scratch/blob/main/decoder.py#L44
class TransformerDecoder(nn.Module):
    def __init__(
        self,
        embedding: torch.nn.Embedding,
        hidden_dim: int,
...
        self.output_layer = nn.Linear(hidden_dim, vocab_size, bias=False)

        # Note: a linear layer multiplies the input with a transpose of the weight matrix, so no need to do that here.
        if tie_output_to_embedding:
            self.output_layer.weight = nn.Parameter(self.embed.weight)
...

또한 저자들은 embedding weight에 상수 \(\sqrt{d_{model}}\)를 곱한 것을 언급한다. 이에 대해서는 의견이 분분한데 그 중 대표적인 것들은 다음과 같다. [7]

Positional Encoding (Sinusoids)

이 모델에는 recurrence가 없기 때문에 sequence의 순서에 대한 정보를 넣어줄 필요가 있다. 논문에서는 embedding과 동일한 차원의 positional vector를 embedding에 합하는 방식을 제안한다. Positional vector는 word의 position \(pos\)와 vector의 dimension \(i\)에 대해 다음과 같은 sine, cosine 주기함수로 정의한다.

\[PE_{(pos,2i)}=sin(pos/10000^{2i/d_{model}})\\ PE_{(pos,2i+1)}=cos(pos/10000^{2i/d_{model}})\]

Positional vector는 다음 두 가지 조건을 만족해야 한다. [8]

위와 같이 주기함수로 positional vector를 정의하면 값이 일정한 구간으로 내에서 반복되므로 값이 너무 커지거나 값의 편차가 너무 커지는 것을 방지할 수 있다. 또한 단어의 위치나 vector의 각 차원에 따라 서로 다른 주기함수를 사용하기 때문에 위치정보를 표시하기 위한 식별자로써의 역할을 할 수 있다.

# [6] https://github.com/jsbaan/transformer-from-scratch/blob/main/positional_encodings.py
class SinusoidEncoding(torch.nn.Module):
    """
    Mostly copied from
    https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html
    """

    def __init__(
        self,
        hidden_dim: int,
        max_len: int=5000,
    ) -> None:
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pos_embed = torch.zeros(max_len, hidden_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim)
        )
        pos_embed[:, 0::2] = torch.sin(position * div_term)  # 2i
        pos_embed[:, 1::2] = torch.cos(position * div_term)  # 2i + 1
        pos_embed = pos_embed.unsqueeze(0)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer("pos_embed", pos_embed, persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Adds positional embeddings to token embeddings.
        N = batch size
        L = sequence length
        E = embedding dim
        :param x: token embeddings. Shape: (N, L, E)
        :return: token_embeddings + positional embeddings. Shape: (N, L, E)
        """
        x = x + self.pos_embed[:, : x.size(1)]
        return x

Computational Complexity

Training

Training Data and Batching

Hardware and Schedule

Optimizer

\[lrate=d_{model}^{-0.5} \cdot min(step\_num^{-0.5}, step\_num \cdot warmup\_steps^{-0.5})\]

Regularization

Label Smoothing

Results

Conclusion

Appendix

References

  1. Vaswani, A. et al. (2017). Attention Is All You Need. arXiv preprint arXiv:1706.03762.

  2. Huang, Y. (2020). Attention is all you need: A Pytorch Implementation. [Online] Available at: https://github.com/jadore801120/attention-is-all-you-need-pytorch [Accessed 01 Apr. 2023].

  3. Huang, A. et al. (2022). The Annotated Transformer. [Online] Available at: http://nlp.seas.harvard.edu/annotated-transformer [Accessed 01 Apr. 2023].

  4. Dontloo et al. (2019). What exactly are keys, queries, and values in attention mechanisms?. [Online] Available at: https://stats.stackexchange.com/questions/421935/what-exactly-are-keys-queries-and-values-in-attention-mechanisms [Accessed 02 Apr. 2023].

  5. Press, O. and Wolf, L. (2016). Using the output embedding to improve language models. arXiv preprint arXiv:1608.05859.

  6. Baan, J. (2022). Implementing a Transformer from Scratch. [Online] Available at: https://towardsdatascience.com/7-things-you-didnt-know-about-the-transformer-a70d93ced6b2 [Accessed 02 Apr. 2023].

  7. Noe et al. (2021). Transformer model: Why are word embeddings scaled before adding positional encodings?. Available at: https://datascience.stackexchange.com/questions/87906/transformer-model-why-are-word-embeddings-scaled-before-adding-positional-encod [Accessed 02 Apr. 2023].

  8. Lee, M. (2022). 트랜스포머(Transformer) 파헤치기-1.Positional Encoding. [Online] Available at: https://www.blossominkyung.com/deeplearning/transfomer-positional-encoding [Accessed 02 Apr. 2023].