Description
🚀 The feature, motivation and pitch
Deepseek has sparked a wave of enthusiasm for the design of Moe (Mixture of Experts) network architectures. I am often asked how to accelerate the inference of an Moe network. Undoubtedly, I thought of using Inductor's aot_compile to compile it into a dynamic library and then calling it in C++ for acceleration.
Unfortunately, the process of selecting experts in Moe is different from that of a typical dense network. This part of the syntax is more like an extension of PyTorch, closer to Python's syntax, and cannot be traced. Below is a simple demo I wrote. I would like to know if the developers of Inductor have any plans to support Moe networks?
import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
def __init__(self, input_dim, output_dim):
super(Expert, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
class MoE(nn.Module):
def __init__(self, input_dim, output_dim, num_experts=10, top_k=2):
super(MoE, self).__init__()
# Eight experts for gating
self.other_experts = nn.ModuleList([Expert(input_dim, output_dim) for _ in range(num_experts - 2)])
# Gate network to choose top_k experts
self.gate = nn.Linear(input_dim, num_experts - 2)
# Final output layer
self.final_linear = nn.Linear((top_k) * output_dim, output_dim)
def forward(self, x):
# Compute gating scores
gate_scores = self.gate(x)
topk_scores, topk_indices = torch.topk(gate_scores, 2, dim=-1)
# Collect outputs from selected experts based on gating
selected_expert_outputs = torch.stack(
[torch.stack([self.other_experts[i](x[idx]) for i in topk_indice], dim = 0) for idx, topk_indice in enumerate(topk_indices)], dim=0
)
# Flatten and pass through final linear layer
all_expert_outputs = selected_expert_outputs.view(x.s
6D57
ize(0), -1)
output = self.final_linear(all_expert_outputs)
return output
if __name__ == "__main__":
# Example usage
input_dim = 128
output_dim = 64
moe = MoE(input_dim, output_dim)
x = torch.randn(32, input_dim) # Batch size of 32
output = moe(x)
print(output.shape) # Expected output shape: [32, 64]
export_model = torch.export.export(
mod=moe,
args=tuple([torch.randn(32, input_dim)]),
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=1024)}},
)
Alternatives
No response
Additional context
No response
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @desertfire @chenyang78 @yushangdi