-
Notifications
You must be signed in to change notification settings - Fork 3k
jax.scipy.special.zeta
seems to unconditionally use float32 precision
#17734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
jax.scipy.zeta
seems to unconditionally use float32 precisionjax.scipy.special.zeta
seems to unconditionally use float32 precision
Thanks for the report. |
Ok |
It looks like the XLA bug was marked as fixed – in the next couple days we should be able to test |
Hi @Gattocrucco The fix provided by openxla PR #10413 appears to address this issue. I tested the mentioned code with JAX version 0.4.26 in Google Colab. The results produced by JAX closely match those of SciPy up to 15 decimal places. While both JAX and SciPy offer 16 decimal places of precision, JAX's results align with mpmath up to 13 decimal places due to mpmath's truncation at 14 digits. from jax.scipy import special as jspecial
from scipy import special
import mpmath
z = jspecial.zeta(2, 1).item()
z_scipy = special.zeta(2, 1)
z_accu = mpmath.zeta(2, 1)
print('JAX :', z)
print('scipy :', z_scipy)
print('mpmath:', z_accu) Output:
Please find the gist for reference. Thank you |
Thanks for following up! |
Description
I enable float64 for jax, evaluate jax's and mpmath's zeta, and compare the results. Only the first 7 digits of jax correspond to mpmath:
This was not happening to me before version 0.4.16.
What jax/jaxlib version are you using?
jax 0.4.16, jaxlib 0.4.16
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.11.2, macOS 13.4
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: