8000 How can I use inductor aot_compile to support a MoE network? · Issue #148747 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
How can I use inductor aot_compile to support a MoE network? #148747
Closed
@sujuyu

Description

@sujuyu

🚀 The feature, motivation and pitch

Deepseek has sparked a wave of enthusiasm for the design of Moe (Mixture of Experts) network architectures. I am often asked how to accelerate the inference of an Moe network. Undoubtedly, I thought of using Inductor's aot_compile to compile it into a dynamic library and then calling it in C++ for acceleration.

Unfortunately, the process of selecting experts in Moe is different from that of a typical dense network. This part of the syntax is more like an extension of PyTorch, closer to Python's syntax, and cannot be traced. Below is a simple demo I wrote. I would like to know if the developers of Inductor have any plans to support Moe networks?

import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Expert, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

class MoE(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts=10, top_k=2):
        super(MoE, self).__init__()
        # Eight experts for gating
        self.other_experts = nn.ModuleList([Expert(input_dim, output_dim) for _ in range(num_experts - 2)])
        # Gate network to choose top_k experts
        self.gate = nn.Linear(input_dim, num_experts - 2)
        # Final output layer
        self.final_linear = nn.Linear((top_k) * output_dim, output_dim)

    def forward(self, x):
        # Compute gating scores
        gate_scores = self.gate(x)
        topk_scores, topk_indices = torch.topk(gate_scores, 2, dim=-1)
        
        # Collect outputs from selected experts based on gating
        selected_expert_outputs = torch.stack(
            [torch.stack([self.other_experts[i](x[idx])  for i in topk_indice], dim = 0) for idx, topk_indice in enumerate(topk_indices)], dim=0
        )

        # Flatten and pass through final linear layer
        all_expert_outputs = selected_expert_outputs.view(x.s
6D57
ize(0), -1)
        output = self.final_linear(all_expert_outputs) 
        return output


if __name__ == "__main__":
    # Example usage
    input_dim = 128
    output_dim = 64
    moe = MoE(input_dim, output_dim)

    x = torch.randn(32, input_dim)  # Batch size of 32
    output = moe(x)
    print(output.shape)  # Expected output shape: [32, 64]


    export_model = torch.export.export(
        mod=moe,
        args=tuple([torch.randn(32, input_dim)]),
        dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=1024)}},
    )

Alternatives

No response

Additional context

No response

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @desertfire @chenyang78 @yushangdi

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0