Description
The following operations:
- index_add
- index_copy
- index_fill
- masked_fill
- masked_scatter
have this quirk where they are composite but their in-place variant has an autograd formula. For example,
Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) {
return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source);
}
We have trouble doing vmap(grad(foo))
where foo
includes one of the above operations. This is because Autograd ends up decomposing e.g. index_fill
into tensor.clone().index_fill_(...)
and the in-place operation is not vmap-compatible.
I've brainstormed two ways to solve this. There are tradeoffs for each and I'm wondering if someone else has a preference.
Approach 1: DECOMPOSE_FUNCTIONAL
(cc @bdhirsh)
We use the functionalization pass to functionalize index_fill. Unfortunately this results in the following code:
self.clone(...).index_fill(dim, index, source)
which results in an unnecessary clone() which is not good for performance. IDK if we want to make the functionalization pass smarter in the future, this sounds complicated.
Approach 2: Add backward formulas for index_fill
and all of the operations above (aka, turn them into primitive operations).
(cc @albanD)
This means that both index_fill
and index_fill_
get backward formulas (Could we get away with only giving index_fill a backward formula?). This is a pretty simple solution, the tradeoff is that we need to duplicate the formulas and we are setting the precedence that "operations that have an out-of-place variant must have a backward formula defined on the out-of-place variant".
Discussion
I prefer Approach 2 for its simplicity. To address the code duplication we can put the formulas into helper functions. Thoughts?