8000 Advanced Indexing does not trace correctly for tensor shape that has leading 1s · Issue #49852 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Advanced Indexing does not trace correctly for tensor shape that has leading 1s #49852
Closed
@ppwwyyxx

Description

@ppwwyyxx

🐛 Bug

A statement like a[b] = c where a, b, c are all tensors and a[b].shape == c.shape should be able to trace correctly because no reshape is needed. However it still uses reshape under the hood, and therefore does not trace correctly (does not generalize) when c.shape has leading 1s.

To Reproduce

Steps to reproduce the behavior:

import torch

def tensor_setitem(x, idx, y):
        x[idx] = y + 1
        return x

x = torch.randn(3, 4)
idx = torch.tensor([1, 2])  # works
idx = torch.tensor([1])  # fails
traced = torch.jit.trace(tensor_setitem, (x, idx, x[idx]))
x = torch.randn(10, 5)
assert torch.allclose(traced(x.clone(), idx, x[idx]),
                                 tensor_setitem(x.clone(), idx, x[idx]))

prints:

RuntimeError: shape '[4]' is invalid for input of size 5

Environment

pytorch master

Additional context

An old commit #9424 aims at skipping reshape to make the op traceable. However it only skips unnecessary reshapes when "c.shape does not have leading 1s" - which now results in this issue. This logic is later mirrored to aten in #32841.

#45828 fixes a similar problem for basic indexing. Advanced indexing seems more tricky.

cc @gmagogsfm

Metadata

Metadata

Labels

daysmodule: advanced indexingRelated to x[i] = y, index functionsoncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0