8000 Weight Shape Corruption in Sequential Modules During Stateful Execution · Issue #155246 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Weight Shape Corruption in Sequential Modules During Stateful Execution #155246
Closed
@LiSsHhUuAaIi

Description

@LiSsHhUuAaIi

🐛 Describe the bug

When compiling models using nn.Sequential with stateful execution (where each layer's input depends on the previous output), Dynamo corrupts convolution weight shapes during fake tensor simulation. This causes channel mismatch errors that don't occur in eager mode.

To Reproduce

import torch  
import torch.nn as nn  

class Model(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.bottlenecks = nn.Sequential(  
            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # Correct shape: [128,64,3,3]  
            nn.Conv2d(128, 64, kernel_size=1)              # Correct shape: [64,128,1,1]  
        )  
      
    def forward(self, x):  
        y = [x]  
        y.extend(m(y[-1]) for m in self.bottlenecks)  # Stateful execution  
        return y[-1]  

x = torch.randn(32, 64, 224, 224)  
model = Model()  
model(x)  # Eager mode works  
torch.compile(model, backend="eager")(x)  # Compiled mode fails  

Error logs

Traceback (most recent call last):
File "test.py", line 20, in
torch.compile(model, backend="eager")(x) # Compiled mode fails
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\eval_frame.py", line 655, in _fn
return fn(*args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\convert_frame.py", line 1432, in call
return self._torchdynamo_orig_callable(
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\convert_frame.py", line 1213, in call
result = self.inner_convert(
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\convert_frame.py", line 598, in call
return compile(
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\convert_frame.py", line 1059, in compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\convert_frame.py", line 761, in compile_inner
return compile_inner(code, one_graph, hooks, transform)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\convert_frame.py", line 797, in compile_inner
out_code = transform_code_object(code, transform)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\bytecode_transformation.py", line 1422, in transform_code_object
transformations(instructions, code_options)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\convert_frame.py", line 257, in fn
return fn(*args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\convert_frame.py", line 715, in transform
tracer.run()
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 3500, in run
super().run()
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1337, in run
while self.step():
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1246, in step
self.dispatch_table[ 9223 inst.opcode](self, inst)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 819, in wrapper
return inner_fn(self, inst)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 2168, in CALL_FUNCTION
self.call_function(fn, args, {})
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1170, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\misc.py", line 903, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\lists.py", line 516, in call_method
return super().call_method(tx, name, args, kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\lists.py", line 376, in call_method
seq = arg.force_unpack_var_sequence(tx)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\functions.py", line 543, in force_unpack_var_sequence
result.append(self.next_variable(tx))
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\functions.py", line 521, in next_variable
return tracer.inline_call
()
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 3905, in inline_call

self.run()
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1337, in run
while self.step():
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1246, in step
self.dispatch_table[inst.opcode](self, inst)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 819, in wrapper
return inner_fn(self, inst)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 2168, in CALL_FUNCTION
self.call_function(fn, args, {})
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1170, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\nn_module.py", line 952, in call_function
return variables.UserFunctionVariable(fn, source=source).call_function(
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\functions.py", line 404, in call_function
return super().call_function(tx, args, kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\functions.py", line 185, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1187, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 3726, in inline_call
return tracer.inline_call
()
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 3905, in inline_call

self.run()
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1337, in run
while self.step():
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1246, in step
self.dispatch_table[inst.opcode](self, inst)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 819, in wrapper
return inner_fn(self, inst)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 2168, in CALL_FUNCTION
self.call_function(fn, args, {})
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1170, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\functions.py", line 926, in call_function
return super().call_function(tx, args, kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\functions.py", line 404, in call_function
return super().call_function(tx, args, kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\functions.py", line 185, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1187, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 3726, in inline_call
return tracer.inline_call
()
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 3905, in inline_call

self.run()
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1337, in run
while self.step():
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1246, in step
self.dispatch_table[inst.opcode](self, inst)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 819, in wrapper
return inner_fn(self, inst)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 2168, in CALL_FUNCTION
self.call_function(fn, args, {})
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\symbolic_convert.py", line 1170, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\torch.py", line 1181, in call_function
tensor_variable = wrap_fx_proxy(
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\builder.py", line 2302, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\builder.py", line 2368, in wrap_fx_proxy_cls
return _wrap_fx_proxy(
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\variables\builder.py", line 2464, in _wrap_fx_proxy
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\utils.py", line 3229, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.traceback) from None
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\utils.py", line 3127, in get_fake_value
ret_val = wrap_fake_exception(
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\utils.py", line 2641, in wrap_fake_exception
return fn()
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\utils.py", line 3128, in
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\utils.py", line 3325, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_dynamo\utils.py", line 3284, in run_node
return node.target(*args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\utils_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_subclasses\fake_tensor.py", line 1282, in torch_dispatch
return self.dispatch(func, types, args, kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_subclasses\fake_tensor.py", line 1823, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_subclasses\fake_tensor.py", line 1384, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_subclasses\fake_tensor.py", line 2397, in _dispatch_impl
op_impl_out = op_impl(self, func, *args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_subclasses\fake_impls.py", line 160, in dispatch_to_op_implementations_dict
return op_implementations_dict[func](fake_mode, func, args, **kwargs)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch_subclasses\fake_impls.py", line 762, in conv
conv_backend = torch._C._select_conv_backend(**kwargs)
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in method conv2d of type object at 0x00007FFE44271440>(
(FakeTensor(..., size=(32, 64, 224, 224)), Parameter(FakeTensor(..., size=(64, 128, 1, 1), requires_grad=True)), Parameter(FakeTensor(..., size=(64,), requires_grad=True)), (1, 1), (0, 0), (1, 1), 1), **{}): got RuntimeError('Given groups=1, weight of size [64, 128, 1, 1], expected input[32, 64, 224, 224] to have 128 channels, but got 64 channels instead')

from user code:
File "test.py", line 14, in forward
y.extend(m(y[-1]) for m in self.bottlenecks) # Stateful execution
File "test.py", line 14, in
y.extend(m(y[-1]) for m in self.bottlenecks) # Stateful execution
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\conv.py", line 554, in forward
return self._conv_forward(input, self.weight, self.bias)
File "D:\Programs\Python\virtualenvs\torch_code-afvE469o\lib\site-packages\torch\nn\modules\conv.py", line 549, in _conv_forward
return F.conv2d(

Versions

Collecting environment information...
PyTorch version: 2.7.0+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 4.0.2
Libc version: N/A

Python version: 3.10.10 (tags/v3.10.10:aad5f6a, Feb 7 2023, 17:20:36) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Name: AMD Ryzen 5 7500F 6-Core Processor
Manufacturer: AuthenticAMD
Family: 107
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3701
MaxClockSpeed: 3701
L2CacheSize: 6144
L2CacheSpeed: None
Revision: 24834

Versions of relevant libraries:
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.2.4
[pip3] nvidia-cuda-runtime-cu12==12.9.37
[pip3] > [pip3] onnx-coreml==1.3
[pip3] onnx-tf==1.10.0
[pip3] > [pip3] > [pip3] > [pip3] optree==0.15.0
[pip3] pytorch-lightning==2.5.1
[pip3] torch==2.7.0
[pip3] torch-directml==0.2.5.dev240914
[pip3] torch-geometric==2.6.1
[pip3] torch_scatter==2.1.2
[pip3] torch_tensorrt==2.7.0
[pip3] torchani==2.2.4
[pip3] torchao==0.11.0
[pip3] torchaudio==2.7.0
[pip3] torchdata==0.11.0
[pip3] torchfx==0.1.0
[pip3] torchmetrics==1.7.1
[pip3] torchrec==0.1.0
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.18.0
[pip3] torchvision==0.19.1
[pip3] torchviz==0.0.3
[pip3] torchx-nightly==2025.5.28
[pip3] vector-quantize-pytorch==1.22.16
[conda] Could not collect

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0