8000 [WIP] Update to jax 0.6, support arbitrary sharding (drop MPI) by PhilipVinc · Pull Request #2059 · netket/netket · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[WIP] Update to jax 0.6, support arbitrary sharding (drop MPI) #2059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 40 commits into
base: master
Choose a base branch
from

Conversation

PhilipVinc
Copy link
Member
@PhilipVinc PhilipVinc commented Jun 3, 2025

WIP stuff experimenting in moving the whole of netket over to jax explicit sharding, which should simplify all our weird edge cases when we deal with sharding.

  • This uses the new 'explicit sharding' interface of jax which is supposed to be the stable way forward to replace shard_map while minimising issues.
  • This will allow us to drop a lot of messy code...

The goal is to also allow for arbitrary sharding patterns, support out of the box multiple gpus and to considerably simplify our internals.

A side objective would be to drop all of MPI.

Details:

  • General: users must now declare the explicit sharding mesh. Some default way to declare the mesh like below could be provided.
                import jax
                import netket as nk

                # Create a mesh with all the devices
                mesh = jax.make_mesh(
                    (len(jax.devices()),),  # How many devices
                    ("S"),                  # The name of the axis. 'S' is standard for 'samples'.
                    axis_types=(
                        jax.sharding.AxisType.Explicit,  # Explicit sharding is required by netket
                    ),
                )
                jax.sharding.set_mesh(mesh) # Set this as the default mesh for jax.
  • Hilbert: all hilbert indexes now correctly propagate sharding of the input. Would be interested to support an out_sharding=... argument in hilbert.all_states().
  • Sampler: when calling sampler.init added a new sampler.init(out_sharding=P("S")) argument with that default, which means that samples are sharded over the samples axes, the current default. However allows for arbitrary sharding patterns.
  • Operators: rewritten to support correctly respect the sharding of the input and maintain it in the output

Right now I'm slowly moving through getting several bugs in jax fixed.
jax-ml/jax#28111 (done)
jax-ml/jax#28538 (done)
jax-ml/jax#28542 (done)
jax-ml/jax#29195

PhilipVinc added 13 commits May 31, 2025 13:59
deprecatio fixes

deprecation of safe zip
Remove unused 'states' iterator in discrete hilbert

In Hilbert: only check when array is not sharded

change annotation

hilbert index: unconstrained

hilbert random, get things to work

cleanup hilbert

fix sharding in hilbert

fix tests

fix hilbert ruff
support sharding in ising

typos in operator

sharding fixes for local operator
fixes
fix qgtdense
fix

improvements
This assumes that whatever model the user provides correctly declares its sharding
@PhilipVinc
Copy link
Member Author

cc @inailuig I'm putting this out there in case you ever want to take stabs at it as well..
it's going to take a while

Copy link
codecov bot commented Jun 3, 2025

Codecov Report

Attention: Patch coverage is 50.00000% with 115 lines in your changes missing coverage. Please review.

Project coverage is 59.16%. Comparing base (aafb876) to head (f012665).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
netket/operator/_local_operator/jax.py 8.69% 21 Missing ⚠️
netket/jax/_map.py 24.00% 19 Missing ⚠️
netket/hilbert/random/fock.py 23.07% 10 Missing ⚠️
netket/jax/sharding.py 60.00% 5 Missing and 5 partials ⚠️
netket/vqs/mc/kernels.py 0.00% 8 Missing ⚠️
netket/sampler/parallel_tempering.py 12.50% 7 Missing ⚠️
netket/optimizer/qgt/qgt_jacobian_dense.py 25.00% 6 Missing ⚠️
netket/sampler/base.py 58.33% 3 Missing and 2 partials ⚠️
netket/hilbert/random/base.py 66.66% 4 Missing ⚠️
netket/sampler/rules/exchange.py 0.00% 4 Missing ⚠️
... and 11 more

❗ There is a different number of reports uploaded between BASE (aafb876) and HEAD (f012665). Click for more details.

HEAD has 1 upload less than BASE
Flag BASE (aafb876) HEAD (f012665)
6 5
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2059       +/-   ##
===========================================
- Coverage   85.14%   59.16%   -25.99%     
===========================================
  Files         317      317               
  Lines       19289    19315       +26     
  Branches     2429     2436        +7     
===========================================
- Hits        16423    11427     -4996     
- Misses       2108     7124     +5016     
- Partials      758      764        +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0