Transformers

TagsArchitectureCS 224NCS 231N

Transformer

At the heart, a transformer is a bunch of stuff around an attention layer.

A transformer takes in a set of feature vectors, encodes it, and then decodes it using attention to the encoded value. A transformer does everything separately to the set of vectors, except when attention is applied.

Encoder block

There are nn encoder blocks. In each block, we have

  1. positional encoding
  1. multi-head self-attention
  1. residual connection
  1. layernorm over each vector
  1. MLP over each vector
  1. another layer norm

The MLP is critical, because if we stack the encoders, without the MLP, attention is a linear operation and so we don’t gain much expressivity.

We add residual connections to help the gradient

Decoder block

There are nn decoder blocks. In each block, we have

  1. positional encoding
  1. masked multi-head self-attention (to prevent peeking!)
  1. layer norm
  1. multi-head attention (keys and values are from the encoded set of vectors, which means that the data that passes through the attention is purely from the transformer encoder (if it weren’t for the residual selection)
  1. layer norm
  1. MLP
  1. layernorm
  1. fully-connected layer

The multi-head attention allows you to mix together the encoded vector with the decoding operation.

See how there is no recurrence needed! It’s all just a big pot of things and it’s stirred slowly with the transformers.

Training a transformer

For the decoder, you can just feed in the ground truth sequence and use a masked self-attention (SUPER important) and try to match the output with the input as close as possible. Use the multi-head attention to contextualize with other information. Essentially, you feed in a squence x0,x1,...xnx_0, x_1, ...x_n and you expect an output of x1,x2,...xn+1x_1, x_2, ...x_{n+1}

Using a transformer

You run the decoder step by step, and you feed the outcome of the transformer back into itself (and modify if necessary). For example, in a captioning task, you would get the predicted xt+1x_{t+1}, find the highest score, embed that, and feed it as xt+1x_{t+1} into the input of the decoder. Keep doing this until you reach the end.