8000 [Dynamo][Custombackend]: rms_norm find inplace op when using aot_export_joint_simple · Issue #154195 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[Dynamo][Custombackend]: rms_norm find inplace op when using aot_export_joint_simple #154195
Open
@GodHforever

Description

@GodHforever

🐛 Describe the bug

I am trying a custom backend in dynamo, frame like this:

def x_compile_backend(
    gm: torch.fx.GraphModule, example_inputs: Sequence[Any], **kwargs
) -> torch.nn.Module:

    try:
        fake_mode = detect_fake_mode(example_inputs)

        with unittest.mock.patch.object(
            fake_mode, "allow_non_fake_inputs", True
        ), fake_mode:
            gm = apply_pre_aot_passes(gm)
            gm = aot_export_joint_simple(
                gm,
                example_inputs,
                trace_joint=False,
            )

When i use rms_norm for test, get below:

dynamo.compiler:graph():
    %l_args_0_ : torch.Tensor [num_users=1] = placeholder[target=L_args_0_]
    %fn____parameters__weight : [num_users=1] = get_attr[target=fn____parameters__weight]
    %rms_norm : [num_users=1] = call_function[target=torch.rms_norm](args = (%l_args_0_, (5,), %fn____parameters__weight, None), kwargs = {})
    return (rms_norm,)

forward:
def forward(self, L_args_0_ : torch.Tensor):
    l_args_0_ = L_args_0_
    fn____parameters__weight = self.fn____parameters__weight
    rms_norm = torch.rms_norm(l_args_0_, (5,), fn____parameters__weight, None);  l_args_0_ = fn____parameters__weight = None
    return (rms_norm,)

After pre-AOT passes: 
graph():
    %l_args_0_ : torch.Tensor [num_users=1] = placeholder[target=L_args_0_]
    %fn____parameters__weight : [num_users=1] = get_attr[target=fn____parameters__weight]
    %rms_norm : [num_users=1] = call_function[target=torch.rms_norm](args = (%l_args_0_, (5,), %fn____parameters__weight, None), kwargs = {})
    return (rms_norm,)

The error is like this:

torch/_functorch/_aot_autograd/functional_utils.py", line 432, in assert_functional_graph
    not n.target._schema.is_mutable
AssertionError: aot_autograd expected to have an entirely functional graph, but found %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Scalar](args = (%mean, 1.1920928955078125e-07), kwargs = {})

Does any one know how to fix this?
The debug graph i get is:

graph():
    %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
    %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%arg0_1, 2), kwargs = {})
    %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [1], True), kwargs = {})
    %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Scalar](args = (%mean, 1.1920928955078125e-07), kwargs = {})
    %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_,), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, %rsqrt), kwargs = {})
    %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %_param_constant0), kwargs = {})
    return (mul_1,)

Versions

Collecting environment information...
PyTorch version: 2.5.1+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (GCC) 14.1.0
Clang version: Could not collect
CMake version: version 3.28.4
Libc version: glibc-2.31

Python version: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-200-generic-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8369HB CPU @ 3.30GHz
Stepping: 11
CPU MHz: 3800.026
CPU max MHz: 4200.0000
CPU min MHz: 1200.0000
BogoMIPS: 6600.06
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 66 MiB
NUMA node0 CPU(s): 0-47
NUMA node1 CPU(s): 48-95
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Vulnerable
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Retpoline
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 ida arat avx512_vnni

Versions of relevant libraries:
[pip3] flake8==3.8.2
[pip3] flake8-bugbear==20.1.4
[pip3] flake8-comprehensions==3.3.0
[pip3] flake8-executable==2.0.4
[pip3] flake8-pyi==20.5.0
[pip3] intel_extension_for_pytorch==2.6.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.5.1+cpu
[pip3] torchvision==0.20.1+cpu
[pip3] triton==3.2.0
[conda] intel-extension-for-pytorch 2.6.0 pypi_0 pypi
[conda] numpy 2.1.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] torch 2.5.1+cpu pypi_0 pypi
[conda] torchvision 0.20.1+cpu pypi_0 pypi
[conda] triton 3.2.0 pypi_0 pypi

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0