Open
Description
🐛 Describe the bug
This would've been useful for exporting DeepSeek VL2, in rewriting this section to avoid a data-dependent error on the if num_tokens == 0
guard:
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
Minimal repro:
import torch
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(10, 10)
def forward(self, x):
u0 = x.item()
y = torch.empty(u0, 10)
return torch.cond(
u0 > 0,
self.lin,
lambda: torch.empty_like(y),
[y],
)
ep = torch.export.export(Foo(), (torch.tensor([4]),), strict=False)
error:
Traceback (most recent call last):
File "/data/users/pianpwk/DeepSeek-VL2/test_cond.py", line 17, in <module>
ep = torch.export.export(Foo(), (torch.tensor([4]),), strict=False)
File "/data/users/pianpwk/pytorch/torch/export/__init__.py", line 319, in export
raise e
File "/data/users/pianpwk/pytorch/torch/export/__init__.py", line 286, in export
return _export(
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1159, in wrapper
raise e
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1125, in wrapper
ep = fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 2172, in _export
ep = _export_for_training(
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1159, in wrapper
raise e
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1125, in wrapper
ep = fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 2033, in _export_for_training
export_artifact = export_func(
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1975, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1760, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1901, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1679, in _make_fx_helper
gm = make_fx(
File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 2295, in wrapped
return make_fx_tracer.trace(f, *args)
File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 2233, in trace
return self._trace_inner(f, *args)
File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 2204, in _trace_inner
t = dispatch_trace(
File "/data/users/pianpwk/pytorch/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/_dynamo/eval_frame.py", line 893, in _fn
return fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 1221, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 1792, in trace
res = super().trace(root, concrete_args)
File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 850, in trace
(self.create_arg(fn(*args)),),
File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 1279, in wrapped
out = f(*tensors) # type:ignore[call-arg]
File "<string>", line 1, in <lambda>
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1583, in wrapped_fn
return tuple(flat_fn(*args))
File "/data/users/pianpwk/pytorch/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call
out = mod(*args[params_len:], **kwargs)
File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 542, in call_module
ret_val = forward(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 818, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/export/_trace.py", line 1885, in forward
tree_out = mod(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/data/users/pianpwk/pytorch/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 542, in call_module
ret_val = forward(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/fx/_symbolic_trace.py", line 818, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/nn/modules/module.py", line 1778, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/pianpwk/DeepSeek-VL2/test_cond.py", line 10, in forward
return torch.cond(
File "/data/users/pianpwk/pytorch/torch/_higher_order_ops/cond.py", line 194, in cond
return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
File "/data/users/pianpwk/pytorch/torch/_dynamo/eval_frame.py", line 699, in compile_wrapper
return fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 1463, in __call__
return self._torchdynamo_orig_callable(
File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 624, in __call__
return _compile(
File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 1087, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/data/users/pianpwk/pytorch/torch/_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 778, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 817, in _compile_inner
out_code = transform_code_object(code, transform)
File "/data/users/pianpwk/pytorch/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
transformations(instructions, code_options)
File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 264, in _fn
return fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/_dynamo/convert_frame.py", line 742, in transform
tracer.run()
File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 3508, in run
super().run()
File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 1345, in run
while self.step():
File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 1253, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 828, in wrapper
return inner_fn(self, inst)
File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 2254, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/data/users/pianpwk/pytorch/torch/_dynamo/symbolic_convert.py", line 1179, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 75, in graph_break_as_hard_error
return fn(*args, **kwargs)
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 980, in call_function
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 137, in realize_all
result = tuple(cls.realize_all(v, cache) for v in value)
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 137, in <genexpr>
result = tuple(cls.realize_all(v, cache) for v in value)
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 135, in realize_all
result = [cls.realize_all(v, cache) for v in value]
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 135, in <listcomp>
result = [cls.realize_all(v, cache) for v in value]
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 125, in realize_all
result = cls.realize_all(value.realize(), cache)
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 133, in realize_all
value_dict[key] = cls.realize_all(value_dict[key], cache)
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 135, in realize_all
result = [cls.realize_all(v, cache) for v in value]
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 135, in <listcomp>
result = [cls.realize_all(v, cache) for v in value]
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 125, in realize_all
result = cls.realize_all(value.realize(), cache)
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 67, in realize
self._cache.realize()
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/lazy.py", line 32, in realize
self.vt = builder.VariableBuilder(tx, self.source)(self.value)
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/builder.py", line 436, in __call__
vt = self._wrap(value)
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/builder.py", line 634, in _wrap
return type_dispatch(self, value)
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/builder.py", line 1961, in wrap_tensor
example_value = wrap_to_fake_tensor_and_record(
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/builder.py", line 3260, in wrap_to_fake_tensor_and_record
fake_e = wrap_fake_exception(
File "/data/users/pianpwk/pytorch/torch/_dynamo/utils.py", line 2705, in wrap_fake_exception
return fn()
File "/data/users/pianpwk/pytorch/torch/_dynamo/variables/builder.py", line 3261, in <lambda>
lambda: tx.fake_mode.from_tensor(
File "/data/users/pianpwk/pytorch/torch/_subclasses/fake_tensor.py", line 2919, in from_tensor
return self.fake_tensor_converter.from_real_tensor(
File "/data/users/pianpwk/pytorch/torch/_subclasses/fake_tensor.py", line 398, in from_real_tensor
out = self.meta_converter(
File "/data/users/pianpwk/pytorch/torch/_subclasses/meta_utils.py", line 1913, in __call__
r = self.meta_tensor(
File "/data/users/pianpwk/pytorch/torch/_subclasses/meta_utils.py", line 1680, in meta_tensor
) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
File "/data/users/pianpwk/pytorch/torch/_subclasses/meta_utils.py", line 943, in sym_sizes_strides_storage_offset
t_size = tuple(
File "/data/users/pianpwk/pytorch/torch/_subclasses/meta_utils.py", line 944, in <genexpr>
shape_env._maybe_specialize_sym_int_with_hint(sz)
File "/data/users/pianpwk/pytorch/torch/fx/experimental/symbolic_shapes.py", line 4114, in _maybe_specialize_sym_int_with_hint
return maybe_sym.node.require_hint()
File "/data/users/pianpwk/pytorch/torch/fx/experimental/sym_node.py", line 219, in require_hint
return self.shape_env.size_hint(self.expr)
File "/data/users/pianpwk/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6114, in size_hint
raise self._make_data_dependent_error(result_expr, expr)
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)
Caused by: return cond_op(*args, **kwargs) # _higher_order_ops/cond.py:186 in _cond_op_wrapper (_dynamo/variables/builder.py:3261 in <lambda>)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/data/users/pianpwk/pytorch/torch/_higher_order_ops/cond.py", line 186, in _cond_op_wrapper
return cond_op(*args, **kwargs)
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
from user code:
File "/data/users/pianpwk/pytorch/torch/_higher_order_ops/cond.py", line 186, in _cond_op_wrapper
return cond_op(*args, **kwargs)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
Versions
nightly
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4