8000 Enabling MOE Quantization using linear decomposition by HDCharles · Pull Request #2043 · pytorch/ao · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Enabling MOE Quantization using linear decomposition #2043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
May 8, 2025
Merged

Conversation

HDCharles
Copy link
Contributor
@HDCharles HDCharles commented Apr 11, 2025

Enabling MOE Quantization using linear decomposition

Summary: This PR is a first step at optimizing moe inference using
torchAO. The goal for this step is to enable existing quantization
kernels and workflows to work for moe quantization by decomposing the
group gemm into a sequence of unbalanced linear ops that can use the
existing quantized kernels. To enable this we had to add support for
quantizing these 3D tensors as well as slicing and indexing. 2 methods
of achieving this were implemented. for int8wo, int8dq, int4wo, fp8wo,
fp8dq, the underlying quantized tensor subclass was adapted to both
support 3D tensors, indexing and slicing, as well as an updated
transformation function that can handle the
ConditionalFeedForwardAOQuantizable modules if the filter funciton in
quantize_ is used to target the aforementioned module. For some complex kernels
which use packed data that couldn't be made to easily work in 3D, we
also added FakeExtraDimTensor which can transform any
quantized tensor subclass into supporting the necessary slice and index
operations for moe quantization. This option is enabled by using
MoeQuantConfig.

This can be applied to huggingface llama4 for instance as shown int he
llama4_quant.py example. Since the hf moe module is implemented in a way
that's not condusive to quantization, it first requires a module swap to
the MOEFeedForwardAOQuantizable.

# API
from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig
from torchao.quantization.quant_api import quantize_, Int8WeightOnlyConfig

 quantize_(model, MoEQuantConfig(Int8WeightOnlyConfig()), filter_fn=cond_ffn_filter)
 model=torch.compile(model, mode="reduce-overhead", fullgraph=is_single_token_inference)
 # there's unavoidable graph break in multi token inference during token shuffling

note the MoEQuantConfig has an attribute use_fake_extra_dim_tensor that defaults to AS_FALLBACK but can be set to TRUE or FALSE as well to either automatically enable or disable that option.

Benchmarks:

batchsize 1 batchsize 8
Technique tok/s memory (GB) tok/s tok/s* batch memory (GB)
None 78.35 93.76 18.2 145.64 94.12
int8wo-fake 6.14 49.13 5.01 40.09 49.23
int8wo-base 98.4 48.87 4.94 39.56 49.2
int4wo-fake 14.25 30.21 11.84 94.75 30.19
int4wo-base 79.38 36.15 10.29 82.29 36.12
fp8wo-fake 3.2 50.31 2.88 23.08 50.29
fp8wo-base 59.41 52.07 2.98 23.81 52.05
fp8dq-fake 9.78 50.92 4.08 32.61 50.89
fp8dq-base 45.92 53.97 3.78 30.23 53.94

Note: its unclear why the memory usage increases when going from batchsize 1 to batchsize 8, my supposition is that because batchsize 1 can be compiled with fullgraph, the kernels are likely to use heavier fusion which can result in larger kernels.

Memory Profile of Quantization of MoE Modules:
image

note: quantization can be done on cpu as in llama4_quant.py so that if only the int4 model fits on cuda you can still run the model. this memory trace was taken from the mixtral moe memory profile by first loading the whole model to cuda.

Test Plan:
python test/quantization/test_moe_quant.py

python
test/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
-k "test_moe_quant_intx"

sh torchao/_models/mixtral-moe/run.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link
pytorch-bot bot commented Apr 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2043

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 1 Unrelated Failure

As of commit d927f06 with merge base b01514c (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 11, 2025
@HDCharles HDCharles force-pushed the moe_quant branch 4 times, most recently from bc69bce to e510be8 Compare May 6, 2025 09:13 8000
@HDCharles HDCharles changed the title Enabling MOE Quantization using linear decomposition [WIP] Enabling MOE Quantization using linear decomposition May 6, 2025
@HDCharles HDCharles added the topic: new feature Use this tag if this PR adds a new feature label May 6, 2025
@HDCharles HDCharles force-pushed the moe_quant branch 3 times, most recently from cde56ad to 7f56c87 Compare May 6, 2025 16:57

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import itertools
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need all this?

I understand it is good to have an MOE model but is it important to have our own definition here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we've discussed adding a moe model for a while and this was the main one for testing and benchmarking and demonstrating the technique.

@@ -490,10 +492,50 @@ def _(func, types, args, kwargs):
self.quant_max,
self.zero_point_domain,
dtype=self.dtype,
strides=self.stride(),
strides=self.stride() if len(block_size)==2 else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this?

Copy link
Contributor Author
@HDCharles HDCharles May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the stride can change for 3 dimensional tensors when doing slicing and indexing ops. I basically just didn't want ot affect the existing 2d implementation functionally while making the 3D one work.

@HDCharles HDCharles force-pushed the moe_quant branch 3 times, most recently from 58d2104 to edf7488 Compare May 7, 2025 16:40
@@ -416,11 +452,16 @@ def block_size(self):

scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
cur_shape = self.shape
assert len(cur_shape) == 4
if len(cur_shape) == 5:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jerryzh168 since you are also making some changes to how slicing is handled

@@ -47,5 +47,6 @@ def _transform(
@functools.wraps(config_type)
def decorator(func):
_QUANTIZE_CONFIG_HANDLER[config_type] = func
return func # needed to make the functions usable externally
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you wanted to use any of the transform functions in quant_api.py that use the decorator, it would say that the function is actually just None since nothing was being returned after the function was passed into the decorator. They were only stored in the transform function dict so they were only usable through quantize_. This solves that problem. Since for MoE we use these transform functions, this was necessary and doesn't affect anything.

@HDCharles HDCharles force-pushed the moe_quant branch 2 times, most recently from 499e3d7 to 63bb0ae Compare May 8, 2025 04:23
@HDCharles HDCharles force-pushed the moe_quant branch 6 times 8000 , most recently from 72fec5c to c024f5d Compare May 8, 2025 06:52
HDCharles added 17 commits May 8, 2025 11:07
Summary: This PR is a first step at optimizing moe inference using
torchAO. The goal for this step is to enable existing quantization
kernels and workflows to work for moe quantization by decomposing the
group gemm into a sequence of unbalanced linear ops that can use the
existing quantized kernels. To enable this we had to add support for
quantizing these 3D tensors as well as slicing and indexing. 2 methods
of achieving this were implemented. for int8wo, int8dq, int4wo, fp8wo,
fp8dq, the underlying quantized tensor subclass was adapted to both
support 3D tensors, indexing and slicing, as well as an updated
transformation function that can handle the
ConditionalFeedForwardAOQuantizable modules if the filter funciton in
quantize_ is used to target the aforementioned module. For some complex kernels
which use packed data that couldn't be made to easily work in 3D, we
also added FakeExtraDimTensor which can transform any
quantized tensor subclass into supporting the necessary slice and index
operations for moe quantization. This option is enabled by using
MoeQuantConfig.

This can be applied to huggingface llama4 for instance as shown int he
llama4_quant.py example. Since the hf moe module is implemented in a way
that's not condusive to quantization, it first requires a module swap to
the MOEFeedForwardAOQuantizable.

TODO final benchmark numbers from run.sh, consolidate 3x implementation
of MOEFeedForwardAOQuantizable and ConditionalFeedForwardAOQuantizable.
verify hqq

Test Plan:
python test/quantization/test_moe_quant.py

python
test/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
-k "test_moe_quant_intx"

sh torchao/_models/mixtral-moe/run.sh

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@HDCharles HDCharles merged commit 7192edf into main May 8, 2025
16 of 18 checks passed
andrewor14 pushed a commit that referenced this pull request May 9, 2025
* Enabling MOE Quantization using linear decomposition

Summary: This PR is a first step at optimizing moe inference using
torchAO. The goal for this step is to enable existing quantization
kernels and workflows to work for moe quantization by decomposing the
group gemm into a sequence of unbalanced linear ops that can use the
existing quantized kernels. To enable this we had to add support for
quantizing these 3D tensors as well as slicing and indexing. 2 methods
of achieving this were implemented. for int8wo, int8dq, int4wo, fp8wo,
fp8dq, the underlying quantized tensor subclass was adapted to both
support 3D tensors, indexing and slicing, as well as an updated
transformation function that can handle the
ConditionalFeedForwardAOQuantizable modules if the filter funciton in
quantize_ is used to target the aforementioned module. For some complex kernels
which use packed data that couldn't be made to easily work in 3D, we
also added FakeExtraDimTensor which can transform any
quantized tensor subclass into supporting the necessary slice and index
operations for moe quantization. This option is enabled by using
MoeQuantConfig.

This can be applied to huggingface llama4 for instance as shown int he
llama4_quant.py example. Since the hf moe module is implemented in a way
that's not condusive to quantization, it first requires a module swap to
the MOEFeedForwardAOQuantizable.

TODO final benchmark numbers from run.sh, consolidate 3x implementation
of MOEFeedForwardAOQuantizable and ConditionalFeedForwardAOQuantizable.
verify hqq

Test Plan:
python test/quantization/test_moe_quant.py

python
test/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
-k "test_moe_quant_intx"

sh torchao/_models/mixtral-moe/run.sh

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing CI

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing CI

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing CI

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* lint

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* remove test code

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing exp test

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing experimental test

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing experimental CI

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing generate.py device stuff

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing tests that aren't skipping

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* ruff format

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* removing test code

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing CI

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* update API and remove branching on quant_api.py transform functions

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* ruff format

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fix weird ci error

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* remove change to test_integration.py

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0