8000 don't require recompiles when switching between torch.Tensor vs AsyncCollectiveTensor graph inputs · Issue #154847 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
don't require recompiles when switching between torch.Tensor vs AsyncCollectiveTensor graph inputs #154847
Open
@bdhirsh

Description

@bdhirsh

See this recompile reason:

 expected type of 'L['x']._local_tensor' to be a tensor type, ' but found <class 'torch.distributed._functional_collectives.AsyncCollectiveTensor'>"

Today, if you compile a region where:
(1) first run: you pass a regular input
(2) second run: your input to the compiled region is the output of a functional collective op that we have not yet synchronized on

then we will be forced to recompile for the two different cases.

You could argue that this is a bit silly: functionally, the compiled artifact that we are generating in both cases is almost identical. The only difference is that in the second case, we need to issue a wait_tensor() somewhere in the compiled region to synchronize the input data before reading from it.

You can manually avoid the recompile today by calling wait_tensor() on your input manually before the compiled region, but this can potentially have perf issues: you are not giving the compiler the chance to move the synchronization later the the first site where the tensor data is read from.

Can we support this properly? The most obvious option would be to unconditionally insert wait_tensor() calls on all of our graph inputs, which will become no-ops if the input did not come from a pending collective (aka is not an AsyncCollectiveTensor).

A few options are:

(1) do this, but back it by a config that is off by default, so the user needs to manually enable this behavior. This is the safest option, but the user needs to know about the config (we could potentially include the config in the recompile reason).

(2) do this, on-by-default for torch.compile with the inductor backend. This might be reasonable to do? It is a bit higher blast-radius, since collectives are not that common in graphs today, but it could be a good hardening exercise

(3) do this automatically, but only when we detect functional collectives inside of the graph. There's no particular reason for this option, other than it tries to strike a middle ground.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: c10dIssues/PRs related to collective communications and process groupsoncall: distributedAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0