8000 Feature request: JAX implementation of scipy.special.jv · Issue #11002 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Feature request: JAX implementation of scipy.special.jv #11002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
saopeter opened this issue Jun 6, 2022 · 10 comments
Open

Feature request: JAX implementation of scipy.special.jv #11002

saopeter opened this issue Jun 6, 2022 · 10 comments
Labels
enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@saopeter
Copy link
saopeter commented Jun 6, 2022

Hi,

I'm working on a research problem that involves computations of Bessel functions of the first kind, of multiple different integer orders, and I'm trying to leverage JAX for autodiff and fast computation on GPUs.

The specific function required by the equations I'm working with is implemented by scipy in scipy.special.jv. JAX's implementation of scipy.special exposes a limited amount of functionality associated with Bessel functions -- the ones I could find are i0, i1, i0e, and i1e, but not any Bessel J.

A similar Bessel function related issue has been raised in the past, at #2466, where someone requested a JAX implementation of scipy.special.iv. The accepted solution for that issue was to use the TensorFlow Probability library implementation of ive. Unfortunately, this solution does not seem to work in my case: TFP does not seem to have an implementation of jv, and I cannot convert between ive and jv since the TFP implementation of ive does not accept complex arguments as inputs.

Thus I'd like to request that JAX (or TFP) implement scipy.special.jv. Alternatively, if someone could point me to a way I could translate one of the existing third party Bessel function implementations (eg. those provided by scipy) into a JAX-composable jv compatible with jax.grad and jax.jit myself, I would appreciate it.

Thanks,

@saopeter saopeter added the enhancement New feature or request label Jun 6, 2022
@Gattocrucco
Copy link
Contributor
Gattocrucco commented Jun 7, 2022

Maybe you could adapt an existing algorithm into a numba-jittable function and see #1870 and/or https://github.com/josipd/jax/blob/master/jax/experimental/jambax.py to plug the function into a jax primitive (see also #9956).

Here is some code I am using to add some special functions but without jit, if it helps:

def makejaxufunc(ufunc, *derivs):
    # TODO use jax.something.standard_primitive
    
    prim = core.Primitive(ufunc.__name__)
    
    @functools.wraps(ufunc)
    def func(*args):
        return prim.bind(*args)

    @prim.def_impl
    def impl(*args):
        return ufunc(*args)

    @prim.def_abstract_eval
    def abstract_eval(*args):
        shape = jnp.broadcast_shapes(*(x.shape for x in args))
        dtype = jnp.result_type(*(x.dtype for x in args))
        return core.ShapedArray(shape, dtype)

    jvps = (
        None if d is None
        else lambda g, *args: d(*args) * g
        for d in derivs
    )
    ad.defjvp(prim, *jvps)

    batching.defbroadcasting(prim)
    
    return func

j0 = makejaxufunc(special.j0, lambda x: -j1(x))
j1 = makejaxufunc(special.j1, lambda x: (j0(x) - jn(2, x)) / 2.0)
jn = makejaxufunc(special.jn, None, lambda n, x: (jn(n - 1, x) - jn(n + 1, x)) / 2.0)
kv = makejaxufunc(special.kv, None, lambda v, z: kvp(v, z, 1))
kvp = makejaxufunc(special.kvp, None, lambda v, z, n: kvp(v, z, n + 1), None)

@tlu7
Copy link
Contributor
tlu7 commented Jun 8, 2022

For being jittable, we may need to turn the recurrence relations into vectorization. Similar approach has been used for implementing the associated Legendre functions in jax.scipy.special.lpmn

@KostadinovShalon
Copy link

Hi, is there any update on this? I need to use J1 and J2 (I do no need to get the gradient, btw) but I had no luck

@tlu7
Copy link
Contributor
tlu7 commented Oct 26, 2022

I will give it a try.

@DavidDevoogdt
Copy link
DavidDevoogdt commented Nov 4, 2022

Hi, would also like to see these implemented. In the meantime, a callback function to scipy seems good enough for me. It is jittable, differentiable ( eg. jit(jacfwd(lambda x: jv(0.5, x)))(x)) works fine) and vmappable. The derivatives are implemented by taking the appropriate function combinations (from https://dlmf.nist.gov/10)

EDIT: due to suggestion of @Gattocrucco, the f 8000 unction now uses pure_callback

import jax.numpy as jnp
import scipy.special
from jax import custom_jvp, pure_callback, vmap

# see https://github.com/google/jax/issues/11002


def generate_bessel(function):
    """function is Jv, Yv, Hv_1,Hv_2"""

    @custom_jvp
    def cv(v, x):
        return pure_callback(
            lambda vx: function(*vx),
            x,
            (v, x),
            vectorized=True,
        )

    @cv.defjvp
    def cv_jvp(primals, tangents):
        v, x = primals
        dv, dx = tangents
        primal_out = cv(v, x)

        # https://dlmf.nist.gov/10.6 formula 10.6.1
        tangents_out = jax.lax.cond(
            v == 0,
            lambda: -cv(v + 1, x),
            lambda: 0.5 * (cv(v - 1, x) - cv(v + 1, x)),
        )

        return primal_out, tangents_out * dx

    return cv


jv = generate_bessel(scipy.special.jv)
yv = generate_bessel(scipy.special.yv)
hankel1 = generate_bessel(scipy.special.hankel1)
hankel2 = generate_bessel(scipy.special.hankel2)


def generate_modified_bessel(function, sign):
    """function is Kv and Iv"""

    @custom_jvp
    def cv(v, x):
        return pure_callback(
            lambda vx: function(*vx),
            x,
            (v, x),
            vectorized=True,
        )

    @cv.defjvp
    def cv_jvp(primals, tangents):
        v, x = primals
        dv, dx = tangents
        primal_out = cv(v, x)

        # https://dlmf.nist.gov/10.6 formula 10.6.1
        tangents_out = jax.lax.cond(
            v == 0,
            lambda: sign * cv(v + 1, x),
            lambda: 0.5 * (cv(v - 1, x) + cv(v + 1, x)),
        )

        return primal_out, tangents_out * dx

    return cv


kv = generate_modified_bessel(scipy.special.kv, sign=-1)
iv = generate_modified_bessel(scipy.special.iv, sign=+1)


def spherical_bessel_genearator(f):
    def g(v, x):
        return f(v + 0.5, x) * jnp.sqrt(jnp.pi / (2 * x))

    return g


spherical_jv = spherical_bessel_genearator(jv)
spherical_yv = spherical_bessel_genearator(yv)
spherical_hankel1 = spherical_bessel_genearator(hankel1)
spherical_hankel2 = spherical_bessel_genearator(hankel2)

For reference, the plots


import matplotlib.pyplot as plt

x = jnp.linspace(0.0, 20.0, num=1000)


for func, name in zip(
    [jv, yv, iv, kv, spherical_jv, spherical_yv],
    ["jv", "yv", "iv", "kv", " spherical_jv", "spherical_yv"],
):

    plt.figure()

    for i in range(5):
        y = vmap(func, in_axes=(None, 0))(i, x)
        plt.plot(x, y, label=i)

    plt.ylim([-1.1, 1.1])
    plt.title(name)
    plt.legend()

    plt.draw()
    plt.pause(0.001)

    # plt.show()

print("done")

@Gattocrucco
Copy link
Contributor
Gattocrucco commented Nov 4, 2022

I would use jax.pure_callback. Is there a reason you are using jax.experimental.host_callback instead? pure_callback should work with vmap.

@axch
Copy link
Contributor
axch commented Jun 2, 2023

Assigning to @jakevdp for further triage.

@arthurmloureiro
Copy link
8000

I am also interested here if there are any updates here :)

@HosseinKhodavirdi
Copy link
HosseinKhodavirdi commented Jul 19, 2023

Since the spherical Hankel generates complex numbers, I am having a difficult time to take its derivative. Using what @DavidDevoogdt shared, I wrote:

spherical_hankel1 = spherical_bessel_genearator(hankel1)

def dhn(n,x):
  real_func = lambda x: np.real(spherical_hankel1(n, x))
  imag_func = lambda x: np.imag(spherical_hankel1(n, x))
  diff=jax.jacfwd(real_func, argnums=0)(x)+1j*jax.jacfwd(imag_func, argnums=0)(x)
  return diff

dhn_vec=vmap(dhn, in_axes=(None, 0))

and when printing:

krr=np.array([6.0])
print(dhn_vec(0, krr))

I get:

XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: RuntimeError: Incorrect output dtype for return value 0: Expected: float32, Actual: complex64

Any suggestions?

Thanks for the edit Jake. Do you have any suggestions, @jakevdp?

@jakevdp
Copy link
Collaborator
jakevdp commented Jul 19, 2023

It looks like you'd need to update the pure_callback call to specify the appropriate dtype for complex input. You're telling it that it should be returning float32, but the function you're calling is returning complex64.

https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html has some discussion of how to use jax.pure_callback

@jakevdp jakevdp removed their assignment Nov 3, 2023
@jakevdp jakevdp added the P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR label Nov 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

No branches or pull requests

9 participants
0