diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index f52644cdf3..0a04197464 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -435,7 +435,17 @@ def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]): inner_tensor = getattr(viewed_tensor, attr) self.assertEqual(inner_tensor.size(0), inner_tensor.numel()) - @parametrize("input_size", [(512 * 512,), (512, 512)]) + @parametrize("input_size", [(512, 512)]) + def test_tensor_2d_view_valid(self, input_size: Tuple[int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + viewed_tensor = nf4_tensor.view(input_size) + self.assertEqual(viewed_tensor.dim(), 2) + self.assertEqual(viewed_tensor.numel(), math.prod(input_size)) + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(viewed_tensor, attr) + self.assertEqual(inner_tensor.size(0), inner_tensor.numel()) + + @parametrize("input_size", [(512 * 512,)]) def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]): nf4_tensor = to_nf4(torch.randn(input_size)) if len(input_size) == 1: @@ -443,11 +453,6 @@ def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]): NotImplementedError, "aten.view\\(NF4Tensor\\) with size" ): nf4_tensor.view(input_size) - if len(input_size) == 2: - with self.assertRaisesRegex( - NotImplementedError, "aten.view\\(NF4Tensor\\) with len\\(size\\)" - ): - nf4_tensor.view(input_size) @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]): @@ -741,6 +746,42 @@ def _test_qlora_fsdp2( self.assertEqual(fsdp_loss, base_loss) +class TestComm(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @skip_if_lt_x_gpu(2) + def test_comm(self): + self.run_subtests( + {"input_size": [512, 2048]}, + self._test_comm, + ) + + def _test_comm(self, input_size: int): + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed._tensor import distribute_tensor + + model = nn.Linear(input_size, input_size, device="cuda") + origin_tensor = model.weight + origin_nf4_tensor = to_nf4(origin_tensor) + model = fully_shard(model) + sharded_tensor = model.weight + sharded_origin_nf4_tensor = distribute_tensor( + origin_nf4_tensor, + sharded_tensor.device_mesh, + sharded_tensor.placements, + ) + + sharded_nf4_detach = sharded_origin_nf4_tensor.detach() + resumed_full_tensor = sharded_nf4_detach.full_tensor() + + self.assertEqual( + origin_nf4_tensor.get_original_weight(), + resumed_full_tensor.get_original_weight(), + ) + + instantiate_parametrized_tests(TestNF4Linear) instantiate_parametrized_tests(TestFSDPOps) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 698a9391bd..e6662b350a 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -22,7 +22,51 @@ c10d_functional = torch.ops.c10d_functional -NF4_OPS_TABLE: Dict[Any, Any] = {} +def nf4_all_gather_into_tensor(func, *args, **kwargs): + assert len(args) > 1, "Expected valid input" + assert len(args[0]) == 3, "Expected 3 input args" + nf4tensor = args[0][0] + group_size = args[0][1] + name = args[0][2] + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + updated_attrs[attr] = func(getattr(nf4tensor, attr), group_size, name) + updated_attrs.update( + { + "size": torch.Size((nf4tensor.size()[0] * group_size, nf4tensor.size()[1])), + } + ) + updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + return updatedNF4Tensor + + +def scatter_nf4tensor(func, *args, **kwargs): + assert len(args) > 1, "Expected valid input" + assert len(args[0][0]) == 1, "Expected 1 output tensor" + output_tensor = args[0][0][0] + input_tensors = args[0][1] + new_attr, update_work = [], [] + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + input_attrs = [] + if input_tensors: + for input_tensor in input_tensors[0]: + assert input_tensor.size() == output_tensor.size(), ( + "Input tensor size must match output tensor size, tensors are not evenly divided." + ) + if hasattr(input_tensor, attr): + input_attrs.append(getattr(input_tensor, attr)) + input_attrs = [input_attrs] + new_attr, update_work = func( + [getattr(output_tensor, attr)], input_attrs, *args[0][2:] + ) + # there are 3 works, return one of them, same as the tensor to fit the required output format + return new_attr, update_work + + +NF4_OPS_TABLE: Dict[Any, Any] = { + torch.ops._c10d_functional.all_gather_into_tensor.default: nf4_all_gather_into_tensor, + torch.ops.c10d.scatter_.default: scatter_nf4tensor, +} _INNER_TENSOR_NAMES_FOR_SHARDING = [ @@ -233,7 +277,6 @@ def nf4_split(aten_op, args, kwargs=None): def nf4_new_zeros(aten_op, args, kwargs=None): nf4tensor = args[0] new_size = tuple(args[1]) - if nf4tensor.numel() % math.prod(new_size) != 0: raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") ratio = nf4tensor.numel() // math.prod(new_size) @@ -273,19 +316,37 @@ def nf4_slice(aten_op, args, kwargs=None): aten.view.default, ] ) -@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") def nf4_view(aten_op, args, kwargs=None): nf4tensor = args[0] size = args[1] - if size[0] != -1: - raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") - updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) - updated_attrs.update( - { - "size": [nf4tensor.numel()], - "stride": (1,), - } - ) + if len(size) == 1: + if size[0] != -1: + raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") + else: + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + updated_attrs.update( + { + "size": [nf4tensor.numel()], + "stride": (1,), + } + ) + elif len(size) == 2: + if nf4tensor.numel() != size[0] * size[1]: + raise NotImplementedError("NF4Tensor size does not match view size.") + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + attr_size = [getattr(nf4tensor, attr).size()] + updated_attrs[attr] = aten_op( + getattr(nf4tensor, attr), *attr_size, **kwargs + ) + updated_attrs.update( + { + "stride": (size[1], 1), + } + ) + else: + raise NotImplementedError("aten.view(NF4Tensor) with empty size") return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) @@ -457,6 +518,20 @@ def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None): return tensors +@implements( + [ + torch.ops._c10d_functional.wait_tensor.default, + ] +) +def wait_tensor(func, *args, **kwargs): + nf4tensor = args[0][0] + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + updated_attrs[attr] = func(getattr(nf4tensor, attr)) + updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + return updatedNF4Tensor + + @dataclass(frozen=True) class SubclassTensorArgs: original_shape: torch.Size