You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I want to use JAX to run the Ferminet package.
I met the problem with jax = 0.4.12, and the errors are reported in the following
INFO:absl:Starting QMC with 2 XLA devices per host across 1 hosts.
2025-05-06 22:36:14.518524: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2025-05-06 22:36:14.518581: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 6029639680 bytes free, 25327697920 bytes total.
Traceback (most recent call last):
File "/home/kf/projects_dir/202505_Ferminet/test.py", line 62, in <module>
train(cfg)
File "/home/kf/projects_dir/202505_Ferminet/ferminet/train.py", line 422, in train
atoms = jnp.stack([jnp.array(atom.coords) for atom in cfg.system.molecule])
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1780, in stack
new_arrays.append(expand_dims(a, axis))
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 890, in expand_dims
return lax.expand_dims(a, axis)
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1328, in expand_dims
return broadcast_in_dim(array, result_shape, broadcast_dims)
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 796, in broadcast_in_dim
return broadcast_in_dim_p.bind(
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/core.py", line 790, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
compiled_fun = xla_primitive_callable(
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/util.py", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/util.py", line 277, in cached
return f(*args, **kwargs)
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
compiled = _xla_callable_uncached(
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
return computation.compile().unsafe_call
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2329, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2651, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2561, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
It can be found out that the error was met when doing the jnp.stack operation.
But I have tried all and found no method to solve this problem.
System info (python version, jaxlib version, accelerator, etc.)
It should be noticed that my machine can not install JAX packages higer than 0.4.12 version otherwise errors were reported : CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12020, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I installed the JAX with pip install --upgrade jax==0.4.12 jaxlib==0.4.12+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html.
The text was updated successfully, but these errors were encountered:
@Eureka10shen I think you should be able to use the latest JAX release (0.6.0) if you upgrade your CUDA version.
It's a pity that both of my two machines can not be proporly used with CUDA >= 12.1 (but it works well with other ml packages) and jax > 0.4.0. So I'm adapting to jax = 0.3.27 back where I modified the codes with it. Thanks for your reply anaway 😀😀😀.
Description
Hi, I want to use
JAX
to run theFerminet
package.I met the problem with jax = 0.4.12, and the errors are reported in the following
It can be found out that the error was met when doing the
jnp.stack
operation.But I have tried all and found no method to solve this problem.
System info (python version, jaxlib version, accelerator, etc.)
CUDA : 12.6
cuDNN version :
Packages :
It should be noticed that my machine can not install
JAX
packages higer than0.4.12
version otherwise errors were reported :CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12020, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I installed the
JAX
withpip install --upgrade jax==0.4.12 jaxlib==0.4.12+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
.The text was updated successfully, but these errors were encountered: