Understanding Incremental Decoding
Walks through some of the performance hacks in BartForConditionalGeneration
Here is a really slow (pseudocode) way to greedily generate summaries:
output_tokens = [bos]
while not done:
encoder_hidden_state = model.encoder(article_input_ids)
logits = model.decoder(encoder_hidden_state, output_tokens)
next_word = logits.argmax()
output_tokens.append(next_word)
if next_word == eos: break
Let's say N = len(article_input_ids)
and M = len(output_tokens)
when we return. Let's also ignore the cost of encoder-decoder attention.
Then the complexity of this approach is roughly
- $(N \cdot M)$ (we call the encoder
M
times onN
tokens) - $+ \sum_{m=1}^{M} m = \frac{M(M+1)}{2}$ (we call the decoder
M
times on 1 token then 2 tokens, all the way to M tokens). - Total: $N \cdot M + \frac{M(M+1)}{2}$
Let's say we are generating a 100 token summary of a 1024 token article, then we have to "process" $N \cdot M + \frac{M(M+1)}{2} = 107,450$ tokens
#collapse-hide
# show work
N,M=1024, 100
complexity_simple = int(N*M + ((M*(M+1))/2))
msg = "To generate a 100 token summary of a 1024 token article using this approach, we have to process {:,} tokens."
#display(Markdown(msg.format(complexity_simple)))
Thankfully, we can just hoist the encoder
call outside the loop!
output_tokens = [bos]
encoder_hidden_state = model.encoder(article_input_ids)
while not done:
logits = model.decoder(encoder_hidden_state, output_tokens)
next_word = logits.argmax()
output_tokens.append(next_word)
if next_word == eos: break
Now the complexity is dramatically reduced: $N \cdot 1 + \frac{M(M+1)}{2} = 6,074$
But we can go even further, and partially cache the attention outputs for the decoder. This will change our generation loop to:
output_tokens = [bos]
encoder_hidden_state = model.encoder(article_input_ids)
cache = None
while not done:
logits, cache = model.decoder(encoder_hidden_state, output_tokens[-1], cache=cache)
next_word = logits.argmax()
output_tokens.append(next_word)
if next_word == eos: break
And change our formula to $N + M = 1,124$
The nitty-gritty of this trick is explained in the Bonus section below.
Partially caching keys and values in DecoderLayer
Here is some pseudocode for attention without all the reshapes and heads and masks and scaling. It doesn't work even though it looks pretty.
class SimplifiedAttention(nn.Module):
def __init__(self, embed_dim):
self.Wq = torch.nn.Linear(embed_dim, embed_dim)
self.Wk = torch.nn.Linear(embed_dim, embed_dim)
self.Wv = torch.nn.Linear(embed_dim, embed_dim)
self.dense = torch.nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value):
q = self.Wq(q)
k = self.Wk(k)
v = self.Wv(v)
matmul_qk = torch.matmul(q, k.T)
attention_weights = matmul_qk.softmax(dim=-1)
output = torch.matmul(attention_weights, v)
return self.dense(output)
Now let's glimpse at the callers inside BART's DecoderLayer
:
(LayerNorms and dropouts deleted for simplicity). Here's some more pseudocode
class SimplifiedDecoderLayer(nn.Module):
def __init__(self, embed_dim):
self.self_attn = SimplifiedAttention(embed_dim)
self.encoder_attn = SimplifiedAttention(embed_dim)
def forward(x, last_encoder_hidden_state, *masks_etc):
# x shape `(batch_size, tokens_generated_so_far, embed_dim)`
# x comes from decoder
x = self.self_attn(query=x, key=x, value=x) # pay attention to somebody else for a change!
output = self.encoder_attn(
query=x,
key=last_encoder_hidden_state, # could be None
value=last_encoder_hidden_state,
)
return output
What did we learn?
- In
encoder_attention
, we can cache everything that doesn't depend on q, namely these outputsk = self.Wk(k) v = self.Wv(v)
The more exciting optimization is that in self_attn
, we can cache the part of k,v that depends on
x[:, :1]
the tokens we've already generated. Then each time through the generation loop, we only pass in x[:, :-1]
and apply concatenation:
k = torch.cat((past_key, new_k), dim='seq_len') # the seq_len dimension,
v = torch.cat((past_value, new_v), dim='seq_len')
Of the 8 F.linear
ops performed by each DecoderLayer
, we've managed to completely cache 2 of them, and almost completely cache 2 more. Overall, we chop off about 40% of the runtime vs only caching encoder_outputs.