Description
🐛 Describe the bug
I am trying a custom backend in dynamo, frame like this:
def x_compile_backend(
gm: torch.fx.GraphModule, example_inputs: Sequence[Any], **kwargs
) -> torch.nn.Module:
try:
fake_mode = detect_fake_mode(example_inputs)
with unittest.mock.patch.object(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
gm = apply_pre_aot_passes(gm)
gm = aot_export_joint_simple(
gm,
example_inputs,
trace_joint=False,
)
When i use rms_norm for test, get below:
dynamo.compiler:graph():
%l_args_0_ : torch.Tensor [num_users=1] = placeholder[target=L_args_0_]
%fn____parameters__weight : [num_users=1] = get_attr[target=fn____parameters__weight]
%rms_norm : [num_users=1] = call_function[target=torch.rms_norm](args = (%l_args_0_, (5,), %fn____parameters__weight, None), kwargs = {})
return (rms_norm,)
forward:
def forward(self, L_args_0_ : torch.Tensor):
l_args_0_ = L_args_0_
fn____parameters__weight = self.fn____parameters__weight
rms_norm = torch.rms_norm(l_args_0_, (5,), fn____parameters__weight, None); l_args_0_ = fn____parameters__weight = None
return (rms_norm,)
After pre-AOT passes:
graph():
%l_args_0_ : torch.Tensor [num_users=1] = placeholder[target=L_args_0_]
%fn____parameters__weight : [num_users=1] = get_attr[target=fn____parameters__weight]
%rms_norm : [num_users=1] = call_function[target=torch.rms_norm](args = (%l_args_0_, (5,), %fn____parameters__weight, None), kwargs = {})
return (rms_norm,)
The error is like this:
torch/_functorch/_aot_autograd/functional_utils.py", line 432, in assert_functional_graph
not n.target._schema.is_mutable
AssertionError: aot_autograd expected to have an entirely functional graph, but found %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Scalar](args = (%mean, 1.1920928955078125e-07), kwargs = {})
Does any one know how to fix this?
The debug graph i get is:
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%arg0_1, 2), kwargs = {})
%mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [1], True), kwargs = {})
%add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Scalar](args = (%mean, 1.1920928955078125e-07), kwargs = {})
%rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_,), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, %rsqrt), kwargs = {})
%_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %_param_constant0), kwargs = {})
return (mul_1,)
Versions
Collecting environment information...
PyTorch version: 2.5.1+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (GCC) 14.1.0
Clang version: Could not collect
CMake version: version 3.28.4
Libc version: glibc-2.31
Python version: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-200-generic-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8369HB CPU @ 3.30GHz
Stepping: 11
CPU MHz: 3800.026
CPU max MHz: 4200.0000
CPU min MHz: 1200.0000
BogoMIPS: 6600.06
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 66 MiB
NUMA node0 CPU(s): 0-47
NUMA node1 CPU(s): 48-95
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Vulnerable
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Retpoline
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 ida arat avx512_vnni
Versions of relevant libraries:
[pip3] flake8==3.8.2
[pip3] flake8-bugbear==20.1.4
[pip3] flake8-comprehensions==3.3.0
[pip3] flake8-executable==2.0.4
[pip3] flake8-pyi==20.5.0
[pip3] intel_extension_for_pytorch==2.6.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.5.1+cpu
[pip3] torchvision==0.20.1+cpu
[pip3] triton==3.2.0
[conda] intel-extension-for-pytorch 2.6.0 pypi_0 pypi
[conda] numpy 2.1.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] torch 2.5.1+cpu pypi_0 pypi
[conda] torchvision 0.20.1+cpu pypi_0 pypi
[conda] triton 3.2.0 pypi_0 pypi
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames