Description
See this recompile reason:
expected type of 'L['x']._local_tensor' to be a tensor type, ' but found <class 'torch.distributed._functional_collectives.AsyncCollectiveTensor'>"
Today, if you compile a region where:
(1) first run: you pass a regular input
(2) second run: your input to the compiled region is the output of a functional collective op that we have not yet synchronized on
then we will be forced to recompile for the two different cases.
You could argue that this is a bit silly: functionally, the compiled artifact that we are generating in both cases is almost identical. The only difference is that in the second case, we need to issue a wait_tensor()
somewhere in the compiled region to synchronize the input data before reading from it.
You can manually avoid the recompile today by calling wait_tensor()
on your input manually before the compiled region, but this can potentially have perf issues: you are not giving the compiler the chance to move the synchronization later the the first site where the tensor data is read from.
Can we support this properly? The most obvious option would be to unconditionally insert wait_tensor()
calls on all of our graph inputs, which will become no-ops if the input did not come from a pending collective (aka is not an AsyncCollectiveTensor
).
A few options are:
(1) do this, but back it by a config that is off by default, so the user needs to manually enable this behavior. This is the safest option, but the user needs to know about the config (we could potentially include the config in the recompile reason).
(2) do this, on-by-default for torch.compile with the inductor backend. This might be reasonable to do? It is a bit higher blast-radius, since collectives are not that common in graphs today, but it could be a good hardening exercise
(3) do this automatically, but only when we detect functional collectives inside of the graph. There's no particular reason for this option, other than it tries to strike a middle ground.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @chauhang @penguinwu