Description
🐛 Describe the bug
Description:
When converting a PyTorch model containing a FourierUnit with torch.fft.rfftn and torch.fft.irfftn to ONNX using torch.compile or (torch.onnx.dynamo_export), the output shape of the ONNX model is incorrect. The expected output shape is not preserved after conversion.
Reproduction Steps:
Define a FourierUnit model as shown in the provided code.
Create an input tensor with shape [1, 192, 64, 64].
Pass the input through the PyTorch model and observe the output shape (should be [1, 192, 64, 64]).
Convert the model to ONNX using torch.compile with the "onnxrt" backend.
Pass the same input through the ONNX model and observe the output shape (incorrectly becomes [1, 192, 64, 33]).
Expected Behavior:
The output shape of the ONNX model should be the same as the PyTorch model, which is [1, 192, 64, 64].
Actual Behavior:
The output shape of the ONNX model is [1, 192, 64, 33], which is incorrect.
Error logs
ORG torch.Size([1, 192, 64, 64])
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py:136: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
warnings.warn(
W0510 05:42:17.739000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.unsqueeze.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.739000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.unsqueeze.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.739000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.unsqueeze.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.739000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.le.Scalar (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.739000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.detach.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.739000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.detach.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten._fft_c2r.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.complex.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.select.int (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.select.int (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.clone.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.permute.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.view.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.relu.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:195] [0/0] extra_support_dict supports node.target: <built-in function getitem> (type: <class 'builtin_function_or_method'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:195] [0/0] extra_support_dict supports node.target: <built-in function getitem> (type: <class 'builtin_function_or_method'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:195] [0/0] extra_support_dict supports node.target: <built-in function getitem> (type: <class 'builtin_function_or_method'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:195] [0/0] extra_support_dict supports node.target: <built-in function getitem> (type: <class 'builtin_function_or_method'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:195] [0/0] extra_support_dict supports node.target: <built-in function getitem> (type: <class 'builtin_function_or_method'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten._native_batch_norm_legit_functional.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.add.Tensor (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.740000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.convolution.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.741000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.view.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.741000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.clone.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.741000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.permute.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.741000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.stack.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.741000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.select.int (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.741000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.select.int (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.741000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten.view_as_real.default (type: <class 'torch._ops.OpOverload'>)
W0510 05:42:17.741000 139715516277184 torch/onnx/_internal/onnxruntime.py:185] [0/0] support_dict supports node.target: aten._fft_r2c.default (type: <class 'torch._ops.OpOverload'>)
2024-05-10 05:42:17.820522491 [W:onnxruntime:, graph.cc:108 MergeShapeInfo] Error merging shape info for output. '_fft_c2r' source:{1,192,64,33} target:{1,192,64,64}. Falling back to lenient merge.
ONNXRT torch.Size([1, 192, 64, 33])
Minified repro
import torch
import torch.nn as nn
import torch.nn.functional as F
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
res = x * y.expand_as(x)
return res
class FourierUnit(nn.Module):
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
# bn_layer not used
super(FourierUnit, self).__init__()
self.groups = groups
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
out_channels=out_channels * 2,
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
self.relu = torch.nn.ReLU(inplace=True)
# squeeze and excitation block
self.use_se = use_se
if use_se:
if se_kwargs is None:
se_kwargs = {}
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
self.spatial_scale_factor = spatial_scale_factor
self.spatial_scale_mode = spatial_scale_mode
self.spectral_pos_encoding = spectral_pos_encoding
self.ffc3d = ffc3d
self.fft_norm = fft_norm
def forward(self, x):
batch = x.shape[0]
if self.spatial_scale_factor is not None:
orig_size = x.shape[-2:]
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
r_size = x.size()
# (batch, c, h, w/2+1, 2)
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
if self.spectral_pos_encoding:
height, width = ffted.shape[-2:]
coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
if self.use_se:
ffted = self.se(ffted)
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
ffted = self.relu(self.bn(ffted))
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
ffted = torch.complex(ffted[..., 0], ffted[..., 1]) # 1, 192, 64, 33
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
# 2024-05-09 14:53:52.733454854 [W:onnxruntime:, graph.cc:108 MergeShapeInfo] Error merging shape info for output. '_fft_c2r' source:{1,192,64,33} target:{1,192,64,64}. Falling back to lenient merge.
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
if self.spatial_scale_factor is not None:
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
return output
fu = FourierUnit(192, 192, groups=1)
x = torch.randn(1, 192, 64, 64)
out = fu(x)
print("ORG", out.shape)
graph = torch.compile(fu, backend="onnxrt")
x = torch.randn(1, 192, 64, 64)
xout
8A55
= graph(x)
print("ONNXRT", xout.shape) # expect 1, 192, 64, 64, but got 1, 192, 64, 33
Versions
Collecting environment information...
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.27.9
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1<EDITED>-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Laptop GPU
Nvidia driver version: 550.67
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: GenuineIntel
Model name: 11th Gen Intel(R) Core(TM) i7-11370H @ 3.30GHz
CPU family: 6
Model: 140
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
Stepping: 1
CPU max MHz: 4800.0000
CPU min MHz: 400.0000
BogoMIPS: 6607.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l2 invpcid_single cdp_l2 ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves split_lock_detect dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid movdiri movdir64b fsrm avx512_vp2intersect md_clear ibt flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 192 KiB (4 instances)
L1i cache: 128 KiB (4 instances)
L2 cache: 5 MiB (4 instances)
L3 cache: 12 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-7
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3]
[pip3]
[pip3]
[pip3] pytorch-lightning==2.2.4
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] torch==2.3.0
[pip3] torchaudio==2.3.0
[pip3] torchdata==0.7.0
[pip3] torchmetrics==1.4.0
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.16.0
[pip3] torchvision==0.18.0
[pip3] triton==2.3.0
[conda] Could not collect
Additional info
The issue persists even when using the nightly version of PyTorch.
The problem also occurs when using torch.onnx.dynamo_export for ONNX conversion.
cc @ezyang @chauhang @penguinwu @msaroufim @bdhirsh @anijain2305