-
Notifications
You must be signed in to change notification settings - Fork 200
[WIP] Update to jax 0.6, support arbitrary sharding (drop MPI) #2059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
deprecatio fixes deprecation of safe zip
Remove unused 'states' iterator in discrete hilbert In Hilbert: only check when array is not sharded change annotation hilbert index: unconstrained hilbert random, get things to work cleanup hilbert fix sharding in hilbert fix tests fix hilbert ruff
support sharding in ising typos in operator sharding fixes for local operator
c remove comments
This assumes that whatever model the user provides correctly declares its sharding
cc @inailuig I'm putting this out there in case you ever want to take stabs at it as well.. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2059 +/- ##
===========================================
- Coverage 85.14% 59.16% -25.99%
===========================================
Files 317 317
Lines 19289 19315 +26
Branches 2429 2436 +7
===========================================
- Hits 16423 11427 -4996
- Misses 2108 7124 +5016
- Partials 758 764 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
… with jax 0.7; TODO remove those lines completely
WIP stuff experimenting in moving the whole of netket over to jax explicit sharding, which should simplify all our weird edge cases when we deal with sharding.
shard_map
while minimising issues.The goal is to also allow for arbitrary sharding patterns, support out of the box multiple gpus and to considerably simplify our internals.
A side objective would be to drop all of MPI.
Details:
out_sharding=...
argument inhilbert.all_states()
.sampler.init
added a newsampler.init(out_sharding=P("S"))
argument with that default, which means that samples are sharded over the samples axes, the current default. However allows for arbitrary sharding patterns.Right now I'm slowly moving through getting several bugs in jax fixed.
jax-ml/jax#28111 (done)
jax-ml/jax#28538 (done)
jax-ml/jax#28542 (done)
jax-ml/jax#29195