Skip to content

Commit

Permalink
Future-proof calls to jnp.solve on batched 1D inputs.
Browse files Browse the repository at this point in the history
This has been deprecated since JAX v0.4.25, and is no longer supported in JAX v0.5.0.

PiperOrigin-RevId: 712980403
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Jan 7, 2025
1 parent 8ca8408 commit 48636c5
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions trax/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,7 @@ def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name
"""
a = mu.shape[-1] * jnp.log(2 * jnp.pi)
_, b = jnp.linalg.slogdet(sigma)
y = jnp.linalg.solve(sigma, x - mu)
y = jnp.expand_dims(y, axis=-1)
y = jnp.linalg.solve(sigma, (x - mu)[..., None])
xm = jnp.expand_dims(x - mu, axis=-2)
c = jnp.matmul(xm, y)
c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)
Expand Down

0 comments on commit 48636c5

Please sign in to comment.