8000 Failure of cuDNN initialization at jax 0.4.12 · Issue #28550 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Failure of cuDNN initialization at jax 0.4.12 #28550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Eureka10shen opened this issue May 6, 2025 · 3 comments
Closed

Failure of cuDNN initialization at jax 0.4.12 #28550

Eureka10shen opened this issue May 6, 2025 · 3 comments
Labels
bug Something isn't working

Comments

@Eureka10shen
Copy link

Description

Hi, I want to use JAX to run the Ferminet package.
I met the problem with jax = 0.4.12, and the errors are reported in the following

INFO:absl:Starting QMC with 2 XLA devices per host across 1 hosts.
2025-05-06 22:36:14.518524: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2025-05-06 22:36:14.518581: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 6029639680 bytes free, 25327697920 bytes total.
Traceback (most recent call last):
  File "/home/kf/projects_dir/202505_Ferminet/test.py", line 62, in <module>
    train(cfg)
  File "/home/kf/projects_dir/202505_Ferminet/ferminet/train.py", line 422, in train
    atoms = jnp.stack([jnp.array(atom.coords) for atom in cfg.system.molecule])
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1780, in stack
    new_arrays.append(expand_dims(a, axis))
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 890, in expand_dims
    return lax.expand_dims(a, axis)
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1328, in expand_dims
    return broadcast_in_dim(array, result_shape, broadcast_dims)
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 796, in broadcast_in_dim
    return broadcast_in_dim_p.bind(
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/core.py", line 380, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/core.py", line 790, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
    compiled_fun = xla_primitive_callable(
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/util.py", line 284, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/util.py", line 277, in cached
    return f(*args, **kwargs)
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
    compiled = _xla_callable_uncached(
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
    return computation.compile().unsafe_call
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2329, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2651, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2561, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/kf/anaconda3/envs/jax427/lib/python3.10/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

It can be found out that the error was met when doing the jnp.stack operation.
But I have tried all and found no method to solve this problem.

System info (python version, jaxlib version, accelerator, etc.)

CUDA : 12.6
cuDNN version :

>>> import torch
>>> torch.backends.cudnn.version()
90501
>>> exit()

Packages :

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
_openmp_mutex             5.1                       1_gnu    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
absl-py                   2.2.2                    pypi_0    pypi
attrs                     25.3.0                   pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
ca-certificates           2025.2.25            h06a4308_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
chex                      0.1.82                   pypi_0    pypi
cloudpickle               3.1.1                    pypi_0    pypi
decorator                 5.2.1                    pypi_0    pypi
distrax                   0.1.5                    pypi_0    pypi
dm-tree                   0.1.9                    pypi_0    pypi
folx                      0.2.16                   pypi_0    pypi
gast                      0.6.0                    pypi_0    pypi
h5py                      3.13.0                   pypi_0    pypi
immutabledict             4.2.1                    pypi_0    pypi
jax                       0.4.12                   pypi_0    pypi
jaxlib                    0.4.12+cuda12.cudnn89          pypi_0    pypi
jaxtyping                 0.3.2                    pypi_0    pypi
ld_impl_linux-64          2.40                 h12ee557_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libffi                    3.4.4                h6a678d5_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libgcc-ng                 11.2.0               h1234567_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libgomp                   11.2.0               h1234567_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libstdcxx-ng              11.2.0               h1234567_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libuuid                   1.41.5               h5eee18b_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
ml-collections            1.1.0                    pypi_0    pypi
ml-dtypes                 0.5.1                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
numpy                     1.25.0                   pypi_0    pypi
openssl                   3.0.16               h5eee18b_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
opt-einsum                3.4.0                    pypi_0    pypi
optax                     0.2.0                    pypi_0    pypi
pandas                    2.2.3                    pypi_0    pypi
pip                       25.1               pyhc872135_2    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
pyblock                   0.6                      pypi_0    pypi
pyscf                     2.9.0                    pypi_0    pypi
python                    3.10.16              he870216_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
python-dateutil           2.9.0.post0              pypi_0    pypi
pytz                      2025.2                   pypi_0    pypi
pyyaml                    6.0.2                    pypi_0    pypi
readline                  8.2                  h5eee18b_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
scipy                     1.12.0                   pypi_0    pypi
setuptools                78.1.1          py310h06a4308_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
six                       1.17.0                   pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
tensorflow-probability    0.25.0                   pypi_0    pypi
tk                        8.6.14               h39e8969_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
toolz                     1.0.0                    pypi_0    pypi
typing-extensions         4.13.2                   pypi_0    pypi
tzdata                    2025.2                   pypi_0    pypi
wadler-lindig             0.1.5                    pypi_0    pypi
wheel                     0.45.1          py310h06a4308_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
wrapt                     1.17.2                   pypi_0    pypi
xz                        5.6.4                h5eee18b_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
zlib                      1.2.13               h5eee18b_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main

It should be noticed that my machine can not install JAX packages higer than 0.4.12 version otherwise errors were reported : CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12020, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

I installed the JAX with pip install --upgrade jax==0.4.12 jaxlib==0.4.12+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html.

@Eureka10shen Eureka10shen added the bug Something isn't working label May 6, 2025
@superbobry
Copy link
Collaborator

Huh, I wonder if our installation instructions are up to date. We promise to support CUDA >=12.1, but from error it looks like we now require >=12.2.

@hawkinsp does this look right?

@superbobry
Copy link
Collaborator

@Eureka10shen I think you should be able to use the latest JAX release (0.6.0) if you upgrade your CUDA version.

@Eureka10shen
Copy link
Author

@Eureka10shen I think you should be able to use the latest JAX release (0.6.0) if you upgrade your CUDA version.

It's a pity that both of my two machines can not be proporly used with CUDA >= 12.1 (but it works well with other ml packages) and jax > 0.4.0. So I'm adapting to jax = 0.3.27 back where I modified the codes with it. Thanks for your reply anaway 😀😀😀.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants
0