8000 Forward-over-reverse gradgradcheck fails for `torch.gradient` on CUDA · Issue #69925 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Forward-over-reverse gradgradcheck fails for torch.gradient on CUDA #69925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
soulitzer opened this issue Dec 14, 2021 · 7 comments
Closed

Forward-over-reverse gradgradcheck fails for torch.gradient on CUDA #69925

soulitzer opened this issue Dec 14, 2021 · 7 comments
Labels
module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module ZeroTensor

Comments

@soulitzer
Copy link
Contributor
soulitzer commented Dec 14, 2021

🐛 Describe the bug

To replicate: Remove the expectedFailure decorator for test_fn_fwgrad_bwgrad in gradient's OpInfo
At of time of posting: only fails on this branch #69740

======================================================================
ERROR: test_fn_fwgrad_bwgrad_gradient_cuda_complex128 (__main__.TestGradientsCUDA)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/local/pytorch4/torch/testing/_internal/common_utils.py", line 1482, in wrapper
    method(*args, **kwargs)
  File "/local/pytorch4/torch/testing/_internal/common_utils.py", line 1482, in wrapper
    method(*args, **kwargs)
  File "/local/pytorch4/torch/testing/_internal/common_device_type.py", line 381, in instantiated_test
    raise rte
  File "/local/pytorch4/torch/testing/_internal/common_device_type.py", line 376, in instantiated_test
    result = test(self, **param_kwargs)
  File "/local/pytorch4/torch/testing/_internal/common_device_type.py", line 753, in test_wrapper
    return test(*args, **kwargs)
  File "test/test_ops.py", line 827, in test_fn_fwgrad_bwgrad
    self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
  File "test/test_ops.py", line 776, in _check_helper
    self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
  File "/local/pytorch4/torch/testing/_internal/common_utils.py", line 2893, in gradgradcheck
    return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs)
  File "/local/pytorch4/torch/autograd/gradcheck.py", line 1544, in gradgradcheck
    return gradcheck(
  File "/local/pytorch4/torch/autograd/gradcheck.py", line 1398, in gradcheck
    return _gradcheck_helper(**args)
  File "/local/pytorch4/torch/autograd/gradcheck.py", line 1412, in _gradcheck_helper
    _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps,
  File "/local/pytorch4/torch/autograd/gradcheck.py", line 1068, in _gradcheck_real_imag
    gradcheck_fn(imag_fn, imag_func_out, imag_inputs, diff_imag_func_out, eps,
  File "/local/pytorch4/torch/autograd/gradcheck.py", line 1281, in _fast_gradcheck
    analytical_vJu = _get_analytical_jacobian_forward_ad(func, inputs, _as_tuple(func_out),
  File "/local/pytorch4/torch/autograd/gradcheck.py", line 343, in _get_analytical_jacobian_forward_ad
    raw_outputs = _as_tuple(fn(*dual_inputs))
  File "/local/pytorch4/torch/autograd/gradcheck.py", line 1030, in wrapped_fn
    return _as_tuple(fn(*new_inputs))
  File "/local/pytorch4/torch/autograd/gradcheck.py", line 1539, in new_func
    grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True,
  File "/local/pytorch4/torch/autograd/__init__.py", line 275, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Versions

At of time of posting: only fails on this branch #69740
Once that has been landed: should be able to replicate on main branch

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7

@soulitzer soulitzer added module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 14, 2021
@jeffdaily
Copy link
Collaborator

Is this related to #70027? I've been trying to debug it and it seems perhaps ZeroTensor is involved.

@soulitzer
Copy link
Contributor Author

@anjali411 maybe #71160 is related?

@jeffdaily
Copy link
Collaborator

Why do we have add_zerotensor and mul_zerotensor, but no div_zerotensor? I added it, and it seems to fix the rocm tests. No more mem access fault.

@anjali411
Copy link
Contributor

yeah #71160 will fix this issue. This issue is being caused because the CUDA dispatch key is not correctly set on EZTs creating with cuda device

@anjali411
Copy link
Contributor

@jeffdaily yeah we could certainly add a fastpath for div as well. Here is a complete list of ops that can benefit from the fastpath in the JVP computation: #69687

anjali411 added a commit that referenced this issue Jan 25, 2022
anjali411 added a commit that referenced this issue Jan 25, 2022
@jeffdaily
Copy link
Collaborator

I added div_zerotensor in PR #71862. Waiting for CI to finish. It did resolve the rocm failures.

anjali411 added a commit that referenced this issue Jan 26, 2022
anjali411 added a commit that referenced this issue Jan 26, 2022
jeffdaily added a commit to ROCm/pytorch that referenced this issue Jan 26, 2022
anjali411 added a commit that referenced this issue Jan 28, 2022
anjali411 added a commit that referenced this issue Jan 28, 2022
anjali411 added a commit that referenced this issue Jan 28, 2022
anjali411 added a commit that referenced this issue Jan 28, 2022
anjali411 added a commit that referenced this issue Jan 28, 2022
anjali411 added a commit that referenced this issue Jan 28, 2022
anjali411 added a commit that referenced this issue Jan 28, 2022
anjali411 added a commit that referenced this issue Jan 28, 2022
anjali411 added a commit that referenced this issue Jan 28, 2022
anjali411 added a commit that referenced this issue Jan 28, 2022
facebook-github-bot pushed a commit that referenced this issue Jan 30, 2022
Summary:
Pull Request resolved: #71611

Fixes #71160 #69925

Test Plan: Imported from OSS

Reviewed By: george-qi

Differential Revision: D33834916

Pulled By: anjali411

fbshipit-source-id: 11cec343e95e2ee188ab7576f26f64aa19317891
pytorchmergebot pushed a commit that referenced this issue Jan 30, 2022
Summary:
Pull Request resolved: #71611

Fixes #71160 #69925

Test Plan: Imported from OSS

Reviewed By: george-qi

Differential Revision: D33834916

Pulled By: anjali411

fbshipit-source-id: 11cec343e95e2ee188ab7576f26f64aa19317891
(cherry picked from commit f6e86f8)
anjali411 added a commit that referenced this issue Feb 2, 2022
anjali411 added a commit that referenced this issue Feb 2, 2022
anjali411 added a commit that referenced this issue Feb 2, 2022
anjali411 added a commit that referenced this issue Feb 2, 2022
anjali411 added a commit that referenced this issue Feb 2, 2022
anjali411 added a commit that referenced this issue Feb 2, 2022
facebook-github-bot pushed a commit that referenced this issue Feb 2, 2022
Summary:
Pull Request resolved: #71611

Fixes #71160 #69925 #69913

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33897543

Pulled By: anjali411

fbshipit-source-id: f1d8608c351876b8c2619da5ef891f74bad30ab5
pytorchmergebot pushed a commit that referenced this issue Feb 2, 2022
Summary:
Pull Request resolved: #71611

Fixes #71160 #69925 #69913

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33897543

Pulled By: anjali411

fbshipit-source-id: f1d8608c351876b8c2619da5ef891f74bad30ab5
(cherry picked from commit 643e666)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71611

Fixes pytorch/pytorch#71160 pytorch/pytorch#69925

Test Plan: Imported from OSS

Reviewed By: george-qi

Differential Revision: D33834916

Pulled By: anjali411

fbshipit-source-id: 11cec343e95e2ee188ab7576f26f64aa19317891
(cherry picked from commit f6e86f8)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71611

Fixes pytorch/pytorch#71160 pytorch/pytorch#69925 #69913

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33897543

Pulled By: anjali411

fbshipit-source-id: f1d8608c351876b8c2619da5ef891f74bad30ab5
(cherry picked from commit 643e666)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71611

Fixes pytorch/pytorch#71160 pytorch/pytorch#69925

Test Plan: Imported from OSS

Reviewed By: george-qi

Differential Revision: D33834916

Pulled By: anjali411

fbshipit-source-id: 11cec343e95e2ee188ab7576f26f64aa19317891
(cherry picked from commit f6e86f8)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71611

Fixes pytorch/pytorch#71160 pytorch/pytorch#69925 #69913

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33897543

Pulled By: anjali411

fbshipit-source-id: f1d8608c351876b8c2619da5ef891f74bad30ab5
(cherry picked from commit 643e666)
@anjali411
Copy link
Contributor

fixed by #71611

facebook-github-bot pushed a commit that referenced this issue Feb 3, 2022
Summary:
Fixes #70027.  Fixes #69925.  Reverts #70061.  Completes a task from #69687.

Pull Request resolved: #71862

Reviewed By: george-qi

Differential Revision: D33824078

Pulled By: anjali411

fbshipit-source-id: 114754161077281f9804476288ed796d344ee54b
pytorchmergebot pushed a commit that referenced this issue Feb 3, 2022
Summary:
Fixes #70027.  Fixes #69925.  Reverts #70061.  Completes a task from #69687.

Pull Request resolved: #71862

Reviewed By: george-qi

Differential Revision: D33824078

Pulled By: anjali411

fbshipit-source-id: 114754161077281f9804476288ed796d344ee54b
(cherry picked from commit 55ce1d5)
cyyever< 9E81 /a> pushed a commit to cyyever/pytorch_private that referenced this issue Feb 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71611

Fixes pytorch/pytorch#71160 pytorch/pytorch#69925

Test Plan: Imported from OSS

Reviewed By: george-qi

Differential Revision: D33834916

Pulled By: anjali411

fbshipit-source-id: 11cec343e95e2ee188ab7576f26f64aa19317891
(cherry picked from commit f6e86f8)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71611

Fixes pytorch/pytorch#71160 pytorch/pytorch#69925 #69913

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33897543

Pulled By: anjali411

fbshipit-source-id: f1d8608c351876b8c2619da5ef891f74bad30ab5
(cherry picked from commit 643e666)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 9, 2022
Summary:
Fixes pytorch/pytorch#70027.  Fixes pytorch/pytorch#69925.  Reverts pytorch/pytorch#70061.  Completes a task from pytorch/pytorch#69687.

Pull Request resolved: pytorch/pytorch#71862

Reviewed By: george-qi

Differential Revision: D33824078

Pulled By: anjali411

fbshipit-source-id: 114754161077281f9804476288ed796d344ee54b
(cherry picked from commit 55ce1d5c02b5869497d07e916c46b6f50287542c)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71611

Fixes pytorch/pytorch#71160 pytorch/pytorch#69925

Test Plan: Imported from OSS

Reviewed By: george-qi

Differential Revision: D33834916

Pulled By: anjali411

fbshipit-source-id: 11cec343e95e2ee188ab7576f26f64aa19317891
(cherry picked from commit f6e86f8)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71611

Fixes pytorch/pytorch#71160 pytorch/pytorch#69925 #69913

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33897543

Pulled By: anjali411

fbshipit-source-id: f1d8608c351876b8c2619da5ef891f74bad30ab5
(cherry picked from commit 643e666)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 9, 2022
Summary:
Fixes pytorch/pytorch#70027.  Fixes pytorch/pytorch#69925.  Reverts pytorch/pytorch#70061.  Completes a task from pytorch/pytorch#69687.

Pull Request resolved: pytorch/pytorch#71862

Reviewed By: george-qi

Differential Revision: D33824078

Pulled By: anjali411

fbshipit-source-id: 114754161077281f9804476288ed796d344ee54b
(cherry picked from commit 55ce1d5c02b5869497d07e916c46b6f50287542c)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module ZeroTensor
Projects
None yet
Development

No branches or pull requests

3 participants
0