8000 move Functionalize dispatch key closer to backends by bdhirsh · Pull Request #77132 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

move Functionalize dispatch key closer to backends #77132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from

Conversation

bdhirsh
Copy link
Contributor
@bdhirsh bdhirsh commented May 10, 2022

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented May 10, 2022

🔗 Helpful links

✅ No Failures (3 Pending)

As of commit 1d4e8fa (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

bdhirsh added a commit that referenced this pull request May 10, 2022
ghstack-source-id: 9a0e787
Pull Request resolved: #77132
bdhirsh added a commit that referenced this pull request May 10, 2022
ghstack-source-id: 5429f44
Pull Request resolved: #77132
bdhirsh added a commit that referenced this pull request May 11, 2022
ghstack-source-id: c78ae84
Pull Request resolved: #77132
bdhirsh added a commit that referenced this pull request May 13, 2022
ghstack-source-id: 10b59d7
Pull Request resolved: #77132
@bdhirsh bdhirsh requested review from ezyang and zou3519 May 25, 2022 18:15
Copy link
Contributor
@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, but please check that functorch functionalize tests pass after this change

Need this to get functionalize to work with backends (LTC/XLA). Now that we can kill the `DECOMPOSE_FUNCTIONAL` code in functorch (see pytorch/functorch#814), this should be ok to land once that PR merges.




[ghstack-poisoned]
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/230/head branch May 30, 2022 14:17
zou3519 added a commit to pytorch/functorch that referenced this pull request May 31, 2022
…ont,Back}

Fixes #842

As Brian pointed out:

For jvp(sub, ...), the chain of dispatch should be:

```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```

Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)

functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).

The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
pytorch/pytorch#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.

*There is an exception for autocast and vmapmode, described in the next section.

Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.

This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.

We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.

We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/

Wait for tests
zou3519 added a commit to pytorch/functorch that referenced this pull request May 31, 2022
…ont,Back}

Fixes #842

The Diagnosis
=============
As Brian pointed out:

For jvp(sub, ...), the chain of dispatch should be:

```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```

Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)

The Problem
=============

functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).

The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
pytorch/pytorch#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.

*There is an exception for autocast and vmapmode, described in the next section.

The Solution
============

Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.

This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.

We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.

We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/

Test Plan
============
Wait for tests
zou3519 added a commit to pytorch/functorch that referenced this pull request May 31, 2022
…ont,Back} (#843)

Fixes #842

The Diagnosis
=============
As Brian pointed out:

For jvp(sub, ...), the chain of dispatch should be:

```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```

Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)

The Problem
=============

functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).

The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
pytorch/pytorch#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.

*There is an exception for autocast and vmapmode, described in the next section.

The Solution
============

Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.

This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.

We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.

We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/

Test Plan
============
Wait for tests
facebook-github-bot pushed a commit that referenced this pull request May 31, 2022
Summary:
Pull Request resolved: #77132

Approved by: https://github.com/ezyang, https://github.com/zou3519

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/7ff091fc4e66f18c2fd463ca038688b67548a6b0

Reviewed By: seemethere

Differential Revision: D36783103

Pulled By: bdhirsh

fbshipit-source-id: 4b25d31257384588b4b1644f7d45adff683eb025
zou3519 added a commit to zou3519/pytorch that referenced this pull request Jul 20, 2022
…amicLayer{Front,Back} (pytorch/functorch#843)

Fixes pytorch/functorch#842

The Diagnosis
=============
As Brian pointed out:

For jvp(sub, ...), the chain of dispatch should be:

```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```

Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)

The Problem
=============

functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).

The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
pytorch#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.

*There is an exception for autocast and vmapmode, described in the next section.

The Solution
============

Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.

This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.

We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.

We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/

Test Plan
============
Wait for tests
bigfootjon pushed a commit that referenced this pull request Jul 21, 2022
…amicLayer{Front,Back} (pytorch/functorch#843)

Fixes pytorch/functorch#842

The Diagnosis
=============
As Brian pointed out:

For jvp(sub, ...), the chain of dispatch should be:

```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```

Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)

The Problem
=============

functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).

The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.

*There is an exception for autocast and vmapmode, described in the next section.

The Solution
============

Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.

This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.

We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.

We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/

Test Plan
============
Wait for tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0