Open
Description
🐛 Describe the bug
Repro:
import torch
import torch._dynamo
torch._dynamo.config.capture_scalar_outputs = True
torch.library.define("ezyang::split_with_sizes_and_clone", "(Tensor input, int[] sizes) -> Tensor[]")
def split_with_sizes_and_clone(input, sizes):
return [t.clone() for t in torch.ops.aten.split_with_sizes.default(input, sizes)]
torch.library.impl("ezyang::split_with_sizes_and_clone", "default", split_with_sizes_and_clone)
@torch.library.impl_abstract("ezyang::split_with_sizes_and_clone")
def split_with_sizes_and_clone_abstract(input, sizes):
# TODO: I'm lazy
rs = torch.ops.aten.split_with_sizes.default(input, sizes)
return [input.new_empty(r.size()) for r in rs]
@torch.compile()
def f(sz, x):
s0, s1 = sz.tolist()
r0, r1 = torch.ops.ezyang.split_with_sizes_and_clone.default(x, [s0, s1])
return torch.ops.aten.sort.default(r1)
N = 7312
S0 = 420
S1 = N - S0
f(torch.tensor([S0, S1]), torch.randn(N))
fails with
torch._dynamo.exc.TorchRuntimeError: Failed running call_function ezyang.split_with_sizes_and_clone.default(*(FakeTensor(..., size=(7312,)), [u0, u1]), **{}):
ezyang::split_with_sizes_and_clone() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'immutable_list'.
Position: 1
Value: [u0, u1]
Declaration: ezyang::split_with_sizes_and_clone(Tensor input, int[] sizes) -> Tensor[]
Cast error details: Unable to cast Python instance of type <class 'torch.fx.immutable_collections.immutable_list'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)
Actually this has nothing to do with immutable_list though; I should have put SymInt in my schema rather than int. Would be good to have a better error in this case
Versions
main
cc @chauhang @penguinwu @bobrenjc93 @zou3519 @bdhirsh @msaroufim @anijain2305