8000 Support Qwen3 and Gemma3 by alessiodevoto · Pull Request #81 · NVIDIA/kvpress · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Support Qwen3 and Gemma3 #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Support Qwen3 and Gemma3 #81

wants to merge 3 commits into from

Conversation

alessiodevoto
Copy link
Collaborator

This addresses #76 to support the QK normalization used in Gemma3 and Qwen3 (+ updates library version).

@SimJeg SimJeg self-assigned this Jun 12, 2025
@alessiodevoto alessiodevoto force-pushed the feat-qwen3-gemma3 branch 2 times, most recently from 9ab6c0f to e9d2ad9 Compare June 12, 2025 12:24
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
@@ -127,13 +129,23 @@ def __call__(self, model: PreTrainedModel) -> Generator:
model : PreTrainedModel
Model to apply the compression method to
"""

if not isinstance(model, (LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM)):
supported_models = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this variable at the beginning of the file as a global variable with name SUPPORTED_MODELS ?

logger.warning(f"Model {type(model)} not tested")

hooks = []
try:
for layer in model.model.layers:
if isinstance(model, Gemma3ForCausalLM) and layer.is_sliding:
# Skip layers with sliding window attention, only for Gemma3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add

logger.warning(f"Compression in Gemma3 is only applied to layer without sliding window attention")

?

@@ -161,12 +163,16 @@ def duo_attention_on_the_fly(model, num_samples=50, q_len=500):
# Mean query
q = module.self_attn.q_proj(h)
q = q.view(1, q.shape[1], -1, d)
if isinstance(module, (Gemma3Attention, Qwen3Attention)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much better to use modules than name of layers, thanks!

@@ -12,6 +12,10 @@

from kvpress.presses.scorer_press import ScorerPress
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please put kvpress import after transformers import (good practice)

@@ -12,6 +12,10 @@

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, kvpress imports should come after

@SimJeg
Copy link
Collaborator
SimJeg commented Jun 13, 2025

Also @alessiodevoto please check if issue #80 involves other changes in this PR

Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0