8000 Obscure error: Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'immutable_list' · Issue #122129 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Obscure error: Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'immutable_list' #122129
Open
@ezyang

Description

@ezyang

🐛 Describe the bug

Repro:

import torch
import torch._dynamo

torch._dynamo.config.capture_scalar_outputs = True

torch.library.define("ezyang::split_with_sizes_and_clone", "(Tensor input, int[] sizes) -> Tensor[]")

def split_with_sizes_and_clone(input, sizes):
    return [t.clone() for t in torch.ops.aten.split_with_sizes.default(input, sizes)]

torch.library.impl("ezyang::split_with_sizes_and_clone", "default", split_with_sizes_and_clone)

@torch.library.impl_abstract("ezyang::split_with_sizes_and_clone")
def split_with_sizes_and_clone_abstract(input, sizes):
    # TODO: I'm lazy
    rs = torch.ops.aten.split_with_sizes.default(input, sizes)
    return [input.new_empty(r.size()) for r in rs]

@torch.compile()
def f(sz, x):
    s0, s1 = sz.tolist()
    r0, r1 = torch.ops.ezyang.split_with_sizes_and_clone.default(x, [s0, s1])
    return torch.ops.aten.sort.default(r1)

N = 7312
S0 = 420
S1 = N - S0

f(torch.tensor([S0, S1]), torch.randn(N))

fails with

torch._dynamo.exc.TorchRuntimeError: Failed running call_function ezyang.split_with_sizes_and_clone.default(*(FakeTensor(..., size=(7312,)), [u0, u1]), **{}):
ezyang::split_with_sizes_and_clone() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'immutable_list'.
Position: 1
Value: [u0, u1]
Declaration: ezyang::split_with_sizes_and_clone(Tensor input, int[] sizes) -> Tensor[]
Cast error details: Unable to cast Python instance of type <class 'torch.fx.immutable_collections.immutable_list'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)

Actually this has nothing to do with immutable_list though; I should have put SymInt in my schema rather than int. Would be good to have a better error in this case

Versions

main

cc @chauhang @penguinwu @bobrenjc93 @zou3519 @bdhirsh @msaroufim @anijain2305

Metadata

Metadata

Assignees

No one assigned

    Labels

    good first issuellm-amenablemodule: dynamic shapesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0