8000 Compiled `nn.Module` with tensor subclass can't be moved to another device · Issue #141548 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Compiled nn.Module with tensor subclass can't be moved to another device #141548
Open
@gau-nernst

Description

@gau-nernst

🐛 Describe the bug

import torch

aten = torch.ops.aten

class Subclass(torch.Tensor):
    def __new__(cls, data):
        return torch.Tensor._make_wrapper_subclass(cls, data.shape, dtype=data.dtype, device=data.device)

    def __init__(self, data):
        self._data = data

    def __repr__(self):
        return f"{self.__class__.__name__}(data={self._data})"

    def __tensor_flatten__(self):
        return ["_data"], []

    @classmethod
    def __tensor_unflatten__(cls, inner_tensors, ctx, outer_size, outer_stride):
        return cls(inner_tensors["_data"])

    def __torch_function__(self, func, types, args, kwargs=None):
        if func == torch.nn.functional.linear:
            return func(args[0], args[1]._data, *args[2:])

        with torch._C.DisableTorchFunctionSubclass():
            return func(*args, **(kwargs or dict()))

    def __torch_dispatch__(self, func, types, args, kwargs):
        if func in (aten._to_copy.default, aten.detach.default):
            args = [x._data if isinstance(x, Subclass) else x for x in args]
            out = func(*args, **kwargs)
            return Subclass(out)

        raise NotImplementedError(f"{func=}")


linear = torch.nn.Linear(2, 2)
linear.weight = torch.nn.Parameter(Subclass(linear.weight.detach()))
linear.compile()
linear(torch.randn(1, 2))
linear.cpu()  # doesn't work even for the same device
RuntimeError                              Traceback (most recent call last)
File ~/xxx/python3.10/site-packages/torch/nn/modules/module.py:948, in Module._apply(self, fn, recurse)
    945     param_applied = torch.nn.Parameter(
    946         param_applied, requires_grad=param.requires_grad
    947     )
--> 948     torch.utils.swap_tensors(param, param_applied)
    949 except Exception as e:

File ~/xxx/python3.10/site-packages/torch/utils/__init__.py:51, in swap_tensors(t1, t2)
     50 if weakref.getweakrefs(t1):
---> 51     raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
     52 if weakref.getweakrefs(t2):

RuntimeError: Cannot swap t1 because it has weakref associated with it

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[1], line 42
     40 linear.compile()
     41 linear(torch.randn(1, 2))
---> 42 linear.cpu()

File ~/xxx/python3.10/site-packages/torch/nn/modules/module.py:1121, in Module.cpu(self)
   1112 def cpu(self: T) -> T:
   1113     r"""Move all model parameters and buffers to the CPU.
   1114 
   1115     .. note::
   (...)
   1119         Module: self
   1120     """
-> 1121     return self._apply(lambda t: t.cpu())

File ~/xxx/python3.10/site-packages/torch/nn/modules/module.py:952, in Module._apply(self, fn, recurse)
    950         if param_grad is not None:
    951             param.grad = param_grad
--> 952         raise RuntimeError(
    953             f"_apply(): Couldn't swap {self._get_name()}.{key}"
    954         ) from e
    955     out_param = param
    956 elif p_should_use_set_data:

RuntimeError: _apply(): Couldn't swap Linear.weight

Per title. A compiled module with tensor subclass can't be moved to another device. cc: pytorch/ao#1309

Versions

torch==2.6.0.dev20241125+cu124

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @albanD @chauhang @penguinwu @bdhirsh @yf225

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2tensor subclassRelated to tensor subclassestriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0