-
Notifications
You must be signed in to change notification settings - Fork 200
Experiment with custom partitioning #1932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1932 +/- ##
==========================================
- Coverage 84.72% 84.65% -0.08%
==========================================
Files 307 307
Lines 18695 18717 +22
Branches 3664 3669 +5
==========================================
+ Hits 15839 15844 +5
- Misses 2114 2130 +16
- Partials 742 743 +1 ☔ View full report in Codecov by Sentry. |
@@ -508,3 +510,51 @@ def device_count() -> int: | |||
jax.device_count() if config.netket_experimental_sharding is True, and mpi.rank otherwise. | |||
""" | |||
return mpi.n_nodes * device_count_per_rank() | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we move this function to a separate file?
bump? |
will require quite a bit more work to add the graidents. |
sigh. it would have been lovely |
For posterity. This is also missing the vmap rule. |
f6a2477
to
6e593ca
Compare
Addresses #1921.
Replaces shard_map with custom partitioning in the operators where it was used.
Infers the sharding from the inputs at compile time, instead of hardcoding it.