Here is a really slow (pseudocode) way to greedily generate summaries:

output_tokens = [bos]
while not done:
     encoder_hidden_state = model.encoder(article_input_ids)
     logits = model.decoder(encoder_hidden_state, output_tokens)
     next_word = logits.argmax()
     output_tokens.append(next_word)
     if next_word == eos: break

Let's say N = len(article_input_ids) and M = len(output_tokens) when we return. Let's also ignore the cost of encoder-decoder attention. Then the complexity of this approach is roughly

  • $(N \cdot M)$ (we call the encoder M times on N tokens)
  • $+ \sum_{m=1}^{M} m = \frac{M(M+1)}{2}$ (we call the decoder M times on 1 token then 2 tokens, all the way to M tokens).
  • Total: $N \cdot M + \frac{M(M+1)}{2}$

Let's say we are generating a 100 token summary of a 1024 token article, then we have to "process" $N \cdot M + \frac{M(M+1)}{2} = 107,450$ tokens

#collapse-hide
# show work
N,M=1024, 100
complexity_simple = int(N*M + ((M*(M+1))/2))
msg = "To generate a 100 token summary of a 1024 token article using this approach, we have to process {:,} tokens."
#display(Markdown(msg.format(complexity_simple)))

Thankfully, we can just hoist the encoder call outside the loop!

output_tokens = [bos]
encoder_hidden_state = model.encoder(article_input_ids)
while not done:
     logits = model.decoder(encoder_hidden_state, output_tokens)
     next_word = logits.argmax()

     output_tokens.append(next_word)
     if next_word == eos: break

Now the complexity is dramatically reduced: $N \cdot 1 + \frac{M(M+1)}{2} = 6,074$

But we can go even further, and partially cache the attention outputs for the decoder. This will change our generation loop to:

output_tokens = [bos]
encoder_hidden_state = model.encoder(article_input_ids)
cache = None
while not done:
     logits, cache = model.decoder(encoder_hidden_state, output_tokens[-1], cache=cache)
     next_word = logits.argmax()
     output_tokens.append(next_word)
     if next_word == eos: break

And change our formula to $N + M = 1,124$

The nitty-gritty of this trick is explained in the Bonus section below.

Partially caching keys and values in DecoderLayer

Here is some pseudocode for attention without all the reshapes and heads and masks and scaling. It doesn't work even though it looks pretty.

class SimplifiedAttention(nn.Module):
    def __init__(self, embed_dim):
        self.Wq = torch.nn.Linear(embed_dim, embed_dim)
        self.Wk = torch.nn.Linear(embed_dim, embed_dim)
        self.Wv = torch.nn.Linear(embed_dim, embed_dim)
        self.dense = torch.nn.Linear(embed_dim, embed_dim)
    def forward(self, query, key, value):
        q = self.Wq(q)
        k = self.Wk(k) 
        v = self.Wv(v)
        matmul_qk = torch.matmul(q, k.T)
        attention_weights = matmul_qk.softmax(dim=-1)
        output = torch.matmul(attention_weights, v)
        return self.dense(output)

Now let's glimpse at the callers inside BART's DecoderLayer: (LayerNorms and dropouts deleted for simplicity). Here's some more pseudocode

class SimplifiedDecoderLayer(nn.Module):

    def __init__(self, embed_dim):
        self.self_attn = SimplifiedAttention(embed_dim)
        self.encoder_attn = SimplifiedAttention(embed_dim)
    def forward(x, last_encoder_hidden_state, *masks_etc):
         # x shape `(batch_size, tokens_generated_so_far, embed_dim)`
         # x comes from decoder

        x = self.self_attn(query=x, key=x, value=x) # pay attention to somebody else for a change!
        output = self.encoder_attn(
            query=x,
            key=last_encoder_hidden_state,  # could be None
            value=last_encoder_hidden_state,
        )
        return output

What did we learn?

  • In encoder_attention, we can cache everything that doesn't depend on q, namely these outputs
          k = self.Wk(k) 
          v = self.Wv(v)

The more exciting optimization is that in self_attn, we can cache the part of k,v that depends on x[:, :1] the tokens we've already generated. Then each time through the generation loop, we only pass in x[:, :-1] and apply concatenation:

k = torch.cat((past_key, new_k), dim='seq_len') # the seq_len dimension, 
v = torch.cat((past_value, new_v), dim='seq_len')

Of the 8 F.linear ops performed by each DecoderLayer, we've managed to completely cache 2 of them, and almost completely cache 2 more. Overall, we chop off about 40% of the runtime vs only caching encoder_outputs.