Description
🐛 Describe the bug
I am trying to export
a model that contains a nonzero
call followed by a grid_sample
(for use in aoti_compile_and_package
). When exporting for cpu or mps, no error is thrown, but when using cuda, "torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(2u0, 0) (unhinted: Eq(2u0, 0)). (Size-like symbols: u0)" is thrown.
Adding manual checks in the user model definition solves the issue, like
torch._check(grid.shape[1] > 0)
torch._check(grid.shape[1] + (grid.shape[1] - 1) % grid.shape[1] < 2147483647)
but in actual model code, a couple more checks are required. I would expect the need for user manual checks to be consistent between the cpu/mps and cuda backends. In this case, no checks required being the correct, or at least ideal behavior.
Reproducer:
import torch
class Test(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, pos):
pos = torch.nonzero(pos[0, 0] > 0.5)
grid = pos.to(x.dtype)[None, :, None, :]
return torch.nn.functional.grid_sample(x, grid, align_corners=False)
device = "cuda"
torch.export.export(
Test(),
(
torch.randn(1, 1, 16, 16, device=device),
torch.randn(1, 1, 32, 32, device=device),
),
)
Ran with TORCHDYNAMO_VERBOSE=1
produces:
W0504 23:29:04.745000 3388865 site-packages/torch/fx/experimental/symbolic_shapes.py:6679] [0/0] failed during evaluate_expr(Eq(2*u0, 0), hint=None, size_oblivious=False, forcing_spec=False
E0504 23:29:04.746000 3388865 site-packages/torch/fx/experimental/recording.py:299] [0/0] failed while running evaluate_expr(*(Eq(2*u0, 0), None, False, False), **{})
Traceback (most recent call last):
File "site-packages/torch/_dynamo/utils.py", line 3284, in run_node
return node.target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/nn/functional.py", line 5023, in grid_sample
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/sym_node.py", line 536, in guard_bool
r = self.evaluate()
^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate
return self.shape_env.evaluate_sym_node(self, size_oblivious)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6655, in evaluate_sym_node
return self.evaluate_expr(
^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
return retlog(fn(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr
return self._evaluate_expr(
^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(2*u0, 0) (unhinted: Eq(2*u0, 0)). (Size-like symbols: u0)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Caused by: return torch.nn.functional.grid_sample(x, grid, align_corners=False) # repro.py:10 in forward (nn/functional.py:5023 in grid_sample)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
User Stack (most recent call last):
(snipped, see stack below for prefix)
File "repro.py", line 10, in forward
return torch.nn.functional.grid_sample(x, grid, align_corners=False)
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "site-packages/torch/_dynamo/utils.py", line 3127, in get_fake_value
ret_val = wrap_fake_exception(
^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/utils.py", line 2641, in wrap_fake_exception
return fn()
^^^^
File "site-packages/torch/_dynamo/utils.py", line 3128, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/utils.py", line 3325, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "site-packages/torch/_dynamo/utils.py", line 3284, in run_node
return node.target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/nn/functional.py", line 5023, in grid_sample
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/sym_node.py", line 536, in guard_bool
r = self.evaluate()
^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate
return self.shape_env.evaluate_sym_node(self, size_oblivious)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6655, in evaluate_sym_node
return self.evaluate_expr(
^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
return retlog(fn(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr
return self._evaluate_expr(
^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr
raise self._make_data_dependent_error(
RuntimeError: Dynamo failed to run FX node with fake tensors: call_function <function grid_sample at 0x7a11819c1080>(*(FakeTensor(..., device='cuda:0', size=(1, 1, 16, 16)), FakeTensor(..., device='cuda:0', size=(1, u0, 1, 2))), **{'align_corners': False}): got GuardOnDataDependentSymNode('Could not guard on data-dependent expression Eq(2*u0, 0) (unhinted: Eq(2*u0, 0)). (Size-like symbols: u0)\n\nATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.\nMaybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\nCaused by: return torch.nn.functional.grid_sample(x, grid, align_corners=False) # repro.py:10 in forward (nn/functional.py:5023 in grid_sample)\nFor more information, run with TORCH_LOGS="dynamic"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nUser Stack (most recent call last):\n (snipped, see stack below for prefix)\n File "repro.py", line 10, in forward\n return torch.nn.functional.grid_sample(x, grid, align_corners=False)\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1')
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "repro.py", line 14, in <module>
torch.export.export(
File "site-packages/torch/export/__init__.py", line 360, in export
return _export(
^^^^^^^^
File "site-packages/torch/export/_trace.py", line 1092, in wrapper
raise e
File "site-packages/torch/export/_trace.py", line 1065, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/export/exported_program.py", line 121, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/export/_trace.py", line 2112, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/export/_trace.py", line 1092, in wrapper
raise e
File "site-packages/torch/export/_trace.py", line 1065, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/export/exported_program.py", line 121, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/export/_trace.py", line 1975, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/eval_frame.py", line 1677, in inner
result_traced = opt_f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
return _compile(
^^^^^^^^^
File "site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/convert_frame.py", line 797, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/bytecode_transformation.py", line 1422, in transform_code_object
transformations(instructions, code_options)
File "site-packages/torch/_dynamo/convert_frame.py", line 257, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/convert_frame.py", line 715, in transform
tracer.run()
File "site-packages/torch/_dynamo/symbolic_convert.py", line 3500, in run
super().run()
File "site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
while self.step():
^^^^^^^^^^^
File "site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
self.dispatch_table[inst.opcode](self, inst)
File "site-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/symbolic_convert.py", line 2933, in CALL
self._call(inst)
File "site-packages/torch/_dynamo/symbolic_convert.py", line 2927, in _call
self.call_function(fn, args, kwargs)
File "site-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/variables/torch.py", line 1181, in call_function
tensor_variable = wrap_fx_proxy(
^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/variables/builder.py", line 2302, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/variables/builder.py", line 2368, in wrap_fx_proxy_cls
return _wrap_fx_proxy(
^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/variables/builder.py", line 2464, in _wrap_fx_proxy
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch/_dynamo/utils.py", line 3214, in get_fake_value
raise UserError( # noqa: B904
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(2*u0, 0) (unhinted: Eq(2*u0, 0)). (Size-like symbols: u0)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Caused by: return torch.nn.functional.grid_sample(x, grid, align_corners=False) # repro.py:10 in forward (nn/functional.py:5023 in grid_sample)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
User Stack (most recent call last):
(snipped, see stack below for prefix)
File "repro.py", line 10, in forward
return torch.nn.functional.grid_sample(x, grid, align_corners=False)
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example
from user code:
File "repro.py", line 10, in forward
return torch.nn.functional.grid_sample(x, grid, align_corners=False)
Versions
Collecting environment information...
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39
Python version: 3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-1027-gcp-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 7
BogoMIPS: 4400.37
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 6 MiB (6 instances)
L3 cache: 38.5 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] torch==2.7.0
[pip3] triton==3.3.0
[conda] Could not collect
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4