Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 636877261
  • Loading branch information
Jake VanderPlas authored and RLaxDev committed May 24, 2024
1 parent df8e600 commit 461b4cf
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion examples/online_q_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def sample(self, batch_size):
if len(self._timesteps) != self._timesteps.maxlen:
raise ValueError("Not enough timesteps for a full sequence.")

actions, timesteps = jax.tree_map(lambda *ts: np.stack(ts),
actions, timesteps = jax.tree.map(lambda *ts: np.stack(ts),
*self._timesteps)
return actions, timesteps

Expand Down
12 changes: 6 additions & 6 deletions rlax/_src/moving_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def debiased_moments(self):
"""Returns debiased moments as in Adam."""
tiny = jnp.finfo(self.decay_product).tiny
debias = 1.0 / jnp.maximum(1 - self.decay_product, tiny)
mean = jax.tree_map(lambda m1: m1 * debias, self.mu)
mean = jax.tree.map(lambda m1: m1 * debias, self.mu)
# This computation of the variance may lose some numerical precision, if
# the mean is not approximately zero.
variance = jax.tree_map(
variance = jax.tree.map(
lambda m2, m: jnp.maximum(0.0, m2 * debias - jnp.square(m)),
self.nu, mean)
return EmaMoments(mean=mean, variance=variance)
Expand All @@ -67,7 +67,7 @@ def create_ema(decay=0.999, pmean_axis_name=None):
"""

def init_state(template_tree):
zeros = jax.tree_map(lambda x: jnp.zeros_like(jnp.mean(x)), template_tree)
zeros = jax.tree.map(lambda x: jnp.zeros_like(jnp.mean(x)), template_tree)
scalar_zero = jnp.ones([], dtype=jnp.float32)
return EmaState(mu=zeros, nu=zeros, decay_product=scalar_zero)

Expand All @@ -79,9 +79,9 @@ def _update(moment, value):
return decay * moment + (1 - decay) * mean

def update_moments(tree, state):
squared_tree = jax.tree_map(jnp.square, tree)
mu = jax.tree_map(_update, state.mu, tree)
nu = jax.tree_map(_update, state.nu, squared_tree)
squared_tree = jax.tree.map(jnp.square, tree)
mu = jax.tree.map(_update, state.mu, tree)
nu = jax.tree.map(_update, state.nu, squared_tree)
state = EmaState(
mu=mu, nu=nu, decay_product=state.decay_product * decay)
return state.debiased_moments(), state
Expand Down
2 changes: 1 addition & 1 deletion rlax/_src/nested_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def conditional_update(new_tensors: Any, old_tensors: Any, is_time: Numeric):
"Rlax conditional_update will be deprecated. Please use optax instead.",
PendingDeprecationWarning, stacklevel=2
)
return jax.tree_map(
return jax.tree.map(
lambda new, old: jax.lax.select(is_time, new, old),
new_tensors, old_tensors)

Expand Down
4 changes: 2 additions & 2 deletions rlax/_src/nested_updates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_conditional_update_is_time(self):
is_time = jnp.array(True)
output = conditional_update(self._new_struct, self._old_struct, is_time)
for o, exp in zip(
jax.tree_leaves(output), jax.tree_leaves(self._new_struct)):
jax.tree.leaves(output), jax.tree.leaves(self._new_struct)):
np.testing.assert_allclose(o, exp)

@chex.all_variants()
Expand All @@ -51,7 +51,7 @@ def test_conditional_update_is_not_time(self):
is_not_time = jnp.array(False)
output = conditional_update(self._new_struct, self._old_struct, is_not_time)
for o, exp in zip(
jax.tree_leaves(output), jax.tree_leaves(self._old_struct)):
jax.tree.leaves(output), jax.tree.leaves(self._old_struct)):
np.testing.assert_allclose(o, exp)


Expand Down
6 changes: 3 additions & 3 deletions rlax/_src/policy_gradients_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_qpg_loss_batch(self, compile_fn, place_fn):
qpg_loss = compile_fn(policy_gradients.qpg_loss)

# Optionally convert to device array.
policy_logits, q_values = jax.tree_map(place_fn,
policy_logits, q_values = jax.tree.map(place_fn,
(self.policy_logits, self.q_values))
# Test outputs.
actual = qpg_loss(policy_logits, q_values)
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_rm_loss_batch(self, compile_fn, place_fn):
rm_loss = compile_fn(policy_gradients.rm_loss)

# Optionally convert to device array.
policy_logits, q_values = jax.tree_map(place_fn,
policy_logits, q_values = jax.tree.map(place_fn,
(self.policy_logits, self.q_values))
# Test outputs.
actual = rm_loss(policy_logits, q_values)
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_rpg_loss(self, compile_fn, place_fn):
rpg_loss = compile_fn(policy_gradients.rpg_loss)

# Optionally convert to device array.
policy_logits, q_values = jax.tree_map(place_fn,
policy_logits, q_values = jax.tree.map(place_fn,
(self.policy_logits, self.q_values))
# Test outputs.
actual = rpg_loss(policy_logits, q_values)
Expand Down
2 changes: 1 addition & 1 deletion rlax/_src/pop_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def art(state: PopArtState,
state_new = PopArtState(shift_new, scale_new, second_moment_new)

# Prevent gradients propagating back through the state.
state_new = jax.tree_map(jax.lax.stop_gradient, state_new)
state_new = jax.tree.map(jax.lax.stop_gradient, state_new)

return state_new

Expand Down
10 changes: 5 additions & 5 deletions rlax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def tree_map_zipped(fn: Callable[..., Any], nests: Sequence[Any]):
if any(tree_structure(x) != tree_def for x in nests[1:]):
raise ValueError('All elements must share the same tree structure.')
return jax.tree_util.tree_unflatten(
tree_def, [fn(*d) for d in zip(*[jax.tree_leaves(x) for x in nests])])
tree_def, [fn(*d) for d in zip(*[jax.tree.leaves(x) for x in nests])])


def tree_split_key(rng_key: Array, tree_like: Any):
Expand Down Expand Up @@ -121,14 +121,14 @@ def tree_replace_masked(tree_data, tree_replacement, mask):
the updated tensor.
"""
if tree_replacement is None:
tree_replacement = jax.tree_map(jnp.zeros_like, tree_data)
return jax.tree_map(
tree_replacement = jax.tree.map(jnp.zeros_like, tree_data)
return jax.tree.map(
lambda data, replacement: base.replace_masked(data, replacement, mask),
tree_data, tree_replacement)


def tree_fn(fn, **unmapped_kwargs):
"""Wrap a function to jax.tree_map over its arguments.
"""Wrap a function to jax.tree.map over its arguments.
You may set some named arguments via a partial to skip the `tree_map` on those
arguments. Usual caveats of `partial` apply (e.g. set named args must be a
Expand All @@ -143,7 +143,7 @@ def tree_fn(fn, **unmapped_kwargs):
"""
pfn = functools.partial(fn, **unmapped_kwargs)
def _wrapped(*args):
return jax.tree_map(pfn, *args)
return jax.tree.map(pfn, *args)
return _wrapped


Expand Down
6 changes: 3 additions & 3 deletions rlax/_src/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_tree_split_key(self):
rng_key = jax.random.PRNGKey(42)
tree_like = (1, (2, 3), {'a': 4})
_, tree_keys = tree_util.tree_split_key(rng_key, tree_like)
self.assertLen(jax.tree_leaves(tree_keys), 4)
self.assertLen(jax.tree.leaves(tree_keys), 4)

def test_tree_map_zipped(self):
nests = [
Expand Down Expand Up @@ -75,9 +75,9 @@ def test_tree_split_leaves(self):
}

for keepdim in (False, True):
expd_shapes = jax.tree_map(lambda x: np.zeros(x.shape[1:]), t)
expd_shapes = jax.tree.map(lambda x: np.zeros(x.shape[1:]), t)
if keepdim:
expd_shapes = jax.tree_map(lambda x: np.expand_dims(x, 0), expd_shapes)
expd_shapes = jax.tree.map(lambda x: np.expand_dims(x, 0), expd_shapes)

res_trees = tree_util.tree_split_leaves(t, axis=0, keepdim=keepdim)
self.assertLen(res_trees, 3)
Expand Down

0 comments on commit 461b4cf

Please sign in to comment.