-
Notifications
You must be signed in to change notification settings - Fork 3k
Modified Bessel function of the second kind in jax.scipy.special #9956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Update: nightly build of tensorflow probability jax.grad(lambda x: tfp.substrates.jax.math.bessel_kve(x[0],x[1]))([0.7,1.2])
[DeviceArray(0., dtype=float32, weak_type=True), DeviceArray(-0.57418, dtype=float32, weak_type=True)] So for now it is fine, but it would be nice to see Modified Bessels of the second kind in Jax, because they are in the base of Matern kernels that are used for Gaussian processes. |
I wrote a small snippet in pure jax that handles the case where z is real, using an integral representation and using double exponential quadrature (Eq 1.6) to solve it :
It seems to work pretty well when z is not too close to zero, and should be differentiable in terms of nu and z as it is only pure jax. However, I have no idea of how to integrate this properly in jax code, and I don't even know if I would be able to |
Hmm, If I was to code it up, I would probably try to solve the Integral using Jordan's lemma and Complex Residue as we usually do in math. Or try to derive a series representation of Modified Bessel of the 2nd kind from the representation of the 1st kind and compute the series up to some term. The way with direct integration seems a bit too straightforward, but I am not sure that my ways will be robust either. Probably the best way is just lookup scipy's implementation, but it might not fit Jax's pure computation paradigm. |
Is there any news on this ? Modified Bessel function of the second kind (scipy special's |
Assigning to @jakevdp for further triage. As I see it, the decisions here are:
|
Closing for now; we'll track this in #12402 |
Hello, I would like to use modified Bessel functions of the second kind with initeger orders, but there are none in the
jax.scipy
yet.Use example:
It seemed that I could use the realisation from
tensorflow_probability
and get jaxified gradients using@tfp_custom_gradient
like in this comment. Althoughtensorflow_probability
's realisation exists, it actually doesn't work.Mathematically, Bessel of the second kind can be expressed in terms of Bessel of the first kind

,but it seems like the result of a limit 0/0, so your functions like
jax.scipy.special.il
don't really help, moreover they also have problems with gradients.The text was updated successfully, but these errors were encountered: