-
Notifications
You must be signed in to change notification settings - Fork 3k
Edge behavior in jax.scipy.special.betainc
#21900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi @mdhaber JAX typically uses single-precision floating-point numbers for calculations, while SciPY defaults to double precision. This difference in precision can lead to slightly different results, especially when working with very small numbers. If the double precision is enabled in JAX, then JAX yields the results that are consistent with SciPy even with very small numbers: import jax
jax.config.update('jax_enable_x64', True)
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax
a = np.logspace(-40, -1, 300)
b = 1
x = 0.25
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend() Please find the gist for reference. Thank you. |
Thanks! Although this actually came up in the context of 32-bit calculations. The definitions should have been: a = np.logspace(-40, -1, 300, dtype=np.float32)
b = np.float32(1.)
x = np.float32(0.25) and the plot looks the same. SciPy's outputs are To zoom in: import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax
a0 = np.finfo(np.float32).smallest_normal
b = np.float32(1.)
x = np.float32(0.25)
factor = np.float32(10)
a = np.logspace(np.log10(a0), np.log10(a0*factor), 300, dtype=np.float32)
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend() |
Hi @mdhaber IIUC, according to scipy/scipy/#8495 (Comment), SciPy do all the internal (low level c) calculations in Thank you. |
Whatever are the reasons for scipy to use float64 internally (one practical reason could be that there are no float32 implementation available for scipy, for instance), evaluating functions using float32 correctly requires the usage of an algorithm that can properly handle overflows, underflows, or cancellations. Using higher precision is a typical cheap trick to avoid paying attention to these fp errors in implementations of the function algorithms to keep algorithms simple. |
JAX's implementation is here, and mentions that it's based on http://dlmf.nist.gov/8.17.E23: https://github.com/google/jax/blob/2b728d55b6054bba8ae26b3523722e80d660e771/jax/_src/lax/special.py#L182-L190 |
But that comment is about I confirmed with @steppi that SciPy now uses Boost's Here is where Here is where and Boost's |
Checking the same code on GPU, we have a bit different plots: import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax
a = np.logspace(-40, -1, 300)
b = 1
x = 0.25
output = betainc_jax(xp.asarray(a), b, x)
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, output, label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()
print(output.devices(), output.dtype)
# {cuda(id=0)} float32 and import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax
a0 = np.finfo(np.float32).smallest_normal
b = np.float32(1.)
x = np.float32(0.25)
factor = np.float32(10)
a = np.logspace(np.log10(a0), np.log10(a0*factor), 300, dtype=np.float32)
output = betainc_jax(xp.asarray(a), b, x)
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, output, label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()
print(output.devices(), output.dtype)
# {cuda(id=0)} float32 So, to reproduce the issue we can add on top of the code import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu" |
Thanks for the investigation @vfdev-5. I extended the top post with a description of a separate issue regarding edge case behavior, which is probably easier to address. |
Heads up: #27107 fixes all issues reported here. |
Thank you! |
Description
jax.scipy.special.betainc
seems to have trouble with very small values of the parametera
, at least for certain values ofb
andx
.I know that it is difficult to guarantee accuracy to machine precision for all possible combinations of input : ) Just thought I'd point out this problem spot since it came up in SciPy testing (scipy/scipy#20963).
There is a separate issue related to edge case behavior in
betainc
that I thought I should bring up: there are some edge cases wherebetainc
returns NaN but another result would be more useful (e.g. in Harrell-Davis quantile estimates). To some extent, the "correct" result is a matter of interpretation, but SciPy recently addressed similar cases in scipy/scipy#22425. While we're waiting for an Array API special function extension (data-apis/array-api#725), please consider adopting this behavior.Produces:
(actual, desired)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='e901fac133dc', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Sun Apr 28 14:29:16 UTC 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered: