From b9a55b559e18a3ec7968ea9b1c0b353139fd8806 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 26 Jun 2025 11:18:37 -0700 Subject: [PATCH 1/5] support nf4 tensor shard and gather --- test/dtypes/test_nf4.py | 52 ++++++++++-- torchao/dtypes/nf4tensor.py | 152 +++++++++++++++++++++++++++--------- 2 files changed, 159 insertions(+), 45 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index f52644cdf3..cceb1c8168 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -14,28 +14,28 @@ import pytest import torch import torch.nn.functional as F + +import torchao +from packaging import version from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointWrapper, apply_activation_checkpointing, + CheckpointWrapper, ) from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( - TestCase, instantiate_parametrized_tests, parametrize, run_tests, + TestCase, ) - -import torchao -from packaging import version from torchao.dtypes._nf4tensor_api import nf4_weight_only from torchao.dtypes.nf4tensor import ( _INNER_TENSOR_NAMES_FOR_SHARDING, - NF4Tensor, linear_nf4, + NF4Tensor, to_nf4, ) from torchao.testing.utils import skip_if_rocm @@ -741,6 +741,46 @@ def _test_qlora_fsdp2( self.assertEqual(fsdp_loss, base_loss) +class TestComm(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @pytest.mark.skipif( + version.parse(torch.__version__).base_version < "2.4.0", + reason="torch >= 2.4 required", + ) + @skip_if_lt_x_gpu(2) + def test_comm(self): + self.run_subtests( + {"input_size": [512, 2048]}, + self._test_comm, + ) + + def _test_comm(self, input_size: int): + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed._tensor import distribute_tensor + + model = nn.Linear(input_size, input_size, device="cuda") + origin_tensor = model.weight + origin_nf4_tensor = to_nf4(origin_tensor) + model = fully_shard(model) + sharded_tensor = model.weight + sharded_origin_nf4_tensor = distribute_tensor( + origin_nf4_tensor, + sharded_tensor.device_mesh, + sharded_tensor.placements, + ) + + sharded_nf4_detach = sharded_origin_nf4_tensor.detach() + resumed_full_tensor = sharded_nf4_detach.full_tensor() + + self.assertEqual( + origin_nf4_tensor.get_original_weight(), + resumed_full_tensor.get_original_weight(), + ) + + instantiate_parametrized_tests(TestNF4Linear) instantiate_parametrized_tests(TestFSDPOps) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 698a9391bd..7415878f3f 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -7,7 +7,7 @@ import math import sys from dataclasses import dataclass, replace -from enum import Enum, auto +from enum import auto, Enum from typing import Any, Dict, Optional, Tuple, Union import torch @@ -22,7 +22,51 @@ c10d_functional = torch.ops.c10d_functional -NF4_OPS_TABLE: Dict[Any, Any] = {} +def nf4_all_gather_into_tensor(func, *args, **kwargs): + nf4tensor = args[0][0] + group_size = args[0][1] + name = args[0][2] + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + updated_attrs[attr] = func(getattr(nf4tensor, attr), group_size, name) + updated_attrs.update( + { + "size": torch.Size((nf4tensor.size()[0] * group_size, nf4tensor.size()[1])), + } + ) + updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + return updatedNF4Tensor + + +def scatter_nf4tensor(func, *args, **kwargs): + output_tensor = args[0][0][0] + input_tensors = args[0][1] + process_group = args[0][2] + src = args[0][3] + asyncOp = args[0][4] + new_attr, update_work = [], [] + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + input_attrs = [] + if input_tensors: + for input_tensor in input_tensors[0]: + if hasattr(input_tensor, attr): + input_attrs.append(getattr(input_tensor, attr)) + input_attrs = [input_attrs] + new_attr, update_work = func( + [getattr(output_tensor, attr)], + input_attrs, + process_group, + src, + asyncOp, + ) + # there are 3 works, return one of them, same as the tensor to fit the required output format + return new_attr, update_work + + +NF4_OPS_TABLE: Dict[Any, Any] = { + torch.ops._c10d_functional.all_gather_into_tensor.default: nf4_all_gather_into_tensor, + torch.ops.c10d.scatter_.default: scatter_nf4tensor, +} _INNER_TENSOR_NAMES_FOR_SHARDING = [ @@ -198,9 +242,9 @@ def nf4_split(aten_op, args, kwargs=None): attr_to_chunks = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert inner_tensor.numel() % num_chunks == 0, ( - f"{attr}.numel() not divisible by {num_chunks}" - ) + assert ( + inner_tensor.numel() % num_chunks == 0 + ), f"{attr}.numel() not divisible by {num_chunks}" chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) attr_to_chunks[attr] = chunks @@ -241,9 +285,9 @@ def nf4_new_zeros(aten_op, args, kwargs=None): updated_attrs = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert inner_tensor.size(0) % ratio == 0, ( - f"{attr}.numel() must be divisible by {ratio}" - ) + assert ( + inner_tensor.size(0) % ratio == 0 + ), f"{attr}.numel() must be divisible by {ratio}" inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) updated_attrs[attr] = inner_tensor updated_attrs["size"] = new_size @@ -273,19 +317,35 @@ def nf4_slice(aten_op, args, kwargs=None): aten.view.default, ] ) -@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") def nf4_view(aten_op, args, kwargs=None): nf4tensor = args[0] size = args[1] - if size[0] != -1: - raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") - updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) - updated_attrs.update( - { - "size": [nf4tensor.numel()], - "stride": (1,), - } - ) + if len(size) == 1: + if size[0] != -1: + raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") + else: + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + updated_attrs.update( + { + "size": [nf4tensor.numel()], + "stride": (1,), + } + ) + else: + updated_attrs = {} + if nf4tensor.numel() != nf4tensor.size()[0] * nf4tensor.size()[1]: + raise NotImplementedError("NF4Tensor size does not match numel.") + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + attr_size = [getattr(nf4tensor, attr).size()] + updated_attrs[attr] = aten_op( + getattr(nf4tensor, attr), *attr_size, **kwargs + ) + updated_attrs.update( + { + "stride": (nf4tensor.size()[1], 1), + } + ) return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) @@ -457,6 +517,20 @@ def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None): return tensors +@implements( + [ + torch.ops._c10d_functional.wait_tensor.default, + ] +) +def wait_tensor(func, *args, **kwargs): + nf4tensor = args[0][0] + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + updated_attrs[attr] = func(getattr(nf4tensor, attr)) + updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + return updatedNF4Tensor + + @dataclass(frozen=True) class SubclassTensorArgs: original_shape: torch.Size @@ -478,9 +552,9 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso torch.Tensor: Tensor of scalers for each block """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert (input_tensor.numel() % block_size) == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" - ) + assert ( + input_tensor.numel() % block_size + ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" n_blocks = input_tensor.numel() // block_size blocks = input_tensor.view(n_blocks, block_size) @@ -563,12 +637,12 @@ def from_tensor( block_size: int, scaler_block_size: int, ): - assert input_tensor.dim() <= 2, ( - f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" - ) - assert input_tensor.numel() % block_size == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" - ) + assert ( + input_tensor.dim() <= 2 + ), f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" + assert ( + input_tensor.numel() % block_size == 0 + ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" assert input_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this # assert not input_tensor.requires_grad, "Input tensor must not require grad" @@ -649,9 +723,9 @@ def double_quantize_scalers( size: (n_scaler_blocks) """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert (input_tensor.numel() % scaler_block_size) == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" - ) + assert ( + input_tensor.numel() % scaler_block_size + ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" # First round of quantization # Produces: A tensor of size (n_blocks) of input_tensor.dtype @@ -659,9 +733,9 @@ def double_quantize_scalers( scalers_1_mean = scalers_1.mean() scalers_1 = scalers_1 - scalers_1_mean # Second round of quantization - assert scalers_1.numel() % scaler_block_size == 0, ( - f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " - ) + assert ( + scalers_1.numel() % scaler_block_size == 0 + ), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " n_scaler_blocks = scalers_1.numel() // scaler_block_size scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size) @@ -703,9 +777,9 @@ def dequantize_scalers( """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert (input_tensor.numel() % scaler_block_size) == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" - ) + assert ( + input_tensor.numel() % scaler_block_size + ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" n_scaler_blocks = input_tensor.numel() // scaler_block_size input_tensor = input_tensor.view(n_scaler_blocks, scaler_block_size) dequantized = (input_tensor / quantization_factor.unsqueeze(-1)).flatten().to( @@ -721,9 +795,9 @@ def convert_to_norm_float_weight( flattened_tensor = input_tensor.flatten() # Since we are using uint8 we will encode 2 entries per byte numel = input_tensor.numel() - assert numel % 2 == 0, ( - "Number of elements must be even just to not have to think about the end" - ) + assert ( + numel % 2 == 0 + ), "Number of elements must be even just to not have to think about the end" # Reshape the flattened tensor into blocks of size self.block_size blocks = flattened_tensor.view(n_blocks, block_size) From 23b9ebf5d2acbe36e3ca4f3512323f1853c745b7 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 26 Jun 2025 11:29:15 -0700 Subject: [PATCH 2/5] adjust unit test --- test/dtypes/test_nf4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index cceb1c8168..be4b66284d 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -435,7 +435,7 @@ def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]): inner_tensor = getattr(viewed_tensor, attr) self.assertEqual(inner_tensor.size(0), inner_tensor.numel()) - @parametrize("input_size", [(512 * 512,), (512, 512)]) + @parametrize("input_size", [(512 * 512,)]) def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]): nf4_tensor = to_nf4(torch.randn(input_size)) if len(input_size) == 1: From aba9245f5385dd844b6932bc8c94fdeb5f374507 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 26 Jun 2025 11:31:31 -0700 Subject: [PATCH 3/5] lint --- test/dtypes/test_nf4.py | 12 ++++---- torchao/dtypes/nf4tensor.py | 56 ++++++++++++++++++------------------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index be4b66284d..04840837be 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -14,28 +14,28 @@ import pytest import torch import torch.nn.functional as F - -import torchao -from packaging import version from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - apply_activation_checkpointing, CheckpointWrapper, + apply_activation_checkpointing, ) from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( + TestCase, instantiate_parametrized_tests, parametrize, run_tests, - TestCase, ) + +import torchao +from packaging import version from torchao.dtypes._nf4tensor_api import nf4_weight_only from torchao.dtypes.nf4tensor import ( _INNER_TENSOR_NAMES_FOR_SHARDING, - linear_nf4, NF4Tensor, + linear_nf4, to_nf4, ) from torchao.testing.utils import skip_if_rocm diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 7415878f3f..63dee371be 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -7,7 +7,7 @@ import math import sys from dataclasses import dataclass, replace -from enum import auto, Enum +from enum import Enum, auto from typing import Any, Dict, Optional, Tuple, Union import torch @@ -242,9 +242,9 @@ def nf4_split(aten_op, args, kwargs=None): attr_to_chunks = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert ( - inner_tensor.numel() % num_chunks == 0 - ), f"{attr}.numel() not divisible by {num_chunks}" + assert inner_tensor.numel() % num_chunks == 0, ( + f"{attr}.numel() not divisible by {num_chunks}" + ) chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) attr_to_chunks[attr] = chunks @@ -285,9 +285,9 @@ def nf4_new_zeros(aten_op, args, kwargs=None): updated_attrs = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert ( - inner_tensor.size(0) % ratio == 0 - ), f"{attr}.numel() must be divisible by {ratio}" + assert inner_tensor.size(0) % ratio == 0, ( + f"{attr}.numel() must be divisible by {ratio}" + ) inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) updated_attrs[attr] = inner_tensor updated_attrs["size"] = new_size @@ -552,9 +552,9 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso torch.Tensor: Tensor of scalers for each block """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - input_tensor.numel() % block_size - ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + assert (input_tensor.numel() % block_size) == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + ) n_blocks = input_tensor.numel() // block_size blocks = input_tensor.view(n_blocks, block_size) @@ -637,12 +637,12 @@ def from_tensor( block_size: int, scaler_block_size: int, ): - assert ( - input_tensor.dim() <= 2 - ), f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" - assert ( - input_tensor.numel() % block_size == 0 - ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + assert input_tensor.dim() <= 2, ( + f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" + ) + assert input_tensor.numel() % block_size == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + ) assert input_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this # assert not input_tensor.requires_grad, "Input tensor must not require grad" @@ -723,9 +723,9 @@ def double_quantize_scalers( size: (n_scaler_blocks) """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - input_tensor.numel() % scaler_block_size - ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + assert (input_tensor.numel() % scaler_block_size) == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + ) # First round of quantization # Produces: A tensor of size (n_blocks) of input_tensor.dtype @@ -733,9 +733,9 @@ def double_quantize_scalers( scalers_1_mean = scalers_1.mean() scalers_1 = scalers_1 - scalers_1_mean # Second round of quantization - assert ( - scalers_1.numel() % scaler_block_size == 0 - ), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " + assert scalers_1.numel() % scaler_block_size == 0, ( + f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " + ) n_scaler_blocks = scalers_1.numel() // scaler_block_size scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size) @@ -777,9 +777,9 @@ def dequantize_scalers( """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - input_tensor.numel() % scaler_block_size - ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + assert (input_tensor.numel() % scaler_block_size) == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + ) n_scaler_blocks = input_tensor.numel() // scaler_block_size input_tensor = input_tensor.view(n_scaler_blocks, scaler_block_size) dequantized = (input_tensor / quantization_factor.unsqueeze(-1)).flatten().to( @@ -795,9 +795,9 @@ def convert_to_norm_float_weight( flattened_tensor = input_tensor.flatten() # Since we are using uint8 we will encode 2 entries per byte numel = input_tensor.numel() - assert ( - numel % 2 == 0 - ), "Number of elements must be even just to not have to think about the end" + assert numel % 2 == 0, ( + "Number of elements must be even just to not have to think about the end" + ) # Reshape the flattened tensor into blocks of size self.block_size blocks = flattened_tensor.view(n_blocks, block_size) From a54089f29d1ccf80be69fd4c270a7ea01c13f0c6 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 26 Jun 2025 14:10:33 -0700 Subject: [PATCH 4/5] add timeout --- torchao/dtypes/nf4tensor.py | 65 +++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 36 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 63dee371be..4ce3f60372 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -7,7 +7,7 @@ import math import sys from dataclasses import dataclass, replace -from enum import Enum, auto +from enum import auto, Enum from typing import Any, Dict, Optional, Tuple, Union import torch @@ -41,9 +41,6 @@ def nf4_all_gather_into_tensor(func, *args, **kwargs): def scatter_nf4tensor(func, *args, **kwargs): output_tensor = args[0][0][0] input_tensors = args[0][1] - process_group = args[0][2] - src = args[0][3] - asyncOp = args[0][4] new_attr, update_work = [], [] for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: input_attrs = [] @@ -53,11 +50,7 @@ def scatter_nf4tensor(func, *args, **kwargs): input_attrs.append(getattr(input_tensor, attr)) input_attrs = [input_attrs] new_attr, update_work = func( - [getattr(output_tensor, attr)], - input_attrs, - process_group, - src, - asyncOp, + [getattr(output_tensor, attr)], input_attrs, *args[0][2:] ) # there are 3 works, return one of them, same as the tensor to fit the required output format return new_attr, update_work @@ -242,9 +235,9 @@ def nf4_split(aten_op, args, kwargs=None): attr_to_chunks = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert inner_tensor.numel() % num_chunks == 0, ( - f"{attr}.numel() not divisible by {num_chunks}" - ) + assert ( + inner_tensor.numel() % num_chunks == 0 + ), f"{attr}.numel() not divisible by {num_chunks}" chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) attr_to_chunks[attr] = chunks @@ -285,9 +278,9 @@ def nf4_new_zeros(aten_op, args, kwargs=None): updated_attrs = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert inner_tensor.size(0) % ratio == 0, ( - f"{attr}.numel() must be divisible by {ratio}" - ) + assert ( + inner_tensor.size(0) % ratio == 0 + ), f"{attr}.numel() must be divisible by {ratio}" inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) updated_attrs[attr] = inner_tensor updated_attrs["size"] = new_size @@ -552,9 +545,9 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso torch.Tensor: Tensor of scalers for each block """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert (input_tensor.numel() % block_size) == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" - ) + assert ( + input_tensor.numel() % block_size + ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" n_blocks = input_tensor.numel() // block_size blocks = input_tensor.view(n_blocks, block_size) @@ -637,12 +630,12 @@ def from_tensor( block_size: int, scaler_block_size: int, ): - assert input_tensor.dim() <= 2, ( - f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" - ) - assert input_tensor.numel() % block_size == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" - ) + assert ( + input_tensor.dim() <= 2 + ), f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" + assert ( + input_tensor.numel() % block_size == 0 + ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" assert input_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this # assert not input_tensor.requires_grad, "Input tensor must not require grad" @@ -723,9 +716,9 @@ def double_quantize_scalers( size: (n_scaler_blocks) """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert (input_tensor.numel() % scaler_block_size) == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" - ) + assert ( + input_tensor.numel() % scaler_block_size + ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" # First round of quantization # Produces: A tensor of size (n_blocks) of input_tensor.dtype @@ -733,9 +726,9 @@ def double_quantize_scalers( scalers_1_mean = scalers_1.mean() scalers_1 = scalers_1 - scalers_1_mean # Second round of quantization - assert scalers_1.numel() % scaler_block_size == 0, ( - f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " - ) + assert ( + scalers_1.numel() % scaler_block_size == 0 + ), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " n_scaler_blocks = scalers_1.numel() // scaler_block_size scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size) @@ -777,9 +770,9 @@ def dequantize_scalers( """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert (input_tensor.numel() % scaler_block_size) == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" - ) + assert ( + input_tensor.numel() % scaler_block_size + ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" n_scaler_blocks = input_tensor.numel() // scaler_block_size input_tensor = input_tensor.view(n_scaler_blocks, scaler_block_size) dequantized = (input_tensor / quantization_factor.unsqueeze(-1)).flatten().to( @@ -795,9 +788,9 @@ def convert_to_norm_float_weight( flattened_tensor = input_tensor.flatten() # Since we are using uint8 we will encode 2 entries per byte numel = input_tensor.numel() - assert numel % 2 == 0, ( - "Number of elements must be even just to not have to think about the end" - ) + assert ( + numel % 2 == 0 + ), "Number of elements must be even just to not have to think about the end" # Reshape the flattened tensor into blocks of size self.block_size blocks = flattened_tensor.view(n_blocks, block_size) From f781ddb216eb53300cfa822384f9faf9cd82e7d1 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 26 Jun 2025 14:12:21 -0700 Subject: [PATCH 5/5] add lint --- torchao/dtypes/nf4tensor.py | 56 ++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 4ce3f60372..079b5feba4 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -7,7 +7,7 @@ import math import sys from dataclasses import dataclass, replace -from enum import auto, Enum +from enum import Enum, auto from typing import Any, Dict, Optional, Tuple, Union import torch @@ -235,9 +235,9 @@ def nf4_split(aten_op, args, kwargs=None): attr_to_chunks = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert ( - inner_tensor.numel() % num_chunks == 0 - ), f"{attr}.numel() not divisible by {num_chunks}" + assert inner_tensor.numel() % num_chunks == 0, ( + f"{attr}.numel() not divisible by {num_chunks}" + ) chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) attr_to_chunks[attr] = chunks @@ -278,9 +278,9 @@ def nf4_new_zeros(aten_op, args, kwargs=None): updated_attrs = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert ( - inner_tensor.size(0) % ratio == 0 - ), f"{attr}.numel() must be divisible by {ratio}" + assert inner_tensor.size(0) % ratio == 0, ( + f"{attr}.numel() must be divisible by {ratio}" + ) inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) updated_attrs[attr] = inner_tensor updated_attrs["size"] = new_size @@ -545,9 +545,9 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso torch.Tensor: Tensor of scalers for each block """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - input_tensor.numel() % block_size - ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + assert (input_tensor.numel() % block_size) == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + ) n_blocks = input_tensor.numel() // block_size blocks = input_tensor.view(n_blocks, block_size) @@ -630,12 +630,12 @@ def from_tensor( block_size: int, scaler_block_size: int, ): - assert ( - input_tensor.dim() <= 2 - ), f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" - assert ( - input_tensor.numel() % block_size == 0 - ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + assert input_tensor.dim() <= 2, ( + f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" + ) + assert input_tensor.numel() % block_size == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + ) assert input_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this # assert not input_tensor.requires_grad, "Input tensor must not require grad" @@ -716,9 +716,9 @@ def double_quantize_scalers( size: (n_scaler_blocks) """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - input_tensor.numel() % scaler_block_size - ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + assert (input_tensor.numel() % scaler_block_size) == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + ) # First round of quantization # Produces: A tensor of size (n_blocks) of input_tensor.dtype @@ -726,9 +726,9 @@ def double_quantize_scalers( scalers_1_mean = scalers_1.mean() scalers_1 = scalers_1 - scalers_1_mean # Second round of quantization - assert ( - scalers_1.numel() % scaler_block_size == 0 - ), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " + assert scalers_1.numel() % scaler_block_size == 0, ( + f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " + ) n_scaler_blocks = scalers_1.numel() // scaler_block_size scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size) @@ -770,9 +770,9 @@ def dequantize_scalers( """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - input_tensor.numel() % scaler_block_size - ) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + assert (input_tensor.numel() % scaler_block_size) == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + ) n_scaler_blocks = input_tensor.numel() // scaler_block_size input_tensor = input_tensor.view(n_scaler_blocks, scaler_block_size) dequantized = (input_tensor / quantization_factor.unsqueeze(-1)).flatten().to( @@ -788,9 +788,9 @@ def convert_to_norm_float_weight( flattened_tensor = input_tensor.flatten() # Since we are using uint8 we will encode 2 entries per byte numel = input_tensor.numel() - assert ( - numel % 2 == 0 - ), "Number of elements must be even just to not have to think about the end" + assert numel % 2 == 0, ( + "Number of elements must be even just to not have to think about the end" + ) # Reshape the flattened tensor into blocks of size self.block_size blocks = flattened_tensor.view(n_blocks, block_size)