8000 [NF4] Support nf4 tensor shard and gather by mori360 · Pull Request #2449 · pytorch/ao · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[NF4] Support nf4 tensor shard and gather #2449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]):
inner_tensor = getattr(viewed_tensor, attr)
self.assertEqual(inner_tensor.size(0), inner_tensor.numel())

@parametrize("input_size", [(512 * 512,), (512, 512)])
@parametrize("input_size", [(512 * 512,)])
def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
if len(input_size) == 1:
Expand Down Expand Up @@ -741,6 +741,46 @@ def _test_qlora_fsdp2(
self.assertEqual(fsdp_loss, base_loss)


class TestComm(FSDPTest):
@property
def world_size(self) -> int:
return 2

@pytest.mark.skipif(
version.parse(torch.__version__).base_version < "2.4.0",
reason="torch >= 2.4 required",
)
@skip_if_lt_x_gpu(2)
def test_comm(self):
self.run_subtests(
{"input_size": [512, 2048]},
self._test_comm,
)

def _test_comm(self, input_size: int):
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import distribute_tensor

model = nn.Linear(input_size, input_size, device="cuda")
origin_tensor = model.weight
origin_nf4_tensor = to_nf4(origin_tensor)
model = fully_shard(model)
sharded_tensor = model.weight
sharded_origin_nf4_tensor = distribute_tensor(
origin_nf4_tensor,
sharded_tensor.device_mesh,
sharded_tensor.placements,
)

sharded_nf4_detach = sharded_origin_nf4_tensor.detach()
resumed_full_tensor = sharded_nf4_detach.full_tensor()

self.assertEqual(
origin_nf4_tensor.get_original_weight(),
resumed_full_tensor.get_original_weight(),
)


instantiate_parametrized_tests(TestNF4Linear)
instantiate_parametrized_tests(TestFSDPOps)

Expand Down
89 changes: 78 additions & 11 deletions torchao/dtypes/nf4tensor.py
< 8000 /tr>
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,44 @@
c10d_functional = torch.ops.c10d_functional


NF4_OPS_TABLE: Dict[Any, Any] = {}
def nf4_all_gather_into_tensor(func, *args, **kwargs):
nf4tensor = args[0][0]
group_size = args[0][1]
name = args[0][2]
updated_attrs = {}
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
updated_attrs[attr] = func(getattr(nf4tensor, attr), group_size, name)
updated_attrs.update(
{
"size": torch.Size((nf4tensor.size()[0] * group_size, nf4tensor.size()[1])),
}
)
updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
return updatedNF4Tensor


def scatter_nf4tensor(func, *args, **kwargs):
output_tensor = args[0][0][0]
input_tensors = args[0][1]
new_attr, update_work = [], []
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
input_attrs = []
if input_tensors:
for input_tensor in input_tensors[0]:
if hasattr(input_tensor, attr):
input_attrs.append(getattr(input_tensor, attr))
input_attrs = [input_attrs]
new_attr, update_work = func(
[getattr(output_tensor, attr)], input_attrs, *args[0][2:]
)
# there are 3 works, return one of them, same as the tensor to fit the required output format
return new_attr, update_work


NF4_OPS_TABLE: Dict[Any, Any] = {
torch.ops._c10d_functional.all_gather_into_tensor.default: nf4_all_gather_into_tensor,
torch.ops.c10d.scatter_.default: scatter_nf4tensor,
}


_INNER_TENSOR_NAMES_FOR_SHARDING = [
Expand Down Expand Up @@ -273,19 +310,35 @@ def nf4_slice(aten_op, args, kwargs=None):
aten.view.default,
]
)
@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=")
@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=")
def nf4_view(aten_op, args, kwargs=None):
nf4tensor = args[0]
size = args[1]
if size[0] != -1:
raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}")
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs)
updated_attrs.update(
{
"size": [nf4tensor.numel()],
"stride": (1,),
}
)
if len(size) == 1:
if size[0] != -1:
raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}")
else:
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs)
updated_attrs.update(
{
"size": [nf4tensor.numel()],
"stride": (1,),
}
)
else:
updated_attrs = {}
if nf4tensor.numel() != nf4tensor.size()[0] * nf4tensor.size()[1]:
raise NotImplementedError("NF4Tensor size does not match numel.")
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
attr_size = [getattr(nf4tensor, attr).size()]
updated_attrs[attr] = aten_op(
getattr(nf4tensor, attr), *attr_size, **kwargs
)
updated_attrs.update(
{
"stride": (nf4tensor.size()[1], 1),
}
)
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))


Expand Down Expand Up @@ -457,6 +510,20 @@ def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None):
return tensors


@implements(
[
torch.ops._c10d_functional.wait_tensor.default,
]
)
def wait_tensor(func, *args, **kwargs):
nf4tensor = args[0][0]
updated_attrs = {}
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
updated_attrs[attr] = func(getattr(nf4tensor, attr))
updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
return updatedNF4Tensor


@dataclass(frozen=True)
class SubclassTensorArgs:
original_shape: torch.Size
Expand Down
Loading
0