Description
🚀 The feature, motivation and pitch
It'd be cool if when working with data-dependent values, we could use the tensor constructor calls to inform the compiler about their values, e.g. in an ideal world, the x[0] == 1
shouldn't data-dependent error out:
@torch.compile(fullgraph=True)
def f():
x = torch.ones(4, dtype=torch.int64)
if x[0] == 1:
return x * 2
return x + 1
neither should the slice operation here:
@torch._dynamo.config(capture_scalar_outputs=True)
@torch.compile(fullgraph=True)
def g():
x = torch.randn(10)
idxs = torch.randint(0, 10, size=(4,))
return [x[:idx.item()] for idx in idxs]
I'm not sure what the ROI is, but the work around data-dependent errors has a tendency to assume unbacked symbols aren't 0/1 (size-oblivious, guard_or_false), and as the guard_size_oblivious migration work is helping move away from a world where users are required to explicitly call _check_is_size
to opt-in, this kind of work would help defend compiler decisions for why we assume something is or is not 0/1.
e.g. it's more difficult to defend that this program could produce an output with non-zero stride by default, by assuming u0 != 1 and u0 == u1:
x = torch.ones(1, dtype=int)
u0 = x.item()
y = torch.randn(u0).expand(u1)
cc @chauhang @penguinwu @ezyang @bobrenjc93 @laithsakka
Alternatives
No response
Additional context
No response