-
Notifications
You must be signed in to change notification settings - Fork 40
Support Qwen3 and Gemma3 #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
9ab6c0f
to
e9d2ad9
Compare
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
e9d2ad9
to
ec53981
Compare
kvpress/presses/base_press.py
Outdated
@@ -127,13 +129,23 @@ def __call__(self, model: PreTrainedModel) -> Generator: | |||
model : PreTrainedModel | |||
Model to apply the compression method to | |||
""" | |||
|
|||
if not isinstance(model, (LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM)): | |||
supported_models = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move this variable at the beginning of the file as a global variable with name SUPPORTED_MODELS ?
logger.warning(f"Model {type(model)} not tested") | ||
|
||
hooks = [] | ||
try: | ||
for layer in model.model.layers: | ||
if isinstance(model, Gemma3ForCausalLM) and layer.is_sliding: | ||
# Skip layers with sliding window attention, only for Gemma3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add
logger.warning(f"Compression in Gemma3 is only applied to layer without sliding window attention")
?
@@ -161,12 +163,16 @@ def duo_attention_on_the_fly(model, num_samples=50, q_len=500): | |||
# Mean query | |||
q = module.self_attn.q_proj(h) | |||
q = q.view(1, q.shape[1], -1, d) | |||
if isinstance(module, (Gemma3Attention, Qwen3Attention)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
much better to use modules than name of layers, thanks!
@@ -12,6 +12,10 @@ | |||
|
|||
from kvpress.presses.scorer_press import ScorerPress |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please put kvpress import after transformers import (good practice)
kvpress/presses/snapkv_press.py
Outdated
@@ -12,6 +12,10 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, kvpress imports should come after
Also @alessiodevoto please check if issue #80 involves other changes in this PR |
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
This addresses #76 to support the QK normalization used in Gemma3 and Qwen3 (+ updates library version).