8000 Add AOPerModuleConfig by jerryzh168 · Pull Request #2119 · pytorch/ao · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add AOPerModuleConfig #2119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load com 10000 ments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.8
rev: v0.11.6
hooks:
# Run the linter.
- id: ruff
Expand Down
74 changes: 73 additions & 1 deletion test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_symmetric_quantization_config,
)
from torch.testing._internal import common_utils
from torch.testing._internal.common_quantization import TestHelperModules
from torch.testing._internal.common_utils import TestCase

from torchao import quantize_
Expand All @@ -28,9 +29,20 @@
AffineQuantizedTensor,
Int4CPULayout,
Int4XPULayout,
PlainLayout,
QDQLayout,
TensorCoreTiledLayout,
)
from torchao.quantization import (
LinearActivationQuantizedTensor,
PerGroup,
)
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.quant_api import (
AOPerModuleConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
Quantizer,
TwoStepQuantizer,
_replace_with_custom_fn_if_matches_filter,
Expand Down Expand Up @@ -933,6 +945,66 @@ def test_workflow_e2e_numerics(self, config):
sqnr = compute_error(y_ref, y_q)
assert sqnr >= 16.5, f"SQNR {sqnr} is too low"

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_ao_per_module_config_default(self):
config1 = Int4WeightOnlyConfig(group_size=32)
config2 = Int8WeightOnlyConfig()
config = AOPerModuleConfig({"_default": config1, "linear2": config2})
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
quantize_(model, config)
model(*example_inputs)
assert isinstance(model.linear1.weight, AffineQuantizedTensor)
assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
assert isinstance(model.linear2.weight, AffineQuantizedTensor)
assert isinstance(model.linear2.weight._layout, PlainLayout)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_ao_per_module_config_module_name(self):
config1 = Int4WeightOnlyConfig(group_size=32)
config2 = Int8WeightOnlyConfig()
config = AOPerModuleConfig({"linear1": config1, "linear2": config2})
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
quantize_(model, config)
model(*example_inputs)
assert isinstance(model.linear1.weight, AffineQuantizedTensor)
assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
assert isinstance(model.linear2.weight, AffineQuantizedTensor)
assert isinstance(model.linear2.weight._layout, PlainLayout)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch 2.6+")
def test_ao_per_module_config_embedding_linear(self):
weight_dtype = torch.int8
granularity = PerGroup(8)
mapping_type = MappingType.SYMMETRIC
embedding_config = IntxWeightOnlyConfig(
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
scale_dtype=None,
)
# example model linear is Linear(16, 8)
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=16)

config = AOPerModuleConfig({"emb": embedding_config, "linear": linear_config})
indices = torch.randint(0, 10, (32,))
indices = indices.unsqueeze(0)
example_inputs = (indices,)
model = TestHelperModules.EmbeddingConvLinearModule().eval()
model(*example_inputs)
quantize_(
model,
config,
filter_fn=lambda x, fqn: isinstance(x, torch.nn.Linear)
or isinstance(x, torch.nn.Embedding),
)
model(*example_inputs)

assert isinstance(model.emb.weight, AffineQuantizedTensor)
assert isinstance(model.emb.weight._layout, QDQLayout)
assert isinstance(model.linear.weight, LinearActivationQuantizedTensor)


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
Expand Down
97 changes: 94 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import logging
import types
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple, Union
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -307,6 +307,52 @@ def _replace_with_custom_fn_if_matches_filter(
return model


def _replace_with_custom_fn_if_matches_filter_with_name(
model,
replacement_fn,
filter_fn,
cur_fqn="",
device=None,
extra_args: Optional[Tuple[Any, ...]] = (),
) -> None:
"""
A variant of _replace_with_custom_fn_if_matches_filter where replacement_fn takes module name as well
...
replacement_fn (Callable[[torch.nn.Module, str], torch.nn.Module]): The function to replace matching modules.
...

Returns:
None
"""
if isinstance(model, Float8Linear):
with torch.device("meta"):
new_module = nn.Linear(model.in_features, model.out_features)
new_module.weight = model.weight
new_module.bias = model.bias
model = new_module
if filter_fn(model, cur_fqn[:-1]):
if device is not None:
model.to(device=device) # move to device before quantization
model = replacement_fn(model, cur_fqn[:-1], *extra_args)
return model
else:
named_children_list = list(model.named_children())
for name, child in named_children_list:
new_child = _replace_with_custom_fn_if_matches_filter_with_name(
child,
replacement_fn,
filter_fn,
f"{cur_fqn}{name}.",
device,
extra_args,
)
if new_child is not child:
setattr(model, name, new_child)
if device is not None:
model.to(device=device) # move parent module to device
return model


def _is_linear(mod, *args):
# avoid circular dependencies
from torchao.quantization.qat.affine_fake_quantized_tensor import (
Expand Down Expand Up @@ -547,13 +593,24 @@ def quantize_(
quantize_(m, int4_weight_only(group_size=32))

"""
filter_fn = _is_linear if filter_fn is None else filter_fn
if isinstance(config, AOPerModuleConfig):
_replace_with_custom_fn_if_matches_filter_with_name(
model,
_ao_per_module_config_handler,
filter_fn,
device=device,
extra_args=(config,),
)
return

if isinstance(config, AOBaseConfig):
handler = _QUANTIZE_CONFIG_HANDLER[type(config)]
# for each linear in the model, apply the transform if filtering passes
_replace_with_custom_fn_if_matches_filter(
model,
handler,
_is_linear if filter_fn is None else filter_fn,
filter_fn,
device=device,
extra_args=(config,),
)
Expand Down Expand Up @@ -1900,6 +1957,40 @@ def _fpx_weight_only_transform(
return module


@dataclass
class AOPerModuleConfig(AOBaseConfig):
"""Per module configurations for torchao quantize_ API

Args:
`module_fqn_to_config`: Dict[str, Optional[AOBaseConfig]]: a dictionary from
the fully qualified name of module to the AOBaseConfig that we want to apply to the module.
Also has a special key: "_default", if "_default" is present in the dictionary,
the config for "_default" will be applied to all the remaining modules that does not have
per module configuration specified.
"""

module_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field(
default_factory=dict
)


def _ao_per_module_config_handler(
module: torch.nn.Module, module_fqn: str, config: AOPerModuleConfig
):
c = config.module_fqn_to_config.get(module_fqn, None)
# Maybe: we can add module type specific config in the future, in needed
# fallback to use default if no module specific config is provided
default_c = config.module_fqn_to_config.get("_default", None)
if default_c is not None and c is None:
c = default_c

if c is not None:
handler = _QUANTIZE_CONFIG_HANDLER[type(c)]
return handler(module, c)

return handler(module, c)


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals(
[
Expand Down
Loading
0