8000 refactor transformer decoder and revamp the left padding attention mask by vince62s · Pull Request #178 · eole-nlp/eole · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

refactor transformer decoder and revamp the left padding attention mask #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 13, 2025

Conversation

vince62s
Copy link
Contributor
@vince62s vince62s commented Jan 7, 2025

The initial goal was to gather transformer_lm and transformer decoder type in a single class

  • remove decoder_type = "transformer_lm"
  • we keep it at "architecture" level = ["transformer", "transformer_lm"]

After carefull review I realized that left padding was not handled 100% properly.
After this PR:

For training:
we rely on pytorch scaled_dot_product_attention in memory_efficient mode.
we may need to investigate the flash_varlen_func() (which seems slower for now). see Dao-AILab/flash-attention#1125
==> self_attn_backend: has no impact at training at the moment.

For inference:
For EncoderDecoder models we rely also on pytorch scaled_dot_product_attention in memory_efficient mode since it is much faster than flash_attn_with_kvcache() due to short sequences / large batches.
For LLMs (Decoder only models) we rely on flash_attn_with_kvcache() which much faster.

When we use SDPA we always buid ourselves the attention_mask and never rely on is_causal=True
The main reason is that we need to handle padding properly and not attend padded tokens.
With flash_attn_with_kvcache() we added a "recent" functionality cache_leftpad we requires flash-attn > 2.6.1

Results are much more consistent between the two options now.

==> self_attn_backend: if "flash" then flash_attn_with_kvcache() will be used, if "pytorch" then sdpa in memory_efficient mode will be used.

@vince62s vince62s changed the title refactor transformer decoder refactor transformer decoder and revamp the left padding attention mask Jan 10, 2025
Copy link
Member
@francoishernandez francoishernandez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicely done! This simplification had been overdue for quite some time.
(Did not test changes on my end yet, but I trust the existing tests should cover most of the changes here.)

Minor comments to address.
Also, for reference, what is the rationale behind embedding the left_pad flag within the batch?

@vince62s
Copy link
Contributor Author
vince62s commented Jan 13, 2025

Also, for reference, what is the rationale behind embedding the left_pad flag within the batch?

When using left_padding we need to mask the pad tokens at EACH time step when decoding.
So knowing your batch is in left_pad mode triggers that at step 0 you store the pad token mask (in the prompt) and you re-use it at each time step decoding.

Also, previously it was ugly hardcoded in the generator.

@vince62s vince62s merged commit db3144c into eole-nlp:main Jan 13, 2025
2 checks passed
vince62s added a commit to vince62s/eole that referenced this pull request Jan 14, 2025
vince62s added a commit that referenced this pull request Jan 14, 2025
@vince62s vince62s deleted the refactrfdecoder branch April 3, 2025 16:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None 4602 yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0