8000 vmap over composite out-of-place ops whose in-place variant is non-composite · Issue #260 · pytorch/functorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
vmap over composite out-of-place ops whose in-place variant is non-composite #260
Open
@zou3519

Description

@zou3519

The following operations:

  • index_add
  • index_copy
  • index_fill
  • masked_fill
  • masked_scatter

have this quirk where they are composite but their in-place variant has an autograd formula. For example,

Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) {
  return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source);
}

We have trouble doing vmap(grad(foo)) where foo includes one of the above operations. This is because Autograd ends up decomposing e.g. index_fill into tensor.clone().index_fill_(...) and the in-place operation is not vmap-compatible.

I've brainstormed two ways to solve this. There are tradeoffs for each and I'm wondering if someone else has a preference.

Approach 1: DECOMPOSE_FUNCTIONAL

(cc @bdhirsh)

We use the functionalization pass to functionalize index_fill. Unfortunately this results in the following code:

self.clone(...).index_fill(dim, index, source)

which results in an unnecessary clone() which is not good for performance. IDK if we want to make the functionalization pass smarter in the future, this sounds complicated.

Approach 2: Add backward formulas for index_fill and all of the operations above (aka, turn them into primitive operations).

(cc @albanD)

This means that both index_fill and index_fill_ get backward formulas (Could we get away with only giving index_fill a backward formula?). This is a pretty simple solution, the tradeoff is that we need to duplicate the formulas and we are setting the precedence that "operations that have an out-of-place variant must have a backward formula defined on the out-of-place variant".

Discussion

I prefer Approach 2 for its simplicity. To address the code duplication we can put the formulas into helper functions. Thoughts?

Metadata

Metadata

Assignees

Labels

actionableIt is clear what should be done for this issue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0