8000 Extend CSR constructor to support batched indices and values by IvanYashchuk · Pull Request #74542 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Extend CSR constructor to support batched indices and values #74542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from

Conversation

IvanYashchuk
Copy link
Collaborator

This is the first portion of changes required to enable Batched CSR format described in #60854 (comment).

Currently, only the same batch shape for indices and values is allowed. In the future, we could enable "broadcasting" of indices and batched values, as done in xFormers (https://github.com/facebookresearch/xformers/blob/dd96b8d8beda5308fb433c1ef3ff04b7f178c263/xformers/components/attention/_sputnik_sparse.py#L441).

This PR adds possibility to construct a batched CSR matrix with torch.sparse_csr_tensor and this batched CSR can be converted to a dense tensor with a .to_dense() call.

@IvanYashchuk IvanYashchuk added module: sparse Related to torch.sparse release notes: sparse release notes category topic: improvements topic category labels Mar 22, 2022
@IvanYashchuk IvanYashchuk requested a review from cpuhrsch March 22, 2022 14:19
@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Mar 22, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 0c79aa7 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

// std::array<int64_t, 2> size = {0, 0};
auto size = DimVector(IntArrayRef(col_indices.sizes().data(), col_indices.dim() - 1));
size.push_back(crow_indices.size(-1) - 1);
size.push_back(col_indices.max().item<int64_t>() + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

col_indices are always guaranteed to be int64_t now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but here .max().item() gives a Scalar and .item<int64_t> casts the Scalar to int64_t.

return item().to##name(); \

@IvanYashchuk
Copy link
Collaborator Author

Okay, a few tests are really failing. I'll resolve the failures.

from functools import reduce
for batch_shape in ((), (2,), (2, 3)):
prod = reduce(mul, batch_shape, 1)
crow_indices = torch.tensor([0, 2, 4], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically these indices don't have to be the same for each batch entry. A more powerful test would potentially modify them as well to be different for each batch entry.

Copy link
Contributor
@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think generally this looks fine, but let's wait until you resolved the current CI failures.

@pytorch-bot
Copy link
pytorch-bot bot commented Mar 23, 2022

We have recently simplified the CIFlow labels and ciflow/cuda is no longer in use.
You can use any of the following

  • ciflow/trunk (.github/workflowss/trunk.yml): all jobs we run per-commit on master
  • ciflow/periodic (.github/workflows/periodic.yml): all jobs we run periodically on master
  • ciflow/all: trunk + periodic; all jobs we run in master CI
  • ciflow/nightly (.github/workflows/nightly.yml): all jobs we run nightly
  • ciflow/binaries: all binary build and upload jobs

@pytorch-bot
Copy link
pytorch-bot bot commented Mar 23, 2022

We have recently simplified the CIFlow labels and ciflow/cpu is no longer in use.
You can use any of the following

  • ciflow/trunk (.github/workflowss/trunk.yml): all jobs we run per-commit on master
  • ciflow/periodic (.github/workflows/periodic.yml): all jobs we run periodically on master
  • ciflow/all: trunk + periodic; all jobs we run in master CI
  • ciflow/nightly (.github/workflows/nightly.yml): all jobs we run nightly
  • ciflow/binaries: all binary build and upload jobs

@IvanYashchuk IvanYashchuk added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 23, 2022
@cpuhrsch
Copy link
Contributor
cpuhrsch commented Apr 4, 2022

@IvanYashchuk - could you rebase this on top of a green commit please? See https://hud.pytorch.org/ (e.g. c5872e6). Hopefully that'll fix the lint CI error.

@IvanYashchuk
Copy link
Collaborator Author

The base of this is viable/strict branch with bf16552, which is green.

@cpuhrsch
Copy link
Contributor
cpuhrsch commented Apr 4, 2022

@IvanYashchuk - well then let's try rerunning those jobs again.

@cpuhrsch
Copy link
Contributor
cpuhrsch commented Apr 4, 2022

@IvanYashchuk - FYI there's a PR that aims to prevent a broken master lint job from holding up PRs that are built upon viable strict. #75199

@cpuhrsch
Copy link
Contributor
cpuhrsch commented Apr 4, 2022

@pytorchbot merge this

malfet added a commit that referenced this pull request Apr 5, 2022
It caused a number of internal only compilation failures, for example
see:
#74425 (comment)
and #74542 (comment)

[ghstack-poisoned]
malfet added a commit that referenced this pull request Apr 5, 2022
It caused a number of internal only compilation failures, for example
see:
#74425 (comment)
and #74542 (comment)

[ghstack-poisoned]
malfet added a commit that referenced this pull request Apr 5, 2022
It caused a number of internal only compilation failures, for example
see:
#74425 (comment)
and #74542 (comment)

ghstack-source-id: 14889fa
Pull Request resolved: #75085
pytorchmergebot pushed a commit that referenced this pull request Apr 5, 2022
It caused a number of internal only compilation failures, for example
see:
#74425 (comment)
and #74542 (comment)

Pull Request resolved: #75085

Approved by: https://github.com/ngimel, https://github.com/albanD
@b0noI
Copy link
Contributor
b0noI commented Apr 5, 2022

@pytorchbot revert this

@b0noI
Copy link
Contributor
b0noI commented Apr 5, 2022

Internal errors:

Summary:
stderr: caffe2/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp:983:20: error: lambda capture 'C_crow_indices' is not used [-Werror,-Wunused-lambda-capture]
auto fix_nnz = [&C_crow_indices, &m](int nnz) -> int {
~^~~~~~~~~~~~~~~
caffe2/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp:983:37: error: lambda capture 'm' is not used [-Werror,-Wunused-lambda-capture]
auto fix_nnz = [&C_crow_indices, &m](int nnz) -> int {
~~~^
stderr: caffe2/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp:983:20: error: lambda capture 'C_crow_indices' is not used [-Werror,-Wunused-lambda-capture]
auto fix_nnz = [&C_crow_indices, &m](int nnz) -> int {
~^~~~~~~~~~~~~~~
caffe2/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp:983:37: error: lambda capture 'm' is not used [-Werror,-Wunused-lambda-capture]
auto fix_nnz = [&C_crow_indices, &m](int nnz) -> int {
~~~^

9E88

pytorchmergebot added a commit that referenced this pull request Apr 5, 2022
@cpuhrsch
Copy link
Contributor
cpuhrsch commented Apr 5, 2022

@malfet - We might want to make this error part of the CI too

@cpuhrsch cpuhrsch reopened this Apr 5, 2022
@IvanYashchuk IvanYashchuk requested a review from a team as a code owner April 5, 2022 21:57
@cpuhrsch
Copy link
Contributor
cpuhrsch commented Apr 5, 2022

@IvanYashchuk - I merged master and added a simple fix for this and will attempt to merge again once the CI runs green

diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
index 27432431f7..7cfe1248fb 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
@@ -983,15 +983,21 @@ void add_out_sparse_csr(
   auto C_col_indices_ptr = C_col_indices.data_ptr<int>();

   // Windows compilers don't support nested macros
-  // so we need this lambda outside of the AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES
-  auto fix_nnz = [&C_crow_indices, &m](int nnz) -> int {
-    // For some reason POINTER_MODE_HOST is not working here
-    // Let's extract manually the nnz from the C_crow_indices
-    #if AT_ROCM_ENABLED()
+  // so we need this lambda outside of the
+  // AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES
+  auto fix_nnz = [
+#if AT_ROCM_ENABLED()
+                     &C_crow_indices,
+                     &m
+#endif
+  ](int nnz) -> int {
+// For some reason POINTER_MODE_HOST is not working here
+// Let's extract manually the nnz from the C_crow_indices
+#if AT_ROCM_ENABLED()
     return std::max({nnz, C_crow_indices.narrow(-1, m, 1).item<int>()});
-    #else
+#else
     return nnz;
-    #endif
+#endif
   };

@malfet
Copy link
Contributor
malfet commented Apr 5, 2022

@malfet - We might want to make this error part of the CI too

[Edit] This warning is only generated by clang(see really old gcc feature request) , and we do not have CUDA+clang builds configured in our CI at the moment (trying to add this in #75293)

facebook-github-bot pushed a commit that referenced this pull request Apr 7, 2022
Summary:
It caused a number of internal only compilation failures, for example
see:
#74425 (comment)
and #74542 (comment)

Pull Request resolved: #75085

Approved by: https://github.com/ngimel, https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/90a56fc515dbac9534a1a14110f9edf089430f81

Reviewed By: b0noI

Differential Revision: D35404322

Pulled By: malfet

fbshipit-source-id: aaa7033d0b7cbfcc1d4b3eeff86d09eba428f068
@IvanYashchuk
Copy link
Collaborator Author

@cpuhrsch, let's try one more time? 🤞

@cpuhrsch
Copy link
Contributor
cpuhrsch commented Apr 7, 2022

@pytorchbot merge this

facebook-github-bot pushed a commit that referenced this pull request Apr 8, 2022
Summary:
This is the first portion of changes required to enable Batched CSR format described in #60854 (comment).

Currently, only the same batch shape for indices and values is allowed. In the future, we could enable "broadcasting" of indices and batched values, as done in xFormers (https://github.com/facebookresearch/xformers/blob/dd96b8d8beda5308fb433c1ef3ff04b7f178c263/xformers/components/attention/_sputnik_sparse.py#L441).

This PR adds possibility to construct a batched CSR matrix with `torch.sparse_csr_tensor` and this batched CSR can be converted to a dense tensor with a `.to_dense()` call.

Pull Request resolved: #74542
Approved by: https://github.com/cpuhrsch

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/c7ae23b50e5f96261889ab9d55df1be7a6b1d55f

Reviewed By: b0noI

Differential Revision: D35485699

fbshipit-source-id: fa1c0c5cf256ac886717a9016a83e62ea2772f75
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request cla signed module: sparse Related to torch.sparse open source release notes: sparse release notes category Reverted topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants
0