-
Notifications
You must be signed in to change notification settings - Fork 24.4k
kl_div
: fix for grads wrt target
, double backward, forward-over-reverse AD support.
#79007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 35e4465 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
kl_div
: fix for grads wrt target
, double backward, forward-over-reverse AD support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fixes! I'll suggest alternatively tightening up the input.shape == target.shape
check if possible (see comment here). If there are actual use cases that depend on the broadcasting (unlikely for loss calculation, slightly more likely if kl_div
is used in a non-loss context), then the approach taken here makes sense.
else { | ||
g = at::where(target > 0, -grad * grad_output, at::zeros({}, grad.options())); | ||
if (reduction == at::Reduction::Mean) { | ||
g = areAnyTensorSubclassLike({g}) ? g / input.numel() : g.mul_(input.numel()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem correct to me - should the in-place version multiply by 1 / input.numel()
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Should be div_
instead of mul_
indeed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting that the tests didn't catch this - do we need to improve coverage somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @soulitzer . I think this path should have been chosen here... So I wonder just the same. The coverage is otherwise there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked that we are indeed going into the isTensorSubclass(g) == true
path. See #79079
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry , could you please tell whether OpInfo tests these paths?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think OpInfo intends to test both isTensorSubclass(g) == true
and false
paths (test_fn_gradgrad runs gradgradcheck which checks with both batched and non-batched tensors), it's just due to a bug only the true
paths is being reached currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd definitely feel better about correctness if that bug was fixed :p Can we at least run it manually before merging this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested these two paths separately, both work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting question about OpInfo coverage -- I don't think we currently test with a tensor subclass in any systematic way, but I could be mistaken. @ezyang would probably know better
@jbschlosser , thank you for your comments! I have addressed them. I have decided to keep broadcasting as |
Ah interesting - can you point me to that warning please? I only see a warning regarding |
Sorry, I confused |
Okay awesome. I see the discussion in #16045 regarding broadcasting for loss functions that led to a warning instead of a hard error. Following the historical precedent set there, I think we should move forward with your broadcasting fix. Maybe in a future PR, it'd also make sense to add a similar warning to |
@@ -2869,7 +2869,8 @@ def kl_div( | |||
else: | |||
reduction_enum = _Reduction.get_enum(reduction) | |||
|
|||
reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) | |||
expanded_input, expanded_target = torch.broadcast_tensors(input, target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that this matches what was done for e.g. nn.MSELoss
. With manual broadcasting here, do we still need the kernel changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, not really. But I like to keep it consistent, since forward is using input.numel
, not target.numel
.
JIT issues seem relevant. |
@@ -61,12 +61,14 @@ void binary_cross_entropy_backward_out_kernel(Tensor& grad_input, const Tensor& | |||
namespace at { namespace native { | |||
|
|||
Tensor kl_div_backward_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) { | |||
auto grad_input = at::empty_like(input); | |||
const auto grad_expanded = grad.expand_as(input); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm backward wasn't failing previously right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it was not because the tests used target.shape() == input.shape()
inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I was thinking that because grad has the same size as the potentially reduced output, we'd need to do some kind of expanding always, and whether target was broadcasted to input or not doesn't matter. But maybe TensorIterator is doing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TI was complaining about shape mismatch in some cases, and in some it would segfault. Probably worth investigating...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah tensor iterator is supposed to handle broadcasting, but maybe there is an issue specifically with this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either that, or the grad forwarding. I will try to isolate the issue, maybe it is something worth reporting about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that, before the fix, TI didn't know to broadcast because it was handed an unbroadcasted target
and grad
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TI should have been broadcasted inputs with shapes input=(5, 5, 5) and output=(5, 5)
(where it complaints about output-input mismatch), and shapes input=(5, 5, 5), output=(5,)
should not have caused any segfaults. Back then TI used to broadcast inputs and outputs to the same shape, but, apparently, this behavior might have been changed, CC @ngimel .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@jbschlosser , do you have more comments and/or suggestions? Otherwise, if it is fine with you, I will merge it. |
@nikitaved I'm good with it, but is there any way to manually test the case that @soulitzer pointed out is wrongly being skipped? |
@jbschlosser , I can confirm gradcheck is happy with both paths. |
@nikitaved good to merge on my end then! |
@pytorchbot merge |
Benchmarks: #80334 (comment) Fixes #80158 Fixes #78867 Fixes #69230 Supersedes #79007 Supersedes #69212 Supersedes #19659 Pull Request resolved: #80334 Approved by: https://github.com/ezyang
Summary: Benchmarks: #80334 (comment) Fixes #80158 Fixes #78867 Fixes #69230 Supersedes #79007 Supersedes #69212 Supersedes #19659 Pull Request resolved: #80334 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/828c787ea98da39eb786925eedcb8527aae07153 Reviewed By: mehtanirav Differential Revision: D37604775 Pulled By: mehtanirav fbshipit-source-id: b188d47df5a3a820e5c15d9ce18b1a2c3f31f287
Benchmarks: #80334 (comment) Fixes #80158 Fixes #78867 Fixes #69230 Supersedes #79007 Supersedes #69212 Supersedes #19659 Pull Request resolved: #80334 Approved by: https://github.com/ezyang
Summary: Benchmarks: #80334 (comment) Fixes #80158 Fixes #78867 Fixes #69230 Supersedes #79007 Supersedes #69212 Supersedes #19659 Pull Request resolved: #80334 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/b5b9db9f844f4f100651c6afa57124fa5851edec Reviewed By: DanilBaibak Differential Revision: D37847477 Pulled By: DanilBaibak fbshipit-source-id: a04919bbd2b746c30c654b971efcf76ef27ac5a6
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Superseded by #80334 |
Fixes #78867,
fixes #65466.
Adds forward-over-reverse AD support.