Description
For each of the items here, we should make sure all compositions (vmap, vmap x vjp) have a batching rule. All of these items should be actionable (in that it is possible to write a batching rule and we are not blocked on functionalization, which is coming soon).
Note: you may need to write an OpInfo for the operator if it doesn't exist already or wait for one to be added. A lot of folks are adding OpInfos right now, so if the OpInfo doesn't exist please ask first to see if someone is working on it.
Note: if any of the operations decompose into in-place operations, then we need functionalization to handle them. I think I've already filtered out all of those, but please check me on that.
Parcel 1: top nn.functional.* and top torch.* foo
- torch.nn.functional.interpolate (this involves adding batching rules for adaptive_avg_pool1d, but we might as well do
adaptive_avg_pool{1, 2, 3}d
as well as their backward variants while we're at it) - nn.functional.unfold. Involves writing a batching rule for im2col im2col_backward (Added im2col batch rule and enabled vmap for nn.functional.unfold op #262)
- nn.functional.grid_sample. Involves writing a batching rule for the backward operator. @vfdev-5
- torch.pow. The backward needs batching rule for logical_and ; use this as an opportunity to write the batching rules for
logical_{and, or, xor}
if those don't exist yet. We may need to also add a change to PyTorch core to make the logical_* functions primitives w.r.t. autograd
Parcel 2: new_blah
- new_empty, new_full, new_ones, new_zeros, empty_like, zeros_like, ones_like, full_like
-
adaptive_max_pool{1, 2, 3}d
as well as the backward variants - diagonal_scatter, select_scatter, slice_scatter @kshitij12345
- linalg.householder_product (forward pass only)
- pixel_shuffle, pixel_unshuffle (these might just be OP_DECOMPOSE)
- isinf, isfinite, isnan
- _cdist_forward, _cdist_backward (try to write a batching rule if possible. If not possible, we may need to write a decomposition) @vfdev-5
Parcel 3: linalg things
- _lu_with_info (forward pass only) @vfdev-5
- linalg.eig (forward pass only) @vfdev-5
- torch.addr @vfdev-5
- cholesky_solve (forward pass only) @vfdev-5
Parcel 4:
-
index_select, index_copy, etc, all need a backward formula in pytorch/pytorch vmap over composite out-of-place ops whose in-place variant is non-composite #260