diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4bf3a56bc5..d76ba70f74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.8 + rev: v0.11.6 hooks: # Run the linter. - id: ruff diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e6cf65e28c..86631e58b2 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -19,6 +19,7 @@ get_symmetric_quantization_config, ) from torch.testing._internal import common_utils +from torch.testing._internal.common_quantization import TestHelperModules from torch.testing._internal.common_utils import TestCase from torchao import quantize_ @@ -28,9 +29,20 @@ AffineQuantizedTensor, Int4CPULayout, Int4XPULayout, + PlainLayout, + QDQLayout, + TensorCoreTiledLayout, +) +from torchao.quantization import ( + LinearActivationQuantizedTensor, + PerGroup, ) -from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.quant_api import ( + AOPerModuleConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8WeightOnlyConfig, + IntxWeightOnlyConfig, Quantizer, TwoStepQuantizer, _replace_with_custom_fn_if_matches_filter, @@ -933,6 +945,66 @@ def test_workflow_e2e_numerics(self, config): sqnr = compute_error(y_ref, y_q) assert sqnr >= 16.5, f"SQNR {sqnr} is too low" + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_ao_per_module_config_default(self): + config1 = Int4WeightOnlyConfig(group_size=32) + config2 = Int8WeightOnlyConfig() + config = AOPerModuleConfig({"_default": config1, "linear2": config2}) + model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + quantize_(model, config) + model(*example_inputs) + assert isinstance(model.linear1.weight, AffineQuantizedTensor) + assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout) + assert isinstance(model.linear2.weight, AffineQuantizedTensor) + assert isinstance(model.linear2.weight._layout, PlainLayout) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_ao_per_module_config_module_name(self): + config1 = Int4WeightOnlyConfig(group_size=32) + config2 = Int8WeightOnlyConfig() + config = AOPerModuleConfig({"linear1": config1, "linear2": config2}) + model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + quantize_(model, config) + model(*example_inputs) + assert isinstance(model.linear1.weight, AffineQuantizedTensor) + assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout) + assert isinstance(model.linear2.weight, AffineQuantizedTensor) + assert isinstance(model.linear2.weight._layout, PlainLayout) + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch 2.6+") + def test_ao_per_module_config_embedding_linear(self): + weight_dtype = torch.int8 + granularity = PerGroup(8) + mapping_type = MappingType.SYMMETRIC + embedding_config = IntxWeightOnlyConfig( + weight_dtype=weight_dtype, + granularity=granularity, + mapping_type=mapping_type, + scale_dtype=None, + ) + # example model linear is Linear(16, 8) + linear_config = Int8DynamicActivationInt4WeightConfig(group_size=16) + + config = AOPerModuleConfig({"emb": embedding_config, "linear": linear_config}) + indices = torch.randint(0, 10, (32,)) + indices = indices.unsqueeze(0) + example_inputs = (indices,) + model = TestHelperModules.EmbeddingConvLinearModule().eval() + model(*example_inputs) + quantize_( + model, + config, + filter_fn=lambda x, fqn: isinstance(x, torch.nn.Linear) + or isinstance(x, torch.nn.Embedding), + ) + model(*example_inputs) + + assert isinstance(model.emb.weight, AffineQuantizedTensor) + assert isinstance(model.emb.weight._layout, QDQLayout) + assert isinstance(model.linear.weight, LinearActivationQuantizedTensor) + class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 903ce22987..4b2cd53024 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -18,8 +18,8 @@ import logging import types import warnings -from dataclasses import dataclass -from typing import Any, Callable, Optional, Tuple, Union +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -307,6 +307,52 @@ def _replace_with_custom_fn_if_matches_filter( return model +def _replace_with_custom_fn_if_matches_filter_with_name( + model, + replacement_fn, + filter_fn, + cur_fqn="", + device=None, + extra_args: Optional[Tuple[Any, ...]] = (), +) -> None: + """ + A variant of _replace_with_custom_fn_if_matches_filter where replacement_fn takes module name as well + ... + replacement_fn (Callable[[torch.nn.Module, str], torch.nn.Module]): The function to replace matching modules. + ... + + Returns: + None + """ + if isinstance(model, Float8Linear): + with torch.device("meta"): + new_module = nn.Linear(model.in_features, model.out_features) + new_module.weight = model.weight + new_module.bias = model.bias + model = new_module + if filter_fn(model, cur_fqn[:-1]): + if device is not None: + model.to(device=device) # move to device before quantization + model = replacement_fn(model, cur_fqn[:-1], *extra_args) + return model + else: + named_children_list = list(model.named_children()) + for name, child in named_children_list: + new_child = _replace_with_custom_fn_if_matches_filter_with_name( + child, + replacement_fn, + filter_fn, + f"{cur_fqn}{name}.", + device, + extra_args, + ) + if new_child is not child: + setattr(model, name, new_child) + if device is not None: + model.to(device=device) # move parent module to device + return model + + def _is_linear(mod, *args): # avoid circular dependencies from torchao.quantization.qat.affine_fake_quantized_tensor import ( @@ -547,13 +593,24 @@ def quantize_( quantize_(m, int4_weight_only(group_size=32)) """ + filter_fn = _is_linear if filter_fn is None else filter_fn + if isinstance(config, AOPerModuleConfig): + _replace_with_custom_fn_if_matches_filter_with_name( + model, + _ao_per_module_config_handler, + filter_fn, + device=device, + extra_args=(config,), + ) + return + if isinstance(config, AOBaseConfig): handler = _QUANTIZE_CONFIG_HANDLER[type(config)] # for each linear in the model, apply the transform if filtering passes _replace_with_custom_fn_if_matches_filter( model, handler, - _is_linear if filter_fn is None else filter_fn, + filter_fn, device=device, extra_args=(config,), ) @@ -1900,6 +1957,40 @@ def _fpx_weight_only_transform( return module +@dataclass +class AOPerModuleConfig(AOBaseConfig): + """Per module configurations for torchao quantize_ API + + Args: + `module_fqn_to_config`: Dict[str, Optional[AOBaseConfig]]: a dictionary from + the fully qualified name of module to the AOBaseConfig that we want to apply to the module. + Also has a special key: "_default", if "_default" is present in the dictionary, + the config for "_default" will be applied to all the remaining modules that does not have + per module configuration specified. + """ + + module_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field( + default_factory=dict + ) + + +def _ao_per_module_config_handler( + module: torch.nn.Module, module_fqn: str, config: AOPerModuleConfig +): + c = config.module_fqn_to_config.get(module_fqn, None) + # Maybe: we can add module type specific config in the future, in needed + # fallback to use default if no module specific config is provided + default_c = config.module_fqn_to_config.get("_default", None) + if default_c is not None and c is None: + c = default_c + + if c is not None: + handler = _QUANTIZE_CONFIG_HANDLER[type(c)] + return handler(module, c) + + return handler(module, c) + + if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals( [