-
Notifications
You must be signed in to change notification settings - Fork 3k
Feature request: JAX implementation of scipy.special.jv #11002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Maybe you could adapt an existing algorithm into a numba-jittable function and see #1870 and/or https://github.com/josipd/jax/blob/master/jax/experimental/jambax.py to plug the function into a jax primitive (see also #9956). Here is some code I am using to add some special functions but without jit, if it helps: def makejaxufunc(ufunc, *derivs):
# TODO use jax.something.standard_primitive
prim = core.Primitive(ufunc.__name__)
@functools.wraps(ufunc)
def func(*args):
return prim.bind(*args)
@prim.def_impl
def impl(*args):
return ufunc(*args)
@prim.def_abstract_eval
def abstract_eval(*args):
shape = jnp.broadcast_shapes(*(x.shape for x in args))
dtype = jnp.result_type(*(x.dtype for x in args))
return core.ShapedArray(shape, dtype)
jvps = (
None if d is None
else lambda g, *args: d(*args) * g
for d in derivs
)
ad.defjvp(prim, *jvps)
batching.defbroadcasting(prim)
return func
j0 = makejaxufunc(special.j0, lambda x: -j1(x))
j1 = makejaxufunc(special.j1, lambda x: (j0(x) - jn(2, x)) / 2.0)
jn = makejaxufunc(special.jn, None, lambda n, x: (jn(n - 1, x) - jn(n + 1, x)) / 2.0)
kv = makejaxufunc(special.kv, None, lambda v, z: kvp(v, z, 1))
kvp = makejaxufunc(special.kvp, None, lambda v, z, n: kvp(v, z, n + 1), None) |
For being jittable, we may need to turn the recurrence relations into vectorization. Similar approach has been used for implementing the associated Legendre functions in jax.scipy.special.lpmn |
Hi, is there any update on this? I need to use J1 and J2 (I do no need to get the gradient, btw) but I had no luck |
I will give it a try. |
Hi, would also like to see these implemented. In the meantime, a callback function to scipy seems good enough for me. It is jittable, differentiable ( eg. EDIT: due to suggestion of @Gattocrucco, the f 8000 unction now uses pure_callback
For reference, the plots
|
I would use |
Assigning to @jakevdp for further triage. |
I am also interested here if there are any updates here :) |
Since the spherical Hankel generates complex numbers, I am having a difficult time to take its derivative. Using what @DavidDevoogdt shared, I wrote: spherical_hankel1 = spherical_bessel_genearator(hankel1)
def dhn(n,x):
real_func = lambda x: np.real(spherical_hankel1(n, x))
imag_func = lambda x: np.imag(spherical_hankel1(n, x))
diff=jax.jacfwd(real_func, argnums=0)(x)+1j*jax.jacfwd(imag_func, argnums=0)(x)
return diff
dhn_vec=vmap(dhn, in_axes=(None, 0)) and when printing:
I get:
Any suggestions? Thanks for the edit Jake. Do you have any suggestions, @jakevdp? |
It looks like you'd need to update the https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html has some discussion of how to use |
Hi,
I'm working on a research problem that involves computations of Bessel functions of the first kind, of multiple different integer orders, and I'm trying to leverage JAX for autodiff and fast computation on GPUs.
The specific function required by the equations I'm working with is implemented by scipy in scipy.special.jv. JAX's implementation of scipy.special exposes a limited amount of functionality associated with Bessel functions -- the ones I could find are i0, i1, i0e, and i1e, but not any Bessel J.
A similar Bessel function related issue has been raised in the past, at #2466, where someone requested a JAX implementation of scipy.special.iv. The accepted solution for that issue was to use the TensorFlow Probability library implementation of ive. Unfortunately, this solution does not seem to work in my case: TFP does not seem to have an implementation of jv, and I cannot convert between ive and jv since the TFP implementation of ive does not accept complex arguments as inputs.
Thus I'd like to request that JAX (or TFP) implement scipy.special.jv. Alternatively, if someone could point me to a way I could translate one of the existing third party Bessel function implementations (eg. those provided by scipy) into a JAX-composable jv compatible with jax.grad and jax.jit myself, I would appreciate it.
Thanks,
The text was updated successfully, but these errors were encountered: