8000 Modified Bessel function of the second kind in jax.scipy.special · Issue #9956 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Modified Bessel function of the second kind in jax.scipy.special #9956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
egorssed opened this issue Mar 18, 2022 · 6 comments
Closed

Modified Bessel function of the second kind in jax.scipy.special #9956

egorssed opened this issue Mar 18, 2022 · 6 comments
8000
Assignees
Labels
enhancement New feature or request

Comments

@egorssed
Copy link
egorssed commented Mar 18, 2022

Hello, I would like to use modified Bessel functions of the second kind with initeger orders, but there are none in the jax.scipy yet.

Use example:

>>>scipy.special.kve(1.,2.)
1.03347
>>>jax.scipy.special.kve(1.,2.)
DeviceArray(1.03347, dtype=float64, weak_type=True)

It seemed that I could use the realisation from tensorflow_probability and get jaxified gradients using @tfp_custom_gradient like in this comment. Although tensorflow_probability's realisation exists, it actually doesn't work.

Mathematically, Bessel of the second kind can be expressed in terms of Bessel of the first kind
Снимок экрана 2022-03-18 в 17 16 23
,but it seems like the result of a limit 0/0, so your functions like jax.scipy.special.il don't really help, moreover they also have problems with gradients.

@egorssed egorssed added the enhancement New feature or request label Mar 18, 2022
@egorssed
Copy link
Author

Update: nightly build of tensorflow probability tfp-nightly=='0.17.0-dev20220322' works and even gives a gradient wrp to coordinate z (not the order v though).

jax.grad(lambda x: tfp.substrates.jax.math.bessel_kve(x[0],x[1]))([0.7,1.2])
[DeviceArray(0., dtype=float32, weak_type=True), DeviceArray(-0.57418, dtype=float32, weak_type=True)]

So for now it is fine, but it would be nice to see Modified Bessels of the second kind in Jax, because they are in the base of Matern kernels that are used for Gaussian processes.

@renecotyfanboy
Copy link
Contributor
renecotyfanboy commented Mar 30, 2022

I wrote a small snippet in pure jax that handles the case where z is real, using an integral representation and using double exponential quadrature (Eq 1.6) to solve it :

def phi(t):
    return jnp.exp(jnp.pi / 2 * jnp.sinh(t))

def dphi(t):
    return jnp.pi / 2 * jnp.cosh(t) * jnp.exp(jnp.pi / 2 * jnp.sinh(t))

def bessel_k(nu, z):
    
    z = jnp.asarray(z)[..., None]
    t = jnp.linspace(-3, 3, 101)[None, :]
    integrand = 0.5*(0.5*z)**nu*jnp.exp(-phi(t)-z**2/(4*phi(t)))*phi(t)**(-nu-1)*dphi(t)

    return jnp.trapz(integrand, x=t, axis=-1)

It seems to work pretty well when z is not too close to zero, and should be differentiable in terms of nu and z as it is only pure jax. However, I have no idea of how to integrate this properly in jax code, and I don't even know if I would be able to

@egorssed
Copy link
Author

I wrote a small snippet in pure jax that handle the case where z is real, using an integral representation and using double exponential quadrature (Eq 1.6) to solve it :

def phi(t):
    return jnp.exp(jnp.pi / 2 * jnp.sinh(t))

def dphi(t):
    return jnp.pi / 2 * jnp.cosh(t) * jnp.exp(jnp.pi / 2 * jnp.sinh(t))

def bessel_k(nu, z):
    
    z = jnp.asarray(z)[..., None]
    t = jnp.linspace(-3, 3, 101)[None, :]
    integrand = 0.5*(0.5*z)**nu*jnp.exp(-phi(t)-z**2/(4*phi(t)))*phi(t)**(-nu-1)*dphi(t)

    return jnp.trapz(integrand, x=t, axis=-1)

It seems to work pretty well when z is not too close to zero, and should be differentiable in terms of nu and z as it is only pure jax. However, I have no idea of how to integrate this properly in jax code, and I don't even know if I would be able to

Hmm, If I was to code it up, I would probably try to solve the Integral using Jordan's lemma and Complex Residue as we usually do in math. Or try to derive a series representation of Modified Bessel of the 2nd kind from the representation of the 1st kind and compute the series up to some term. The way with direct integration seems a bit too straightforward, but I am not sure that my ways will be robust either.

Probably the best way is just lookup scipy's implementation, but it might not fit Jax's pure computation paradigm.

@HGangloff
Copy link

Is there any news on this ? Modified Bessel function of the second kind (scipy special's kn, kv, kve) are indeed needed in Matern kernels and it would be nice to see them in jax

@axch
Copy link
Contributor
axch commented Jun 2, 2023

Assigning to @jakevdp for further triage. As I see it, the decisions here are:

  • Do we want to own modified Bessel functions of the second kind in JAX, or direct users to TFP? Or do something nutty like try to encourage TFP to push out a separate package of math functions that JAX could then be happy to take a dependency on?
  • Does the above discussion qualify as a bug in TFP's modified Bessel function and should we push the bug or a fix there?

@jakevdp
Copy link
Collaborator
jakevdp commented Nov 8, 2023

Closing for now; we'll track this in #12402

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants
0