8000 Edge behavior in `jax.scipy.special.betainc` · Issue #21900 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Edge behavior in jax.scipy.special.betainc #21900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
mdhaber opened this issue Jun 15, 2024 · 10 comments · Fixed by #27107
Closed

Edge behavior in jax.scipy.special.betainc #21900

mdhaber opened this issue Jun 15, 2024 · 10 comments · Fixed by #27107
Assignees
Labels
bug Something isn't working

Comments

@mdhaber
Copy link
mdhaber commented Jun 15, 2024

Description

jax.scipy.special.betainc seems to have trouble with very small values of the parameter a, at least for certain values of b and x.

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a = np.logspace(-40, -1, 300)
b = 1
x = 0.25
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()
image

I know that it is difficult to guarantee accuracy to machine precision for all possible combinations of input : ) Just thought I'd point out this problem spot since it came up in SciPy testing (scipy/scipy#20963).


There is a separate issue related to edge case behavior in betainc that I thought I should bring up: there are some edge cases where betainc returns NaN but another result would be more useful (e.g. in Harrell-Davis quantile estimates). To some extent, the "correct" result is a matter of interpretation, but SciPy recently addressed similar cases in scipy/scipy#22425. While we're waiting for an Array API special function extension (data-apis/array-api#725), please consider adopting this behavior.

import jax.numpy as np
from jax.scipy import special
cases = [((0.0, 0.0, 0.0), np.nan),
         ((0.0, 0.0, 0.5), np.nan),
         ((0.0, 0.0, 1.0), np.nan),
         ((np.inf, np.inf, 0.0), np.nan),
         ((np.inf, np.inf, 0.5), np.nan),
         ((np.inf, np.inf, 1.0), np.nan),
         ((0.0, 1.0, 0.0), 0.0),
         ((0.0, 1.0, 0.5), 1.0),
         ((0.0, 1.0, 1.0), 1.0),
         ((1.0, 0.0, 0.0), 0.0),
         ((1.0, 0.0, 0.5), 0.0),
         ((1.0, 0.0, 1.0), 1.0),
         ((0.0, np.inf, 0.0), 0.0),
         ((0.0, np.inf, 0.5), 1.0),
         ((0.0, np.inf, 1.0), 1.0),
         ((np.inf, 0.0, 0.0), 0.0),
         ((np.inf, 0.0, 0.5), 0.0),
         ((np.inf, 0.0, 1.0), 1.0),
         ((1.0, np.inf, 0.0), 0.0),
         ((1.0, np.inf, 0.5), 1.0),
         ((1.0, np.inf, 1.0), 1.0),
         ((np.inf, 1.0, 0.0), 0.0),
         ((np.inf, 1.0, 0.5), 0.0),
         ((np.inf, 1.0, 1.0), 1.0)]

for args, ref in cases:
    print(special.betainc(*args), ref)

Produces:
(actual, desired)

nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan 0.0
nan 1.0
nan 1.0
nan 0.0
nan 0.0
nan 1.0
nan 0.0
nan 1.0
nan 1.0
nan 0.0
nan 0.0
nan 1.0
nan 0.0
nan 1.0
nan 1.0
nan 0.0
nan 0.0
nan 1.0

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='e901fac133dc', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Sun Apr 28 14:29:16 UTC 2024', machine='x86_64')

@rajasekharporeddy
Copy link
Contributor
rajasekharporeddy commented Jun 17, 2024

Hi @mdhaber

JAX typically uses single-precision floating-point numbers for calculations, while SciPY defaults to double precision. This difference in precision can lead to slightly different results, especially when working with very small numbers. If the double precision is enabled in JAX, then JAX yields the results that are consistent with SciPy even with very small numbers:

import jax

jax.config.update('jax_enable_x64', True)

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a = np.logspace(-40, -1, 300)
b = 1
x = 0.25
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()

image

Please find the gist for reference.

Thank you.

@mdhaber
Copy link
Author
mdhaber commented Jun 17, 2024

Thanks! Although this actually came up in the context of 32-bit calculations. The definitions should have been:

a = np.logspace(-40, -1, 300, dtype=np.float32)
b = np.float32(1.)
x = np.float32(0.25)

and the plot looks the same. SciPy's outputs are float32, so I assume that's being preserved internally, although perhaps it is converting back and forth.
In any case, I know the trouble area is toward the small end of normal numbers and extends into the subnormals, so I understand if it's not a priority. Feel free to close!


To zoom in:

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a0 = np.finfo(np.float32).smallest_normal
b = np.float32(1.)
x = np.float32(0.25)
factor = np.float32(10)
a = np.logspace(np.log10(a0), np.log10(a0*factor), 300, dtype=np.float32)
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()

image

@rajasekharporeddy
Copy link
Contributor
rajasekharporeddy commented Jun 24, 2024

Hi @mdhaber

IIUC, according to scipy/scipy/#8495 (Comment), SciPy do all the internal (low level c) calculations in float64 even if the input is float32 or other. But JAX do it in float32 itself. That might be causing this difference.

Thank you.

@pearu
Copy link
Collaborator
pearu commented Jun 24, 2024

Whatever are the reasons for scipy to use float64 internally (one practical reason could be that there are no float32 implementation available for scipy, for instance), evaluating functions using float32 correctly requires the usage of an algorithm that can properly handle overflows, underflows, or cancellations. Using higher precision is a typical cheap trick to avoid paying attention to these fp errors in implementations of the function algorithms to keep algorithms simple.
So, I wonder what is the location of jax.scipy.special.betainc implementation which may provide explanations for the behavior observed in this issue.

@jakevdp
Copy link
Collaborator
jakevdp commented Jun 24, 2024

@mdhaber
Copy link
Author
mdhaber commented Jun 24, 2024

SciPy do all the internal (low level c) calculations in float64 even if the input is float32 or other.

But that comment is about scipy.ndimage.affine_transform, not scipy.special.betainc.

I confirmed with @steppi that SciPy now uses Boost's ibeta for betainc, and the types seem to be preserved in the calculation.

Here is where betainc is defined in terms of ibeta.
https://github.com/scipy/scipy/blob/e36e728081475466d2faae65e1dfecfa2314c857/scipy/special/functions.json#L118-L123

Here is where ibeta is used for float and double instantiations of the function.
https://github.com/scipy/scipy/blob/e36e728081475466d2faae65e1dfecfa2314c857/scipy/special/boost_special_functions.h#L106-L116

and Boost's ibeta is templated:
https://beta.boost.org/doc/libs/1_68_0/libs/math/doc/html/math_toolkit/sf_beta/ibeta_function.html

@vfdev-5
Copy link
Collaborator
vfdev-5 commented Jul 9, 2024

Checking the same code on GPU, we have a bit different plots:

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a = np.logspace(-40, -1, 300)
b = 1
x = 0.25

output = betainc_jax(xp.asarray(a), b, x)

plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, output, label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()

print(output.devices(), output.dtype)
# {cuda(id=0)} float32

image

and

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a0 = np.finfo(np.float32).smallest_normal
b = np.float32(1.)
x = np.float32(0.25)
factor = np.float32(10)
a = np.logspace(np.log10(a0), np.log10(a0*factor), 300, dtype=np.float32)

output = betainc_jax(xp.asarray(a), b, x)
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, output, label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()

print(output.devices(), output.dtype)
# {cuda(id=0)} float32

image

So, to reproduce the issue we can add on top of the code

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"

@mdhaber
Copy link
Author
mdhaber commented Feb 8, 2025

Thanks for the investigation @vfdev-5. I extended the top post with a description of a separate issue regarding edge case behavior, which is probably easier to address.

< 8000 /task-lists>

@pearu
Copy link
Collaborator
pearu commented Mar 13, 2025

Heads up: #27107 fixes all issues reported here.

@mdhaber
Copy link
Author
mdhaber commented Mar 20, 2025

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants
0