-
Notifications
You must be signed in to change notification settings - Fork 21
refactor transformer decoder and revamp the left padding attention mask #178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nicely done! This simplification had been overdue for quite some time.
(Did not test changes on my end yet, but I trust the existing tests should cover most of the changes here.)
Minor comments to address.
Also, for reference, what is the rationale behind embedding the left_pad
flag within the batch?
When using left_padding we need to mask the pad tokens at EACH time step when decoding. Also, previously it was ugly hardcoded in the generator. |
The initial goal was to gather transformer_lm and transformer decoder type in a single class
After carefull review I realized that left padding was not handled 100% properly.
After this PR:
For training:
we rely on pytorch
scaled_dot_product_attention
inmemory_efficient
mode.we may need to investigate the
flash_varlen_func()
(which seems slower for now). see Dao-AILab/flash-attention#1125==>
self_attn_backend
: has no impact at training at the moment.For inference:
For EncoderDecoder models we rely also on pytorch
scaled_dot_product_attention
inmemory_efficient
mode since it is much faster thanflash_attn_with_kvcache()
due to short sequences / large batches.For LLMs (Decoder only models) we rely on
flash_attn_with_kvcache()
which much faster.When we use SDPA we always buid ourselves the attention_mask and never rely on is_causal=True
The main reason is that we need to handle padding properly and not attend padded tokens.
With
flash_attn_with_kvcache()
we added a "recent" functionalitycache_leftpad
we requires flash-attn > 2.6.1Results are much more consistent between the two options now.
==>
self_attn_backend
: if "flash" thenflash_attn_with_kvcache()
will be used, if "pytorch" then sdpa in memory_efficient mode will be used.