8000 `jax.scipy.special.zeta` seems to unconditionally use float32 precision · Issue #17734 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

jax.scipy.special.zeta seems to unconditionally use float32 precision #17734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Gattocrucco opened this issue Sep 22, 2023 · 5 comments
Closed
Assignees
Labels
bug Something isn't working XLA

Comments

@Gattocrucco
Copy link
Contributor

Description

I enable float64 for jax, evaluate jax's and mpmath's zeta, and compare the results. Only the first 7 digits of jax correspond to mpmath:

from jax import config
config.update("jax_enable_x64", True)

from jax.scipy import special as jspecial
import mpmath

z = jspecial.zeta(2, 1).item()
z_accu = mpmath.zeta(2, 1)
print('JAX:   ', z)
print('mpmath:', z_accu)
JAX:    1.6449342726140739
mpmath: 1.64493406684823

This was not happening to me before version 0.4.16.

What jax/jaxlib version are you using?

jax 0.4.16, jaxlib 0.4.16

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.11.2, macOS 13.4

NVIDIA GPU info

No response

@Gattocrucco Gattocrucco added the bug Something isn't working label Sep 22, 2023
@Gattocrucco Gattocrucco changed the title jax.scipy.zeta seems to unconditionally use float32 precision jax.scipy.special.zeta seems to unconditionally use float32 precision Sep 22, 2023
@jakevdp
Copy link
Collaborator
jakevdp commented Sep 22, 2023

Thanks for the report. jax.scipy.special.zeta calls directly into XLA's Zeta operation, so the best place to file this issue would be at https://github.com/openxla/xla

@jakevdp jakevdp added the XLA label Sep 22, 2023
@Gattocrucco
Copy link
Contributor Author

Ok

@jakevdp
Copy link
Collaborator
jakevdp commented Mar 13, 2024

It looks like the XLA bug was marked as fixed – in the next couple days we should be able to test jax.scipy.special.zeta with the jaxlib nightly build to confirm that things look right on the JAX side.

@jakevdp jakevdp self-assigned this Mar 13, 2024
@rajasekharporeddy
Copy link
Contributor

Hi @Gattocrucco

The fix provided by openxla PR #10413 appears to address this issue. I tested the mentioned code with JAX version 0.4.26 in Google Colab. The results produced by JAX closely match those of SciPy up to 15 decimal places. While both JAX and SciPy offer 16 decimal places of precision, JAX's results align with mpmath up to 13 decimal places due to mpmath's truncation at 14 digits.

from jax.scipy import special as jspecial
from scipy import special
import mpmath


z = jspecial.zeta(2, 1).item()
z_scipy = special.zeta(2, 1)
z_accu = mpmath.zeta(2, 1)
print('JAX   :', z)
print('scipy :', z_scipy)
print('mpmath:', z_accu)

Output:

JAX   : 1.6449340668482264
scipy : 1.6449340668482266
mpmath: 1.64493406684823

Please find the gist for reference.

Thank you

@jakevdp
Copy link
Collaborator
jakevdp commented Apr 25, 2024

Thanks for following up!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working XLA
Projects
None yet
Development

No branches or pull requests

3 participants
0