8000 [float8 moe training] Add TP support by danielvegamyhre · Pull Request #2425 · pytorch/ao · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[float8 moe training] Add TP support #2425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Conversation

danielvegamyhre
Copy link
Contributor
@danielvegamyhre danielvegamyhre commented Jun 23, 2025

Note: this should be merged AFTER this bug fix: #2451 I will rebase and retest all of this once that's merged.

Summary

  • Add TP support for routed experts and shared expert.
    • Make target dim of scale squeze() ops explicit to handle both 2D and 3D "A" tensors (routed experts case has 2D "A", shared expert has 3D "A").
    • Make offs optional to handle shared_expert case where num_experts=1 (scaled grouped GEMM only processing 1 expert)
  • Add debug logging

Test plan

  • Added integration test using torchtitan llama4 TP implementation. Test cases for (1) routed experts, and (2) routed experts + shared expert.
  • Manual testing with torchtitan llama4 debug model with TP=2, targeting routed experts AND shared experts works (logs).
  • Manual testing with torch titan llama4 debug model with FSDP=2 + TP=2 confirms this 2D parallelism is working for routed experts (logs)

Limitations

  • 2D parallel witih FDSP+TP for shared experts is not yet supported yet (see comment) below). Need to debug this, which I will do in a subsequent PR.

Copy link
pytorch-bot bot commented Jun 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2425

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures

As of commit 29be4b2 with merge base 2898903 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 23, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft June 23, 2025 17:10
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jun 23, 2025
@danielvegamyhre danielvegamyhre changed the title [float8 moe training] TP support for routed experts [float8 moe training] Add TP support Jun 23, 2025
@danielvegamyhre
Copy link
Contributor Author
danielvegamyhre commented Jun 23, 2025

Error with FSDP=2, TP=2 targeting both routed experts AND shared expert. The issue is specific to the shared expert using 2D parallelism. Will debug and resolve in separate PR.

The logs are a bit confusing, I first see an error in meta registration that "B" tensor is fp32 instead of bf16. This is odd, since I'm not using torch.compile and I thought meta registrations were only used for compile.

File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_meta_registrations.py", line 7527, in _meta_grouped_mm_common
...
RuntimeError: Expected inputs of BF16 type but got mat_a.dtype=torch.bfloat16 and mat_b.dtype=torch.float32.

Then a few lines later, I see my log lines during the forward pass, just before the grouped mm, confirming the "B" tensor (W1) is bf16, not fp32:

[rank0]:X dtype: torch.bfloat16
[rank0]:W1 dtype: torch.bfloat16
[rank0]:W1 type: <class 'torch.distributed.tensor.DTensor'>
[rank0]:W1.to_local() type: <class 'torchao.prototype.moe_training.tensor.ScaledGroupedMMTensor'>
[rank0]:W1.to_local() dtype: torch.bfloat16

(as an aside, it's strange these log lines appear AFTER the error has already occured (?). I assume it must be due to how log writes are buffered)

Then at the end of the logs, I see a different error related to strides/sizes not matching a storage of size 0, but i'm guessing this a downstream affect of the first error:

RuntimeError: setStorage: sizes [1, 256, 256], strides [65536, 256, 1], storage offset 0, and itemsize 2 requiring a storage size of 131072 are out of bounds for storage of size 0

Full logs: https://www.internalfb.com/phabricator/paste/view/P1850071143

@danielvegamyhre danielvegamyhre force-pushed the optional-offs branch 2 times, most recently from 80cf6d4 to 44778d0 Compare June 23, 2025 21:40
@danielvegamyhre danielvegamyhre changed the title [float8 moe training] Add TP support [float8 moe training] Add TP and FSDP+TP support Jun 24, 2025
@danielvegamyhre danielvegamyhre changed the title [float8 moe training] Add TP and FSDP+TP support [float8 moe training] Add TP support Jun 24, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review June 24, 2025 17:32
@danielvegamyhre danielvegamyhre force-pushed the optional-offs branch 2 times, most recently from e4ff51d to 074b423 Compare June 24, 2025 17:36
@danielvegamyhre
Copy link
Contributor Author

cc @drisspg @vkuzo for review

fyi @tianyu-l @lessw2020 @ngimel for awareness as well

@danielvegamyhre danielvegamyhre requested review from vkuzo and drisspg June 24, 2025 17:36
from torch.nn import functional as F

# this feature requires CUDA and SM89+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit we have some helpers for this in ao/utils


# this test requires torchtitan
try:
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_tp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add this test to test_float8 ->

dist.destroy_process_group()


def _validate_model_conversion(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did I review another PR that had teh same util? if so maybe put into torchao.testing so we can reuse

return device_mesh


def apply_moe_tp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is always specific to module structure e.g. the fqn's right?

@@ -8,6 +14,8 @@
register_quantize_module_handler,
)

logger: logging.Logger = logging.getLogger(__name__)
Copy link
Contributor
@drisspg drisspg Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note, we should setup better logging in torchao
alas: https://docs.python.org/3/howto/logging.html#configuring-logging-for-a-library

just getting the root logger going w/ null handler

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0