8000 `kl_div`: fix for grads wrt `target`, double backward, forward-over-reverse AD support. by nikitaved · Pull Request #79007 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

kl_div: fix for grads wrt target, double backward, forward-over-reverse AD support. #79007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

nikitaved
Copy link
Collaborator

Fixes #78867,
fixes #65466.
Adds forward-over-reverse AD support.

@nikitaved nikitaved added the module: autograd Related to torch.autograd, and the autograd engine in general label Jun 7, 2022
@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Jun 7, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 35e4465 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@nikitaved nikitaved requested review from jbschlosser and removed request for albanD, ngimel and mruberry June 7, 2022 10:52
@nikitaved nikitaved changed the title Nikitaved/kl div bacward boost kl_div: fix for grads wrt target, double backward, forward-over-reverse AD support. Jun 7, 2022
@nikitaved nikitaved added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 7, 2022
Copy link
Contributor
@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes! I'll suggest alternatively tightening up the input.shape == target.shape check if possible (see comment here). If there are actual use cases that depend on the broadcasting (unlikely for loss calculation, slightly more likely if kl_div is used in a non-loss context), then the approach taken here makes sense.

else {
g = at::where(target > 0, -grad * grad_output, at::zeros({}, grad.options()));
if (reduction == at::Reduction::Mean) {
g = areAnyTensorSubclassLike({g}) ? g / input.numel() : g.mul_(input.numel());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem correct to me - should the in-place version multiply by 1 / input.numel() instead?

Copy link
Collaborator Author
@nikitaved nikitaved Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Should be div_ instead of mul_ indeed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting that the tests didn't catch this - do we need to improve coverage somewhere?

Copy link
Collaborator Author
@nikitaved nikitaved Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @soulitzer . I think this path should have been chosen here... So I wonder just the same. The coverage is otherwise there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked that we are indeed going into the isTensorSubclass(g) == true path. See #79079

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry , could you please tell whether OpInfo tests these paths?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think OpInfo intends to test both isTensorSubclass(g) == true and false paths (test_fn_gradgrad runs gradgradcheck which checks with both batched and non-batched tensors), it's just due to a bug only the true paths is being reached currently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd definitely feel better about correctness if that bug was fixed :p Can we at least run it manually before merging this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested these two paths separately, both work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting question about OpInfo coverage -- I don't think we currently test with a tensor subclass in any systematic way, but I could be mistaken. @ezyang would probably know better

@nikitaved
Copy link
Collaborator Author

@jbschlosser , thank you for your comments! I have addressed them. I have decided to keep broadcasting as kl_div already has a warning about potential issues when broadcasting is used. It suggests, however, that the grads could be wrong. I think it would make sense to change it to something different, or remove entirely. What do you think about that?

@jbschlosser
Copy link
Contributor

@jbschlosser , thank you for your comments! I have addressed them. I have decided to keep broadcasting as kl_div already has a warning about potential issues when broadcasting is used. It suggests, however, that the grads could be wrong. I think it would make sense to change it to something different, or remove entirely. What do you think about that?

Ah interesting - can you point me to that warning please? I only see a warning regarding reduction='mean' vs. reduction='batchmean'.

@nikitaved
Copy link
Collaborator Author
nikitaved commented Jun 7, 2022

Sorry, I confused kl_div with other loss functions. l1_loss and mse_loss have a warning, for example. I suspected them to have the same issue as kl_div, but they are fine wrt autograd since the kernels receive broadcasted inputs. No action is needed, unless we indeed decide to impose shape restrictions.

@jbschlosser
Copy link
Contributor

Sorry, I confused kl_div with other loss functions. l1_loss and mse_loss have a warning, for example. I suspected them to have the same issue as kl_div, but they are fine wrt autograd since the kernels receive broadcasted inputs. No action is needed, unless we indeed decide to impose shape restrictions.

Okay awesome. I see the discussion in #16045 regarding broadcasting for loss functions that led to a warning instead of a hard error. Following the historical precedent set there, I think we should move forward with your broadcasting fix. Maybe in a future PR, it'd also make sense to add a similar warning to kl_div, but let's leave that out for now.

@@ -2869,7 +2869,8 @@ def kl_div(
else:
reduction_enum = _Reduction.get_enum(reduction)

reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target)
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this matches what was done for e.g. nn.MSELoss. With manual broadcasting here, do we still need the kernel changes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not really. But I like to keep it consistent, since forward is using input.numel, not target.numel.

@nikitaved
Copy link
Collaborator Author

JIT issues seem relevant.

@@ -61,12 +61,14 @@ void binary_cross_entropy_backward_out_kernel(Tensor& grad_input, const Tensor&
namespace at { namespace native {

Tensor kl_div_backward_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
auto grad_input = at::empty_like(input);
const auto grad_expanded = grad.expand_as(input);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm backward wasn't failing previously right?

Copy link
Collaborator Author
@nikitaved nikitaved Jun 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it was not because the tests used target.shape() == input.shape() inputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I was thinking that because grad has the same size as the potentially reduced output, we'd need to do some kind of expanding always, and whether target was broadcasted to input or not doesn't matter. But maybe TensorIterator is doing something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TI was complaining about shape mismatch in some cases, and in some it would segfault. Probably worth investigating...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah tensor iterator is supposed to handle broadcasting, but maybe there is an issue specifically with this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either that, or the grad forwarding. I will try to isolate the issue, maybe it is something worth reporting about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that, before the fix, TI didn't know to broadcast because it was handed an unbroadcasted target and grad.

Copy link
Collaborator Author
@nikitaved nikitaved Jun 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TI should have been broadcasted inputs with shapes input=(5, 5, 5) and output=(5, 5) (where it complaints about output-input mismatch), and shapes input=(5, 5, 5), output=(5,) should not have caused any segfaults. Back then TI used to broadcast inputs and outputs to the same shape, but, apparently, this behavior might have been changed, CC @ngimel .

Copy link
Contributor
@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@nikitaved
Copy link
Collaborator Author
nikitaved commented Jun 8, 2022

@jbschlosser , do you have more comments and/or suggestions? Otherwise, if it is fine with you, I will merge it.

@jbschlosser
Copy link
Contributor

@nikitaved I'm good with it, but is there any way to manually test the case that @soulitzer pointed out is wrongly being skipped?

@nikitaved
Copy link
Collaborator Author

@jbschlosser , I can confirm gradcheck is happy with both paths.

@jbschlosser
Copy link
Contributor

@nikitaved good to merge on my end then!

@nikitaved
Copy link
Collaborator Author
nikitaved commented Jun 9, 2022

@pytorchbot merge

lezcano added a commit that referenced this pull request Jun 27, 2022
Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jun 27, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jun 27, 2022
Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212

ghstack-source-id: 9d28b23
Pull Request resolved: #80334
lezcano added a commit that referenced this pull request Jun 28, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jun 28, 2022
Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212

ghstack-source-id: c25d0a0
Pull Request resolved: #80334
lezcano added a commit that referenced this pull request Jun 30, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jun 30, 2022
Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212

ghstack-source-id: 9b824cc
Pull Request resolved: #80334
lezcano added a commit that referenced this pull request Jun 30, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 4, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 4, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 4, 2022
Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212

ghstack-source-id: b02c11e
Pull Request resolved: #80334
pytorchmergebot pushed a commit that referenced this pull request Jul 4, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659
Pull Request resolved: #80334
Approved by: https://github.com/ezyang
facebook-github-bot pushed a commit that referenced this pull request Jul 6, 2022
Summary:
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

Pull Request resolved: #80334
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/828c787ea98da39eb786925eedcb8527aae07153

Reviewed By: mehtanirav

Differential Revision: D37604775

Pulled By: mehtanirav

fbshipit-source-id: b188d47df5a3a820e5c15d9ce18b1a2c3f31f287
lezcano added a commit that referenced this pull request Jul 12, 2022
Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212

ghstack-source-id: 8dd9be6
Pull Request resolved: #80334
lezcano added a commit that referenced this pull request Jul 12, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 12, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 13, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 13, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jul 13, 2022
Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212

ghstack-source-id: 2e71ccd
Pull Request resolved: #80334
pytorchmergebot pushed a commit that referenced this pull request Jul 13, 2022
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659
Pull Request resolved: #80334
Approved by: https://github.com/ezyang
facebook-github-bot pushed a commit that referenced this pull request Jul 14, 2022
Summary:
Benchmarks: #80334 (comment)

Fixes #80158
Fixes #78867
Fixes #69230

Supersedes #79007
Supersedes #69212
Supersedes #19659

Pull Request resolved: #80334
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/b5b9db9f844f4f100651c6afa57124fa5851edec

Reviewed By: DanilBaibak

Differential Revision: D37847477

Pulled By: DanilBaibak

fbshipit-source-id: a04919bbd2b746c30c654b971efcf76ef27ac5a6
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Aug 15, 2022
@lezcano
Copy link
Collaborator
lezcano commented Aug 15, 2022

Superseded by #80334

@lezcano lezcano closed this Aug 15, 2022
@rgommers rgommers removed the Merged label Sep 13, 2022
@github-actions github-actions bot deleted the nikitaved/kl_div_bacward_boost branch February 17, 2024 01:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request cla signed module: autograd Related to torch.autograd, and the autograd engine in general open source release notes: autograd release notes category Reverted Stale
Projects
None yet
Development

Successfully merging this pull request may close these issues.

kl_div will crash in the backward pass on cuda torch.nn.functional.kl_div fails gradgradcheck if the target requires a gradient
0