8000 333 add torch compile uni test by fromm-m · Pull Request #344 · Modalities/modalities · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

333 add torch compile uni test #344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/modalities/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,15 @@ def get_parent_module_and_child_name(child_module: nn.Module, model: nn.Module)
return parent_candidate, child_name
raise ModelStateError("No valid parent candidate")

block_types = tuple([get_module_class_from_name(model, b) for b in block_names])
block_types = []
for name in block_names:
module_class = get_module_class_from_name(model, name)
if module_class is not None:
block_types.append(module_class)
else:
raise ValueError("None of the provided block_names match any modules in the model")

block_types = tuple(block_types)

for _, module in model.named_modules():
if isinstance(module, block_types):
Expand Down
90 changes: 90 additions & 0 deletions tests/test_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
import torch.nn as nn

from modalities.models.gpt2.gpt2_model import (
GPT2LLM,
ActivationType,
AttentionConfig,
AttentionImplementation,
LayerNorms,
LayerNormWrapperConfig,
PositionTypes,
QueryKeyValueTransformType,
)
from modalities.models.model_factory import ModelFactory


def create_gpt2_configs():
attention_config = AttentionConfig(
qkv_transforms=[
AttentionConfig.QueryKeyValueTransformConfig(
type_hint=QueryKeyValueTransformType.RotaryTransform.name,
config=AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig(
n_embd=512, n_head=8, seq_length_dim=-2, base_freq=10000
),
)
]
)
norm_config = LayerNormWrapperConfig(norm_type=LayerNorms.layer_norm, config={"normalized_shape": 512})
return attention_config, norm_config


@pytest.fixture
def gpt2_model():
attention_config, norm_config = create_gpt2_configs()
model = GPT2LLM(
sample_key="input_ids",
prediction_key="logits",
poe_type=PositionTypes.NOPE,
sequence_length=256,
vocab_size=1024,
n_layer=4,
n_head_q=8,
n_head_kv=4,
n_embd=512,
ffn_hidden=2048,
dropout=0.1,
bias=True,
activation_type=ActivationType.SWIGLU,
attention_implementation=AttentionImplementation.PYTORCH_FLASH,
attention_config=attention_config,
attention_norm_config=norm_config,
ffn_norm_config=norm_config,
lm_head_norm_config=norm_config,
use_weight_tying=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does torch.compile support weight tying now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests run through, as I understand it, compile supports weight tying (and other techniques), they are just not as "performant" as they introduce graph breaks

)
return model


def test_get_compiled_model_compiles_blocks(gpt2_model):
original_blocks = list(gpt2_model.transformer.h)
original_wte = gpt2_model.transformer.wte
original_lm_head = gpt2_model.transformer.lm_head

block_names = ["GPT2Block"]
result_model = ModelFactory.get_compiled_model(gpt2_model, block_names)

assert len(result_model.transformer.h) == 4, "Should still have four blocks"
for i, (original_block, new_block) in enumerate(zip(original_blocks, result_model.transformer.h)):
assert new_block is not original_block, f"Block {i} should be a compiled version"
assert isinstance(new_block, nn.Module), f"Block {i} should be an nn.Module"
assert result_model.transformer.wte is original_wte, "Embedding layer should remain unchanged"
assert result_model.transformer.lm_head is original_lm_head, "LM head should remain unchanged"
assert result_model is gpt2_model, "Should return the same model instance"


def test_get_compiled_model_no_matching_blocks(gpt2_model):
"""
Test that get_compiled_model raises a ValueError if no blocks match the specified types.
"""
with pytest.raises(ValueError, match="None of the provided block_names match any modules in the model"):
ModelFactory.get_compiled_model(gpt2_model, block_names=["Conv2d"])


def test_get_compiled_model_empty_block_names(gpt2_model):
original_model_dict = dict(gpt2_model.named_modules())
result_model = ModelFactory.get_compiled_model(gpt2_model, block_names=[])

new_model_dict = dict(result_model.named_modules())
assert new_model_dict == original_model_dict, "Model should remain unchanged with empty block_names"
assert result_model is gpt2_model, "Should return the same model instance"
0