Description
π The feature, motivation and pitch
When testing if something like vmap works on an operator, we test the following things:
- if it errors out (it shouldn't)
- if the output of vmap matches the output of a for loop (it should)
- if there is a batching rule implemented for the operation. We do this by running the vmap and checking if it raises any "batching rule not implemented" warnings.
We have two separate tests, test_vmap_exhaustive, and test_op_has_batch_rule. The former does (1) and (2), and the latter does (1), (2), (3) (because (1) and (2) are almost required for (3)). We could cut down the test time if we have one test and used something like unittest.subTest in a way so that each test gets 3 subtests.
Furthermore, if a vmap test fails for e.g. torch.searchsorted, we have an expected failure for it. It would be nice to be able to distinguish between if it failed because of (1) or if it failed because of (2); (2) is a silent correctness issue and is much more hi-pri to fix.
Alternatives
n/a
Additional context
cc @jbschlosser who has worked on many testing improvements in the past