From efd993f09df2cc1430608d77e818ba4bdf328883 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 21 Jun 2025 11:34:43 -0700 Subject: [PATCH 1/5] make offs optional for scaled grouped mm --- .../moe_training/scaled_grouped_mm.py | 10 ++++++---- torchao/prototype/moe_training/tensor.py | 20 +++++++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 29adffd831..c4c9db9266 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Optional import torch @@ -18,11 +19,13 @@ _is_column_major, ) +logger: logging.Logger = logging.getLogger(__name__) + def _scaled_grouped_mm( A: torch.Tensor, B_t: torch.Tensor, - offs: torch.Tensor, + offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, ) -> torch.Tensor: """ @@ -38,6 +41,7 @@ def _scaled_grouped_mm( out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True. """ + logger.info("Using differentiable _scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, @@ -54,9 +58,8 @@ def forward( ctx, A: torch.Tensor, B_t: torch.Tensor, - offs: torch.Tensor, + offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, - use_triton_for_per_group_scales: bool = True, ) -> torch.Tensor: # torchao _scaled_grouped_mm only supports A=2D, B=3D. assert A.ndim == 2, "A must be 2D" @@ -76,7 +79,6 @@ def forward( assert B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16, ( "B must be float32 or bfloat16" ) - assert offs.dtype == torch.int32, "offs must be int32" # Assert A and B dims are compatible for a scaled grouped GEMM. assert A.size(-1) == B_t.size(-2), ( diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 3ea9529237..37954879d3 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -1,3 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import logging from typing import Any, Optional, Tuple import torch @@ -6,6 +13,8 @@ from torchao.prototype.moe_training import _scaled_grouped_mm +logger: logging.Logger = logging.getLogger(__name__) + _ops_to_preserve_subclass = { torch.ops.aten.empty_like.default, torch.ops.aten.new_zeros.default, @@ -27,7 +36,7 @@ class ScaledGroupedMMTensor(torch.Tensor): differentiable _scaled_grouped_mm autograd function. """ - grouped_mm_func_name = "_grouped_mm" + grouped_mm_func_names = {"_grouped_mm", "_grouped_mm.default"} offs_arg_name = "offs" @staticmethod @@ -57,7 +66,7 @@ def __init__( @classmethod def __torch_function__(cls, func, types, args, kwargs={}): # override the grouped mm op to use the differentiable _scaled_grouped_mm - if func.__name__ == cls.grouped_mm_func_name: + if func.__name__ in cls.grouped_mm_func_names: # Use torchao scaled grouped mm with dynamic quant for # "2d x 3d with offsets" case (used for routed experts). # Otherwise, fall back to regular grouped mm. @@ -69,7 +78,9 @@ def __torch_function__(cls, func, types, args, kwargs={}): A_is_2d = A.dim() == 2 B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None - if A_is_2d and B_is_3d and has_offs: + logger.info(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}") + + if A_is_2d and B_is_3d: return _scaled_grouped_mm( *args, **kwargs, @@ -107,7 +118,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}): ) def fsdp_pre_all_gather(self, mesh): - return (self._data,), () + metadata = () + return (self._data,), metadata def fsdp_post_all_gather( self, From accbb27951cc413c7e695f528e01848937212f2a Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 23 Jun 2025 09:49:54 -0700 Subject: [PATCH 2/5] tp on routed experts working --- .../moe_training/conversion_utils.py | 12 ++++- .../moe_training/scaled_grouped_mm.py | 51 +++++++++++++++---- torchao/prototype/moe_training/tensor.py | 8 +-- 3 files changed, 55 insertions(+), 16 deletions(-) diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index 51af0fd956..a7aaf0ebb5 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import logging from typing import Callable, Optional from torch import nn @@ -8,6 +14,8 @@ register_quantize_module_handler, ) +logger: logging.Logger = logging.getLogger(__name__) + class MoETrainingConfig(AOBaseConfig): """ @@ -105,7 +113,9 @@ def post_order_traversal( ScaledGroupedMMTensor(param), requires_grad=param.requires_grad ) setattr(module, param_name, new_param) - print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor") + logger.info( + f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor" + ) post_order_traversal(root_module) return root_module diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index c4c9db9266..c2192005e6 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -39,9 +39,8 @@ def _scaled_grouped_mm( and in column-major memory layout. offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. - use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True. """ - logger.info("Using differentiable _scaled_grouped_mm") + logger.debug("Using differentiable _scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, @@ -61,8 +60,8 @@ def forward( offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, ) -> torch.Tensor: - # torchao _scaled_grouped_mm only supports A=2D, B=3D. - assert A.ndim == 2, "A must be 2D" + # torchao _scaled_grouped_mm only supports A=2D|3D + B=3D. + assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D" assert B_t.ndim == 3, "B must be 3D" assert A.size(-1) % 16 == 0, ( @@ -151,12 +150,25 @@ def forward( assert _is_column_major(B_t_fp8_col_major), ( "B must be column-major for output = A @ B" ) + + # TODO: remove excessive logging once prototype is more mature. + logger.debug( + ( + f"forward scaled_grouped_mm: A_fp8_row_major.shape={A_fp8_row_major.shape}, " + f"A_scale.shape={A_scales.squeeze(-1).shape}, " + f"B_t_fp8_col_major.shape={B_t_fp8_col_major.shape}, " + f"B_t_scale.shape={B_t_scales.squeeze(1).shape}, " + f"offs={offs if offs is not None else None}" + ) + ) return torch._scaled_grouped_mm( A_fp8_row_major, B_t_fp8_col_major, - A_scales.squeeze().reciprocal(), - B_t_scales.squeeze().reciprocal(), - offs, + # Squeeze A scales to: (B, S, 1) => (B, M), or (B*S, 1) => (B*S) + A_scales.squeeze(-1).reciprocal(), + # Squeeze B scales to: (B, 1, N) => (B, N) + B_t_scales.squeeze(1).reciprocal(), + offs=offs, out_dtype=out_dtype, use_fast_accum=True, ) @@ -193,12 +205,20 @@ def backward(ctx, grad_output: torch.Tensor): assert _is_column_major(B_fp8_col_major), ( "B must be column-major for grad_A = grad_output @ B" ) + logger.debug( + ( + f"backward grad_A: grad_output_fp8_row_major.shape={grad_output_fp8_row_major.shape}, " + f"grad_output_scale.shape={grad_output_scales.shape}, " + f"B_fp8_col_major.shape={B_fp8_col_major.shape}, " + f"B_scale.shape={B_scales.shape}, " + ) + ) grad_A = torch._scaled_grouped_mm( grad_output_fp8_row_major, B_fp8_col_major, - grad_output_scales.squeeze().reciprocal(), - B_scales.squeeze().reciprocal(), - offs, + grad_output_scales.squeeze(-1).reciprocal(), + B_scales.squeeze(1).reciprocal(), + offs=offs, out_dtype=out_dtype, use_fast_accum=True, ) @@ -238,12 +258,21 @@ def backward(ctx, grad_output: torch.Tensor): assert _is_column_major(A_fp8_col_major), ( "A must be column-major for grad_B = grad_output_t @ A" ) + + logger.debug( + ( + f"backward grad_B: grad_output_t_fp8_row_major.shape={grad_output_t_fp8_row_major.shape}, " + f"grad_output_t_scale.shape={grad_output_t_scales.shape}, " + f"A_fp8_col_major.shape={A_fp8_col_major.shape}, " + f"A_scale.shape={A_scales.shape}, " + ) + ) grad_B = torch._scaled_grouped_mm( grad_output_t_fp8_row_major, A_fp8_col_major, grad_output_t_scales.reciprocal(), A_scales.reciprocal(), - offs, + offs=offs, out_dtype=out_dtype, use_fast_accum=True, ) diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 37954879d3..6b9df55c90 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -75,12 +75,12 @@ def __torch_function__(cls, func, types, args, kwargs={}): # used for shared experts. This is basically the grouped_mm # kernel handling a bmm. A, B = args[0], args[1] - A_is_2d = A.dim() == 2 + A_is_2d_or_3d = A.dim() in (2, 3) B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None - logger.info(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}") - - if A_is_2d and B_is_3d: + logger.debug(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}") + + if A_is_2d_or_3d and B_is_3d: return _scaled_grouped_mm( *args, **kwargs, From a80b9a0f36edab2ac664cca231aa0e6d4b883466 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 24 Jun 2025 10:28:19 -0700 Subject: [PATCH 3/5] add tp integration test --- test/prototype/moe_training/test_fsdp.py | 13 ++ test/prototype/moe_training/test_tp.py | 219 +++++++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 test/prototype/moe_training/test_tp.py diff --git a/test/prototype/moe_training/test_fsdp.py b/test/prototype/moe_training/test_fsdp.py index 4994a76854..302256dbec 100644 --- a/test/prototype/moe_training/test_fsdp.py +++ b/test/prototype/moe_training/test_fsdp.py @@ -1,3 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp.py +# +####################################################################### + import copy import os diff --git a/test/prototype/moe_training/test_tp.py b/test/prototype/moe_training/test_tp.py new file mode 100644 index 0000000000..3984ae2d40 --- /dev/null +++ b/test/prototype/moe_training/test_tp.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_tp.py +# +####################################################################### + +import copy +import os + +import pytest +import torch +from torch import distributed as dist +from torch import nn +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.nn import functional as F + +# this feature requires CUDA and SM89+ +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): + pytest.skip( + "CUDA not available or compute capability < 8.9", allow_module_level=True + ) + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig +from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor +from torchao.quantization.quant_api import quantize_ + +# this test requires torchtitan +try: + from torchtitan.experiments.llama4.infra.parallelize import apply_moe_tp + from torchtitan.experiments.llama4.model.args import TransformerModelArgs + from torchtitan.experiments.llama4.model.moe import MoE +except ImportError: + import warnings + + warnings.warn("torchtitan not installed, skipping MoE tests.") + pytest.skip(allow_module_level=True) + + +@pytest.mark.parametrize( + "target_fqns", + [ + ["experts"], + ["experts,shared_expert"], + ], +) +def test_moe_float8_training_tp_sp(target_fqns: list[str]): + assert torch.cuda.is_available() + + # setup distributed for fsdp + mesh = setup_distributed() + + # define model args + model_args = TransformerModelArgs( + moe_enabled=True, + num_experts=8, + dim=256, + vocab_size=1024, + ) + init_std = 0.02 + device = torch.device("cuda") + + # reference bf16 MoE + ref_model = MoE(model_args).to(torch.bfloat16).cuda() + torch.manual_seed(1) + ref_model.init_weights(init_std, device) + + # target MoE for testing conversion + model = copy.deepcopy(ref_model) + + # assert starting params are identical for both models + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + assert torch.equal(param1, param2) + + # convert MoE to float8 training + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in target_fqns: + if target_fqn in cur_fqn: + return True + return False + + # quantize test model + config = MoETrainingConfig() + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + + # validate that only the experts were converted + _validate_model_conversion( + model, + target_fqns=target_fqns, + ) + + # apply TP + apply_moe_tp(model, mesh) + apply_moe_tp(ref_model, mesh) + + # inputs + batch, seq, dim = 8, 2048, 256 + ref_x = torch.randn( + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device + ) + x = ref_x.detach().clone().requires_grad_(True) + + # forward pass + ref_out = ref_model(ref_x) + out = model(x) + + # validate output + out_sqnr = compute_error(out, ref_out) + assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + + # compute loss + labels = torch.ones_like(ref_out) + ref_loss = F.mse_loss(ref_out, labels) + out_loss = F.mse_loss(out, labels) + + # backward pass + ref_loss.backward() + out_loss.backward() + + # validate input gradient + input_grad_sqnr = compute_error(x.grad, ref_x.grad) + assert input_grad_sqnr.item() >= 30.0, ( + f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}." + ) + + # validate param gradients + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + param_grad_sqnr = compute_error(param1.grad, param2.grad) + assert param_grad_sqnr.item() >= 25.0, ( + f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + ) + + dist.destroy_process_group() + + +def _validate_model_conversion( + root_module: nn.Module, + target_fqns: list[str], +): + def _recursive_validate( + module: nn.Module, + cur_fqn: str, + ): + is_allowed_module = any([target_fqn in cur_fqn for target_fqn in target_fqns]) + + # check current module params + for param_name, param in module.named_parameters(recurse=False): + is_converted_type = isinstance(param, ScaledGroupedMMTensor) + if is_converted_type: + assert is_allowed_module, ( + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." + ) + if not is_allowed_module: + assert not is_converted_type, ( + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." + ) + + # recursively check child modules + for child_name, child_module in module.named_children(): + child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name + _recursive_validate(child_module, child_fqn) + + _recursive_validate(root_module, "") + + +def setup_distributed(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + device_mesh = init_device_mesh("cuda", (world_size,)) + # seed must be the same in all processes + torch.manual_seed(1) + torch.cuda.set_device(rank) + return device_mesh + + +def apply_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, +): + # base on llama4 MoE TP implementation here: https://github.com/pytorch/torchtitan/blob/d9cc6b4df341eec27768b5ab9cead87ef595dbc2/torchtitan/experiments/llama4/infra/parallelize.py#L147C1-L180C10 + from torch.distributed.tensor import Partial, Replicate, Shard + from torch.distributed.tensor.parallel import ( + PrepareModuleInputOutput, + parallelize_module, + ) + from torchtitan.experiments.llama4.infra.expert_parallel import ( + NoParallel, + TensorParallel, + ) + + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + # input Replicate, output Partial + "moe.experts": TensorParallel(output_layout=Partial()), + "moe.shared_expert": TensorParallel(output_layout=Partial()), + } + parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) From fb0122e1eb16567c5fe702a2c9e8ae0c8a38fa15 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 24 Jun 2025 10:35:16 -0700 Subject: [PATCH 4/5] remove excessive logging --- .../moe_training/scaled_grouped_mm.py | 30 +------------------ torchao/prototype/moe_training/tensor.py | 6 ---- 2 files changed, 1 insertion(+), 35 deletions(-) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index c2192005e6..c0460155b6 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -60,7 +60,7 @@ def forward( offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, ) -> torch.Tensor: - # torchao _scaled_grouped_mm only supports A=2D|3D + B=3D. + # torchao _scaled_grouped_mm only supports A=2D|3D and B=3D. assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D" assert B_t.ndim == 3, "B must be 3D" @@ -150,17 +150,6 @@ def forward( assert _is_column_major(B_t_fp8_col_major), ( "B must be column-major for output = A @ B" ) - - # TODO: remove excessive logging once prototype is more mature. - logger.debug( - ( - f"forward scaled_grouped_mm: A_fp8_row_major.shape={A_fp8_row_major.shape}, " - f"A_scale.shape={A_scales.squeeze(-1).shape}, " - f"B_t_fp8_col_major.shape={B_t_fp8_col_major.shape}, " - f"B_t_scale.shape={B_t_scales.squeeze(1).shape}, " - f"offs={offs if offs is not None else None}" - ) - ) return torch._scaled_grouped_mm( A_fp8_row_major, B_t_fp8_col_major, @@ -205,14 +194,6 @@ def backward(ctx, grad_output: torch.Tensor): assert _is_column_major(B_fp8_col_major), ( "B must be column-major for grad_A = grad_output @ B" ) - logger.debug( - ( - f"backward grad_A: grad_output_fp8_row_major.shape={grad_output_fp8_row_major.shape}, " - f"grad_output_scale.shape={grad_output_scales.shape}, " - f"B_fp8_col_major.shape={B_fp8_col_major.shape}, " - f"B_scale.shape={B_scales.shape}, " - ) - ) grad_A = torch._scaled_grouped_mm( grad_output_fp8_row_major, B_fp8_col_major, @@ -258,15 +239,6 @@ def backward(ctx, grad_output: torch.Tensor): assert _is_column_major(A_fp8_col_major), ( "A must be column-major for grad_B = grad_output_t @ A" ) - - logger.debug( - ( - f"backward grad_B: grad_output_t_fp8_row_major.shape={grad_output_t_fp8_row_major.shape}, " - f"grad_output_t_scale.shape={grad_output_t_scales.shape}, " - f"A_fp8_col_major.shape={A_fp8_col_major.shape}, " - f"A_scale.shape={A_scales.shape}, " - ) - ) grad_B = torch._scaled_grouped_mm( grad_output_t_fp8_row_major, A_fp8_col_major, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 6b9df55c90..020a7ca7be 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import logging from typing import Any, Optional, Tuple import torch @@ -13,8 +12,6 @@ from torchao.prototype.moe_training import _scaled_grouped_mm -logger: logging.Logger = logging.getLogger(__name__) - _ops_to_preserve_subclass = { torch.ops.aten.empty_like.default, torch.ops.aten.new_zeros.default, @@ -77,9 +74,6 @@ def __torch_function__(cls, func, types, args, kwargs={}): A, B = args[0], args[1] A_is_2d_or_3d = A.dim() in (2, 3) B_is_3d = B.dim() == 3 - has_offs = kwargs.get(cls.offs_arg_name) is not None - logger.debug(f"A.shape={A.shape}, B.shape={B.shape}, has_offs={has_offs}") - if A_is_2d_or_3d and B_is_3d: return _scaled_grouped_mm( *args, From 29be4b2e6f61111a1ff4d1cb82e60637673d4771 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 26 Jun 2025 16:06:54 -0700 Subject: [PATCH 5/5] fix dtype bug --- .../moe_training/conversion_utils.py | 5 +- .../moe_training/scaled_grouped_mm.py | 2 +- torchao/prototype/moe_training/tensor.py | 49 +++++++++++++++---- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index a7aaf0ebb5..72056e68b3 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -84,7 +84,7 @@ def _swap_params( f"Does not support a root nn.Parameter with children: {module}" ) if not isinstance(module.data, ScaledGroupedMMTensor): - new_data = ScaledGroupedMMTensor(module.data) + new_data = ScaledGroupedMMTensor(module.data, module.data.dtype) return nn.Parameter(new_data, requires_grad=module.requires_grad) return module @@ -110,7 +110,8 @@ def post_order_traversal( for param_name, param in module.named_parameters(recurse=False): if not isinstance(param.data, ScaledGroupedMMTensor): new_param = nn.Parameter( - ScaledGroupedMMTensor(param), requires_grad=param.requires_grad + ScaledGroupedMMTensor(param.data, param.data.dtype), + requires_grad=param.requires_grad, ) setattr(module, param_name, new_param) logger.info( diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index c0460155b6..b353bb0a55 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -40,7 +40,7 @@ def _scaled_grouped_mm( offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. """ - logger.debug("Using differentiable _scaled_grouped_mm") + logger.info("Using differentiable _scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 020a7ca7be..cd74d5535f 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Any, Optional, Tuple import torch @@ -12,6 +13,9 @@ from torchao.prototype.moe_training import _scaled_grouped_mm +logger: logging.Logger = logging.getLogger(__name__) + + _ops_to_preserve_subclass = { torch.ops.aten.empty_like.default, torch.ops.aten.new_zeros.default, @@ -40,6 +44,7 @@ class ScaledGroupedMMTensor(torch.Tensor): def __new__( cls, tensor: torch.Tensor, + dtype: torch.dtype, ): return torch.Tensor._make_wrapper_subclass( cls, @@ -47,7 +52,7 @@ def __new__( strides=tensor.stride(), storage_offset=tensor.storage_offset(), memory_format=suggest_memory_format(tensor), - dtype=tensor.dtype, + dtype=dtype, layout=tensor.layout, device=tensor.device, pin_memory=tensor.is_pinned(), @@ -57,8 +62,10 @@ def __new__( def __init__( self, tensor: torch.Tensor, + dtype: torch.dtype, ): self._data = tensor + self._dtype = dtype @classmethod def __torch_function__(cls, func, types, args, kwargs={}): @@ -87,12 +94,22 @@ def __torch_function__(cls, func, types, args, kwargs={}): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs={}): + logger.debug(f"{func.__name__}, args={args}, kwargs={kwargs}") # detach is special case if func == torch.ops.aten.detach.default: - return ScaledGroupedMMTensor(args[0]._data) + return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype) # unwrap args and kwargs - unwrap = lambda tensor: tensor._data + dtype: Optional[torch.dtype] = None + + def unwrap(t): + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype + return t._data + args, kwargs = pytree.tree_map_only( ScaledGroupedMMTensor, unwrap, (args, kwargs or {}) ) @@ -107,13 +124,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}): # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass return pytree.tree_map_only( torch.Tensor, - lambda x: ScaledGroupedMMTensor(x), + lambda x: ScaledGroupedMMTensor(x, dtype), out, ) + def __repr__(self): + return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})" + + def __tensor_flatten__(self): + return ["_data"], {"dtype": self._dtype} + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + return ScaledGroupedMMTensor( + inner_tensors["_data"], + flatten_spec["dtype"], + ) + def fsdp_pre_all_gather(self, mesh): - metadata = () - return (self._data,), metadata + all_gather_inputs = (self._data,) + all_gather_metadata = () + return all_gather_inputs, all_gather_metadata def fsdp_post_all_gather( self, @@ -124,6 +155,6 @@ def fsdp_post_all_gather( out: Optional[torch.Tensor] = None, ): (data,) = all_gather_outputs - return ScaledGroupedMMTensor( - data, - ), (data,) + out = ScaledGroupedMMTensor(data, param_dtype) + inner_tensors = (data,) + return out, inner_tensors