Open
Description
🐛 Describe the bug
import torch
aten = torch.ops.aten
class Subclass(torch.Tensor):
def __new__(cls, data):
return torch.Tensor._make_wrapper_subclass(cls, data.shape, dtype=data.dtype, device=data.device)
def __init__(self, data):
self._data = data
def __repr__(self):
return f"{self.__class__.__name__}(data={self._data})"
def __tensor_flatten__(self):
return ["_data"], []
@classmethod
def __tensor_unflatten__(cls, inner_tensors, ctx, outer_size, outer_stride):
return cls(inner_tensors["_data"])
def __torch_function__(self, func, types, args, kwargs=None):
if func == torch.nn.functional.linear:
return func(args[0], args[1]._data, *args[2:])
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **(kwargs or dict()))
def __torch_dispatch__(self, func, types, args, kwargs):
if func in (aten._to_copy.default, aten.detach.default):
args = [x._data if isinstance(x, Subclass) else x for x in args]
out = func(*args, **kwargs)
return Subclass(out)
raise NotImplementedError(f"{func=}")
linear = torch.nn.Linear(2, 2)
linear.weight = torch.nn.Parameter(Subclass(linear.weight.detach()))
linear.compile()
linear(torch.randn(1, 2))
linear.cpu() # doesn't work even for the same device
RuntimeError Traceback (most recent call last)
File ~/xxx/python3.10/site-packages/torch/nn/modules/module.py:948, in Module._apply(self, fn, recurse)
945 param_applied = torch.nn.Parameter(
946 param_applied, requires_grad=param.requires_grad
947 )
--> 948 torch.utils.swap_tensors(param, param_applied)
949 except Exception as e:
File ~/xxx/python3.10/site-packages/torch/utils/__init__.py:51, in swap_tensors(t1, t2)
50 if weakref.getweakrefs(t1):
---> 51 raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
52 if weakref.getweakrefs(t2):
RuntimeError: Cannot swap t1 because it has weakref associated with it
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Cell In[1], line 42
40 linear.compile()
41 linear(torch.randn(1, 2))
---> 42 linear.cpu()
File ~/xxx/python3.10/site-packages/torch/nn/modules/module.py:1121, in Module.cpu(self)
1112 def cpu(self: T) -> T:
1113 r"""Move all model parameters and buffers to the CPU.
1114
1115 .. note::
(...)
1119 Module: self
1120 """
-> 1121 return self._apply(lambda t: t.cpu())
File ~/xxx/python3.10/site-packages/torch/nn/modules/module.py:952, in Module._apply(self, fn, recurse)
950 if param_grad is not None:
951 param.grad = param_grad
--> 952 raise RuntimeError(
953 f"_apply(): Couldn't swap {self._get_name()}.{key}"
954 ) from e
955 out_param = param
956 elif p_should_use_set_data:
RuntimeError: _apply(): Couldn't swap Linear.weight
Per title. A compiled module with tensor subclass can't be moved to another device. cc: pytorch/ao#1309
Versions
torch==2.6.0.dev20241125+cu124
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @albanD @chauhang @penguinwu @bdhirsh @yf225