Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap(custom_jvp) does not strip zeros from nondifferentiable return values, leading to AD crashes #25724

Open
patrick-kidger opened this issue Jan 5, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@patrick-kidger
Copy link
Collaborator

Description

This:

import jax
import jax.numpy as jnp

@jax.custom_jvp
def returns_int(x):
   return x, 1

@returns_int.defjvp
def f_jvp(p, t):
   (p,) = p
   (t,) = t
   p_out, p_aux = returns_int(p)
   t_out = jnp.ones_like(p_out)
   t_aux = jax.custom_derivatives.zero_from_primal(p_aux, symbolic_zeros=True)
   return (p_out, p_aux), (t_out, t_aux)

def run(x):
   _, integer = jax.vmap(returns_int)(x)
   print(integer)

jax.jvp(run, (jnp.arange(2.),), (jnp.arange(2.),))

prints:

Traced<ShapedArray(int32[2], weak_type=True)>with<JVPTrace> with
  primal = Array([1, 1], dtype=int32, weak_type=True)
  tangent = array([(b'',), (b'',)], dtype=[('float0', 'V')])

In contrast if the jax.vmap is removed, then there is no JVPTrace at all -- it seems that custom_jvp attempts to strip symbolic zeros from nondifferentiable return values, but that this is defeated by having a vmap wrapper.

So (a) there is a discrepancy there, but (b) this is now a problem for downsteam nondifferentiable primitives! They see an AD tracer, they don't have an AD rule, they explode. And moreover this spurious tangent isn't removable with lax.stop_gradient because that skips over nondifferentiable types:

jax/jax/_src/lax/lax.py

Lines 2047 to 2054 in 54fd738

elif (dtypes.issubdtype(_dtype(x), np.floating) or
dtypes.issubdtype(_dtype(x), np.complexfloating)):
# break abstractions to support legacy leaked tracer use cases
if isinstance(x, ad.JVPTracer):
return stop(x.primal)
return ad_util.stop_gradient_p.bind(x)
else:
return x

!

I imagine the fix is either to allow lax.stop_gradient to operate on nondifferentiable types, or to adjust custom_jvp to have consistent behavior regardless of whether it is vmapped. (Or maybe both?)

System info (python version, jaxlib version, accelerator, etc.)

JAX 0.4.38

@dfm
Copy link
Collaborator

dfm commented Jan 8, 2025

Thanks for tracking this down and for the clear reproduction, @patrick-kidger! I haven't had a chance to dig in too deeply yet, but I think the key place where this issue hits is here:

# TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can
# be wasteful in the rare case it actually triggers; handle symbolically!
outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs]

which has an existing TODO for @mattjj and @froystig. To get consistent behavior under vmap, I think we would need to remove the instantiate and properly handle the symbolic zeros below.

I'll take a look soon!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants