8000 [cond] support for unbacked symbols in compiled region · Issue #154559 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[cond] support for unbacked symbols in compiled region #154559
Open
@pianpwk

Description

@pianpwk

🐛 Describe the bug

This would've been useful for exporting DeepSeek VL2, in rewriting this section to avoid a data-dependent error on the if num_tokens == 0 guard:

        outputs = []
        start_idx = 0
        for i, num_tokens in enumerate(tokens_per_expert):
            end_idx = start_idx + num_tokens
            if num_tokens == 0:
                continue
            expert = self.experts[i + self.ep_rank * self.experts_per_rank]
            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
            expert_out = expert(tokens_for_this_expert)
            outputs.append(expert_out)
            start_idx = end_idx

Minimal repro:

import torch

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(10, 10)
    def forward(self, x):
        u0 = x.item()
        y = torch.empty(u0, 10)
        return torch.cond(
            u0 > 0,
            self.lin,
            lambda: torch.empty_like(y),
            [y],
        )

ep = torch.export.export(Foo(), (torch.tensor([4]),), strict=False)

error:

Traceback (most recent call last):
  File "/data/users/pianpwk/DeepSeek-VL2/test_cond.py", line 17, in <module>
    ep = torch.export.export(Foo(), (torch.tensor([4]),), strict=False)
  File "/data/users/pianpwk/pytorch/torch/export/__init__.py", line 319, in export
    raise e
  File "/data/users/pianpwk/pytorch/torch/export/__init__.py", line 286, in export
    return _export(
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1159, in wrapper
    raise e
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1125, in wrapper
    ep = fn(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 2172, in _export
    ep = _export_for_training(
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1159, in wrapper
    raise e
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1125, in wrapper
    ep = fn(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 2033, in _export_for_training
    export_artifact = export_func(
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1975, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1760, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1901, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1679, in _make_fx_helper
    gm = make_fx(
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 2295, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 2233, in trace
    return self._trace_inner(f, *args)
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 2204, in _trace_inner
    t = dispatch_trace(
  File "/data/users/pianpwk/pytorch/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/eval_frame.py", line 893, in _fn
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 1221, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 1792, in trace
    res = super().trace(root, concrete_args)
  File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 1279, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1583, in wrapped_fn
    return tuple(flat_fn(*args))
  File "/data/users/pianpwk/pytorch/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1885, in forward
    tree_out = mod(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/pianpwk/DeepSeek-VL2/test_cond.py", line 10, in forward
    return torch.cond(
  File "/data/users/pianpwk/pytorch/torch/_higher_order_ops/cond.py", line 194, in cond
    return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
  File "/data/users/pianpwk/pytorch/torch/_dynamo/eval_frame.py", line 699, in compile_wrapper
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 1463, in __call__
    return self._torchdynamo_orig_callable(
  File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 624, in __call__
    return _compile(
  File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 1087, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/data/users/pianpwk/pytorch/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 778, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 817, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
    transformations(instructions, code_options)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 264, in _fn
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 742, in transform
    tracer.run()
  File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 3508, in run
    super().run()
  File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 1345, in run
    while self.step():
  File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 1253, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 828, in wrapper
    return inner_fn(self, inst)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 2254, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 1179, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 75, in graph_break_as_hard_error
    return fn(*args, **kwargs)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 980, in call_function
    args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 137, in realize_all
    result = tuple(cls.realize_all(v, cache) for v in value)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 137, in <genexpr>
    result = tuple(cls.realize_all(v, cache) for v in value)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 135, in realize_all
    result = [cls.realize_all(v, cache) for v in value]
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 135, in <listcomp>
    result = [cls.realize_all(v, cache) for v in value]
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 125, in realize_all
    result = cls.realize_all(value.realize(), cache)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 133, in realize_all
    value_dict[key] = cls.realize_all(value_dict[key], cache)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 135, in realize_all
    result = [cls.realize_all(v, cache) for v in value]
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 135, in <listcomp>
    result = [cls.realize_all(v, cache) for v in value]
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 125, in realize_all
    result = cls.realize_all(value.realize(), cache)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 67, in realize
    self._cache.realize()
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 32, in realize
    self.vt = builder.VariableBuilder(tx, self.source)(self.value)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/builder.py", line 436, in __call__
    vt = self._wrap(value)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/builder.py", line 634, in _wrap
    return type_dispatch(self, value)
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/builder.py", line 1961, in wrap_tensor
    example_value = wrap_to_fake_tensor_and_record(
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/builder.py", line 3260, in wrap_to_fake_tensor_and_record
    fake_e = wrap_fake_exception(
  File "/data/users/pianpwk/pytorch/torch/_dynamo/utils.py", line 2705, in wrap_fake_exception
    return fn()
  File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/builder.py", line 3261, in <lambda>
    lambda: tx.fake_mode.from_tensor(
  File "/data/users/pianpwk/pytorch/torch/_subclasses/fake_tensor.py", line 2919, in from_tensor
    return self.fake_tensor_converter.from_real_tensor(
  File "/data/users/pianpwk/pytorch/torch/_subclasses/fake_tensor.py", line 398, in from_real_tensor
    out = self.meta_converter(
  File "/data/users/pianpwk/pytorch/torch/_subclasses/meta_utils.py", line 1913, in __call__
    r = self.meta_tensor(
  File "/data/users/pianpwk/pytorch/torch/_subclasses/meta_utils.py", line 1680, in meta_tensor
    ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
  File "/data/users/pianpwk/pytorch/torch/_subclasses/meta_utils.py", line 943, in sym_sizes_strides_storage_offset
    t_size = tuple(
  File "/data/users/pianpwk/pytorch/torch/_subclasses/meta_utils.py", line 944, in <genexpr>
    shape_env._maybe_specialize_sym_int_with_hint(sz)
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4114, in _maybe_specialize_sym_int_with_hint
    return maybe_sym.node.require_hint()
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/sym_node.py", line 219, in require_hint
    return self.shape_env.size_hint(self.expr)
  File "/data/users/pianpwk/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6114, in size_hint
    raise self._make_data_dependent_error(result_expr, expr)
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: u0)

Caused by: return cond_op(*args, **kwargs)  # _higher_order_ops/cond.py:186 in _cond_op_wrapper (_dynamo/variables/builder.py:3261 in <lambda>)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/data/users/pianpwk/pytorch/torch/_higher_order_ops/cond.py", line 186, in _cond_op_wrapper
    return cond_op(*args, **kwargs)

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

from user code:
   File "/data/users/pianpwk/pytorch/torch/_higher_order_ops/cond.py", line 186, in _cond_op_wrapper
    return cond_op(*args, **kwargs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

Versions

nightly

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0