8000 PGO errors out in dynamo main path for sparse tensors · Issue #154161 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
PGO errors out in dynamo main path for sparse tensors #154161
Open
@xmfan

Description

@xmfan

🐛 Describe the bug

Repro:

import torch
torch._dynamo.config.compiled_autograd = True


@torch.compile(backend="eager")
def _test_grad_tensor(
    params_grad_tensor,
    backward_grad_tensor,
    should_preserve_reference,
    create_graph,
):
    params = torch.tensor([1.5, 1.5]).requires_grad_()
    params.grad = params_grad_tensor
    grad_saved = params.grad
    params.backward(backward_grad_tensor, create_graph=create_graph)
    assert (id(grad_saved) == id(params.grad)) == should_preserve_reference

_test_grad_tensor(
    torch.sparse_coo_tensor(
        torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0])
    ),
    torch.tensor([1.5, 1.5]),
    False,  # never accumulates in-place
    False,
)

Stride is empty list because we set it here:

https://github.com/pytorch/pytorch/blob/59c5fff2aa2216612ea54e287a73f2055aab84ac/torch/_dynamo/variables/builder.py#L3042C9-L3042C20

Traceback (most recent call last):
  File "/home/xmfan/core/a/pytorch/sparse.py", line 18, in <module>
    _test_grad_tensor(
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/eval_frame.py", line 699, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/sparse.py", line 12, in _test_grad_tensor
    params = torch.tensor([1.5, 1.5]).requires_grad_()
  File "/home/xmfan/core/a/pytorch/sparse.py", line 15, in torch_dynamo_resume_in__test_grad_tensor_at_12
    params.backward(backward_grad_tensor, create_graph=create_graph)
  File "/home/xmfan/core/a/pytorch/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(
  File "/home/xmfan/core/a/pytorch/torch/autograd/__init__.py", line 354, in backward
    _engine_run_backward(
  File "/home/xmfan/core/a/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/compiled_autograd.py", line 1031, in runtime_wrapper
    out = compiled_fn(
          ^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/eval_frame.py", line 372, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/eval_frame.py", line 699, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/fx/graph_module.py", line 840, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/fx/graph_module.py", line 416, in __call__
    raise e
  File "/home/xmfan/core/a/pytorch/torch/fx/graph_module.py", line 403, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/convert_frame.py", line 1463, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/convert_frame.py", line 1242, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/convert_frame.py", line 624, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/convert_frame.py", line 1138, in _compile
    raise InternalTorchDynamoError(
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/convert_frame.py", line 1087, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/convert_frame.py", line 778, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/convert_frame.py", line 817, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/bytecode_transformation.py", line 1423, in transform_code_object
    transformations(instructions, code_options)
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/convert_frame.py", line 264, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/convert_frame.py", line 742, in transform
    tracer.run()
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/symbolic_convert.py", line 3508, in run
    super().run()
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/symbolic_convert.py", line 1345, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/symbolic_convert.py", line 1253, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/symbolic_convert.py", line 828, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/symbolic_convert.py", line 422, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builtin.py", line 1160, in call_function
    return handler(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builtin.py", line 793, in <lambda>
    tx, [v.realize() for v in args], kwargs
         ^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/lazy.py", line 67, in realize
    self._cache.realize()
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/lazy.py", line 32, in realize
    self.vt = builder.VariableBuilder(tx, self.source)(self.value)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builder.py", line 436, in __call__
    vt = self._wrap(value)
         ^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builder.py", line 634, in _wrap
    return type_dispatch(self, value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builder.py", line 1538, in wrap_listlike
    list_variable = wrap_fx_proxy_cls(
                    ^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builder.py", line 2567, in wrap_fx_proxy_cls
    return handle_traced_output(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builder.py", line 2734, in handle_traced_output
    wrap_fx_proxy_cls(
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builder.py", line 2560, in wrap_fx_proxy_cls
    return _wrap_fx_preexisting_tensor(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builder.py", line 2622, in _wrap_fx_preexisting_tensor
    tensor = wrap_to_fake_tensor_and_record(tensor, tx=tx, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builder.py", line 3240, in wrap_to_fake_tensor_and_record
    symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/variables/builder.py", line 3125, in _automatic_dynamic
    config.automatic_dynamic_shapes and frame_state_entry.is_stride_dynamic(i)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/pgo.py", line 301, in is_stride_dynamic
    return self.stride[dim] is auto_dynamic
           ~~~~~~~~~~~^^^^^
torch._dynamo.exc.InternalTorchDynamoError: IndexError: tuple index out of range

Versions

main

cc @chauhang @penguinwu @ezyang @bobrenjc93

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0