8000 `torch.nn.functional.kl_div` fails gradgradcheck if the target requires a gradient · Issue #65466 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

torch.nn.functional.kl_div fails gradgradcheck if the target requires a gradient #65466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
pmeier opened this issue Sep 22, 2021 · 11 comments
Closed
Assignees
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pmeier
Copy link
Collaborator
pmeier commented Sep 22, 2021

🐛 Bug

torch.nn.functional.kl_div fails gradgradcheck if the target requires a gradient.

To Reproduce

import torch
from torch.autograd import gradgradcheck
from torch.nn.functional import softmax, log_softmax, kl_div

torch.manual_seed(0)
# input should be log probabilities
input = log_softmax(torch.randn(3, dtype=torch.float64, requires_grad=True))
# target should be probabilities
target = softmax(torch.randn_like(input, requires_grad=True))

gradgradcheck(kl_div, inputs=(input, target))
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 1,
numerical:tensor([[-0.1344,  0.0000,  0.0000],
        [ 0.0000, -0.1344,  0.0000],
        [ 0.0000,  0.0000, -0.1344]], dtype=torch.float64)
analytical:tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)

The error only shows up, if the target requires a gradient.

Additional context

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @mruberry @jbschlosser @walterddr

@pmeier pmeier added module: autograd Related to torch.autograd, and the autograd engine in general module: nn Related to torch.nn labels Sep 22, 2021
@nikitaved
Copy link
Collaborator
nikitaved commented Sep 22, 2021

It is not a bug I guess, the issue is that tensors of type float are created, rather than of double.

@pmeier
Copy link
Collaborator Author
pmeier commented Sep 22, 2021

My bad, I need to recheck why the gradcheck tests in test/test_ops.py failed. They are run in float64.

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 22, 2021
@pmeier pmeier changed the title torch.nn.functional.kl_div fails gradcheck torch.nn.functional.kl_div fails gradgradcheck if the target requires a gradient Sep 22, 2021
@pmeier
Copy link
Collaborator Author
pmeier commented Sep 22, 2021

@nikitaved

It is not a bug I guess, the issue is that tensors of type float are created, rather than of double.

You were right about that. The tests in test_ops.py failed, because I used an unreasonable lower limit for the input values. Fixing that, now a gradgradcheck< 8000 /code> error shows up. I've edited the original comment with the new reproduction / error message.

@nikitaved
Copy link
Collaborator
nikitaved commented Sep 22, 2021

OK, looks like the backward to kl_div has a custom backward implemented which sets the grad wrt target to zero (parameter 1),
as could be seen here:

- name: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor
.
Which, I guess, is wrong as log is inf-many times differentiable with a non-zero value for positive values.
@pmeier , unless you want to have a look into this, I could try to resolve it.

@lezcano
Copy link
Collaborator
lezcano commented Sep 22, 2021

Note: Interestingly enough, kl_div_backward does have the second derivative implemented by hand, but the derivative of wrt the other input Q, that is, kl_div_target_backward is implicitly differentiable. This is a bit odd.
Given how simple the implementation of the backward is using TensorIterator, we could consider implementing the backward and double backward for kl_div with it.

@nikitaved
Copy link
Collaborator
nikitaved commented Sep 22, 2021

In order to fix that we need to implement kl_div_backward which considers input and target jointly (as a function of two variables), without the assumption that either of them is a constant (as if a function of one variable), like in this case, where we have only partial derivatives defined. Many loss functions suffer for this issue of having double backward being wrong as the double backward of a forward function is a function that depends on both partial derivatives grad_input and grad_target, not just a single one of them.

@pmeier pmeier self-assigned this Sep 22, 2021
@mruberry
Copy link
Collaborator

@albanD what do you think?

@albanD
Copy link
Collaborator
albanD commented Sep 23, 2021

Many loss functions suffer for this issue of having double backward being wrong as the double backward of a forward function is a function that depends on both partial derivatives grad_input and grad_target, not just a single one of them.

I can see individual formulas being wrong but I don't think it is a problem to have the formulas separated out.
You can see the image below where both partial derivative influence properly both inputs inputs.
Or I am missing something?

For example if you do

import torch
# !pip install torchviz
import torchviz


a = torch.rand(1, 10, requires_grad=True)
t = torch.rand(1, 10, requires_grad=True)

loss = torch.kl_div(a, t).sum()

ga, gt = torch.autograd.grad(loss, (a, t), create_graph=True)

torchviz.make_dot((loss, ga, gt), params={k:v for k,v in locals().items() if isinstance(v, torch.Tensor)})

image

@lezcano
Copy link
Collaborator
lezcano commented Sep 23, 2021

fwiw, @pmeier will submit a patch for this later today / tomorrow. We found that having the formulas separated was really the way to go.

On a completely unrelated topic, that's some very cool graph @albanD :D

@nikitaved
Copy link
Collaborator
nikitaved commented Sep 23, 2021

@albanD , when only partial derivatives are defined and I want to double backward, the effects of the differentiated backwards will be accumulated, right? Judging from how kl_backward is implemented, it sets the derivative for input, and the backward to kl_backward sets the grad wrt target to zero, yet kl_backward is a function of target, so it must have a non-zero derivative (kl_div nor its derivatives are constant functions wrt either target or input), or am I missing something? So, apparently, the engine does accumulate grads from the backwards of partial derivatives...

@albanD
Copy link
Collaborator
albanD commented Sep 23, 2021

@nikitaved each backward formula is only responsible to provide the gradient flowing through the forward function they define. If some other function is also using target, then this other function is responsible for computing that part of the gradient.

the engine does accumulate grads from the backwards of partial derivatives

Yes, if a variable is re-used multiple times, the engine will make sure that the gradients from all the usage are added before processing more.
In the graph above, you can see 3 arrows coming out of t's AccumulateGrad Node. That's because it is used by 3 different functions.
Most of the complexity of the execution engine is to actually know how many of these arrows we should expect gradients from for a given backward() call and accumulate them properly.

@pmeier pmeier linked a pull request Sep 27, 2021 that will close this issue
@pmeier pmeier linked a pull request Dec 1, 2021 that will close this issue
facebook-github-bot pushed a commit that referenced this issue Jun 10, 2022
…everse AD support. (#79007) (#79007)

Summary:
Fixes #78867,
fixes #65466.
Adds forward-over-reverse AD support.

Pull Request resolved: #79007
Approved by: https://github.com/soulitzer, https://github.com/jbschlosser

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/72ad222cff59cbe730a49dd828cb0a25d2a18417

Reviewed By: osalpekar

Differential Revision: D37058939

Pulled By: osalpekar

fbshipit-source-id: 28ee709c47bc5fcb82ae31dd4a30e9ecac573709
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
6 participants
0