8000 Inconsistent export behavior for nonzero+grid_sample between CUDA and CPU/MPS backends · Issue #152791 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Inconsistent export behavior for nonzero+grid_sample between CUDA and CPU/MPS backends #152791
Open
@sachin-skyline

Description

@sachin-skyline

🐛 Describe the bug

I am trying to export a model that contains a nonzero call followed by a grid_sample (for use in aoti_compile_and_package). When exporting for cpu or mps, no error is thrown, but when using cuda, "torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(2u0, 0) (unhinted: Eq(2u0, 0)). (Size-like symbols: u0)" is thrown.

Adding manual checks in the user model definition solves the issue, like

torch._check(grid.shape[1] > 0)
torch._check(grid.shape[1] + (grid.shape[1] - 1) % grid.shape[1] < 2147483647)

but in actual model code, a couple more checks are required. I would expect the need for user manual checks to be consistent between the cpu/mps and cuda backends. In this case, no checks required being the correct, or at least ideal behavior.

Reproducer:

import torch

class Test(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, pos):
        pos = torch.nonzero(pos[0, 0] > 0.5)
        grid = pos.to(x.dtype)[None, :, None, :]
        return torch.nn.functional.grid_sample(x, grid, align_corners=False)

device = "cuda"

torch.export.export(
    Test(),
    (
        torch.randn(1, 1, 16, 16, device=device),
        torch.randn(1, 1, 32, 32, device=device),
    ),
)

Ran with TORCHDYNAMO_VERBOSE=1 produces:

W0504 23:29:04.745000 3388865 site-packages/torch/fx/experimental/symbolic_shapes.py:6679] [0/0] failed during evaluate_expr(Eq(2*u0, 0), hint=None, size_oblivious=False, forcing_spec=False
E0504 23:29:04.746000 3388865 site-packages/torch/fx/experimental/recording.py:299] [0/0] failed while running evaluate_expr(*(Eq(2*u0, 0), None, False, False), **{})
Traceback (most recent call last):
  File "site-packages/torch/_dynamo/utils.py", line 3284, in run_node
    return node.target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/nn/functional.py", line 5023, in grid_sample
    return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/sym_node.py", line 536, in guard_bool
    r = self.evaluate()
        ^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6655, in evaluate_sym_node
    return self.evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr
    return self._evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(2*u0, 0) (unhinted: Eq(2*u0, 0)).  (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Caused by: return torch.nn.functional.grid_sample(x, grid, align_corners=False)  # repro.py:10 in forward (nn/functional.py:5023 in grid_sample)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "repro.py", line 10, in forward
    return torch.nn.functional.grid_sample(x, grid, align_corners=False)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "site-packages/torch/_dynamo/utils.py", line 3127, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/utils.py", line 2641, in wrap_fake_exception
    return fn()
           ^^^^
  File "site-packages/torch/_dynamo/utils.py", line 3128, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/utils.py", line 3325, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "site-packages/torch/_dynamo/utils.py", line 3284, in run_node
    return node.target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/nn/functional.py", line 5023, in grid_sample
    return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/sym_node.py", line 536, in guard_bool
    r = self.evaluate()
        ^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6655, in evaluate_sym_node
    return self.evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr
    return self._evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr
    raise self._make_data_dependent_error(
RuntimeError: Dynamo failed to run FX node with fake tensors: call_function <function grid_sample at 0x7a11819c1080>(*(FakeTensor(..., device='cuda:0', size=(1, 1, 16, 16)), FakeTensor(..., device='cuda:0', size=(1, u0, 1, 2))), **{'align_corners': False}): got GuardOnDataDependentSymNode('Could not guard on data-dependent expression Eq(2*u0, 0) (unhinted: Eq(2*u0, 0)).  (Size-like symbols: u0)\n\nATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.\nMaybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\nCaused by: return torch.nn.functional.grid_sample(x, grid, align_corners=False)  # repro.py:10 in forward (nn/functional.py:5023 in grid_sample)\nFor more information, run with TORCH_LOGS="dynamic"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nUser Stack (most recent call last):\n  (snipped, see stack below for prefix)\n  File "repro.py", line 10, in forward\n    return torch.nn.functional.grid_sample(x, grid, align_corners=False)\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1')

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "repro.py", line 14, in <module>
    torch.export.export(
  File "site-packages/torch/export/__init__.py", line 360, in export
    return _export(
           ^^^^^^^^
  File "site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/export/_trace.py", line 2112, in _export
    ep = _export_for_training(
         ^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/export/_trace.py", line 1975, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/eval_frame.py", line 1677, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
    return _compile(
           ^^^^^^^^^
  File "site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/convert_frame.py", line 797, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/bytecode_transformation.py", line 1422, in transform_code_object
    transformations(instructions, code_options)
  File "site-packages/torch/_dynamo/convert_frame.py", line 257, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/convert_frame.py", line 715, in transform
    tracer.run()
  File "site-packages/torch/_dynamo/symbolic_convert.py", line 3500, in run
    super().run()
  File "site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
    while self.step():
          ^^^^^^^^^^^
  File "site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "site-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/symbolic_convert.py", line 2933, in CALL
    self._call(inst)
  File "site-packages/torch/_dynamo/symbolic_convert.py", line 2927, in _call
    self.call_function(fn, args, kwargs)
  File "site-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/variables/torch.py", line 1181, in call_function
    tensor_variable = wrap_fx_proxy(
                      ^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/variables/builder.py", line 2302, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/variables/builder.py", line 2368, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
           ^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/variables/builder.py", line 2464, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch/_dynamo/utils.py", line 3214, in get_fake_value
    raise UserError(  # noqa: B904
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(2*u0, 0) (unhinted: Eq(2*u0, 0)).  (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Caused by: return torch.nn.functional.grid_sample(x, grid, align_corners=False)  # repro.py:10 in forward (nn/functional.py:5023 in grid_sample)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "repro.py", line 10, in forward
    return torch.nn.functional.grid_sample(x, grid, align_corners=False)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
   File "repro.py", line 10, in forward
    return torch.nn.functional.grid_sample(x, grid, align_corners=False)

Versions

Collecting environment information...
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.3 (main, Feb  4 2025, 14:48:35) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-1027-gcp-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               12
On-line CPU(s) list:                  0-11
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family:                           6
Model:                                85
Thread(s) per core:                   2
Core(s) per socket:                   6
Socket(s):                            1
Stepping:                             7
BogoMIPS:                             4400.37
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            192 KiB (6 instances)
L1i cache:                            192 KiB (6 instances)
L2 cache:                             6 MiB (6 instances)
L3 cache:                             38.5 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-11
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] torch==2.7.0
[pip3] triton==3.3.0
[conda] Could not collect

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0