vmap(custom_jvp)
does not strip zeros from nondifferentiable return values, leading to AD crashes
#25724
Labels
bug
Something isn't working
Description
This:
prints:
In contrast if the
jax.vmap
is removed, then there is noJVPTrace
at all -- it seems thatcustom_jvp
attempts to strip symbolic zeros from nondifferentiable return values, but that this is defeated by having a vmap wrapper.So (a) there is a discrepancy there, but (b) this is now a problem for downsteam nondifferentiable primitives! They see an AD tracer, they don't have an AD rule, they explode. And moreover this spurious tangent isn't removable with
lax.stop_gradient
because that skips over nondifferentiable types:jax/jax/_src/lax/lax.py
Lines 2047 to 2054 in 54fd738
!
I imagine the fix is either to allow
lax.stop_gradient
to operate on nondifferentiable types, or to adjustcustom_jvp
to have consistent behavior regardless of whether it is vmapped. (Or maybe both?)System info (python version, jaxlib version, accelerator, etc.)
JAX 0.4.38
The text was updated successfully, but these errors were encountered: