Skip to content

Commit

Permalink
Remove references to jax.core.raise_to_shaped
Browse files Browse the repository at this point in the history
As of JAX v0.4.36, `core.raise_to_shaped` is deprecated, and simply returns the input unchanged.

PiperOrigin-RevId: 704944384
  • Loading branch information
Jake VanderPlas authored and The oryx Authors committed Dec 11, 2024
1 parent 32f02ef commit 449b41d
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 8 deletions.
8 changes: 4 additions & 4 deletions oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def handle_sow(self, *values, name, tag, tree, mode):
raise ValueError(f'Variable has already been reaped: {name}')
avals = tree_util.tree_unflatten(
tree,
[jax_core.raise_to_shaped(jax_core.get_aval(v)) for v in values])
[jax_core.get_aval(v) for v in values])
vals = tree_util.tree_unflatten(tree, values)
pred = None
if mode == 'cond_clobber':
Expand Down Expand Up @@ -792,7 +792,7 @@ def _get_harvest_metadata(closed_jaxpr, settings, *args):
flat_args, in_tree = tree_util.tree_flatten(args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
in_avals = jax_util.safe_map(
lambda a: jax_core.raise_to_shaped(jax_core.get_aval(a)),
lambda a: jax_core.get_aval(a),
flat_args)
pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
metadata = aux()
Expand Down Expand Up @@ -841,7 +841,7 @@ def _reap_scan_rule(trace: HarvestTrace, *vals, length, reverse, jaxpr,
cond_carry_avals[name] = None
if mode == 'cond_clobber':
reap_carry_avals[name] = aval
cond_carry_avals[name] = jax_core.raise_to_shaped(jax_core.get_aval(True))
cond_carry_avals[name] = jax_core.get_aval(True)

body_fun = jax_core.jaxpr_as_fun(jaxpr)

Expand Down Expand Up @@ -929,7 +929,7 @@ def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr,
)
reap_avals[k] = meta['aval']
if mode == 'cond_clobber':
cond_avals[k] = jax_core.raise_to_shaped(jax_core.get_aval(True))
cond_avals[k] = jax_core.get_aval(True)

cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr)
body_fun = jax_core.jaxpr_as_fun(body_jaxpr)
Expand Down
1 change: 0 additions & 1 deletion oryx/core/interpreters/inverse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def unknown(cls, aval):
def new(cls, val):
val = np.array(val)
aval = jax_core.get_aval(val)
aval = jax_core.raise_to_shaped(aval)
ndslice = NDSlice.new(val, np.zeros_like(val))
return InverseAndILDJ(aval, frozenset([ndslice]))

Expand Down
3 changes: 1 addition & 2 deletions oryx/core/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ def subcall(self, name):
tie_all_p = jax_core.Primitive('tie_all')
tie_all_p.multiple_results = True
tie_all_p.def_impl(lambda *args: args)
tie_all_p.def_abstract_eval(lambda *args: safe_map( # pylint: disable=g-long-lambda
jax_core.raise_to_shaped, args))
tie_all_p.def_abstract_eval(lambda *args: args)

mlir.register_lowering(tie_all_p, lambda c, *args: args)

Expand Down
2 changes: 1 addition & 1 deletion oryx/core/trace_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_shaped_aval(x):
if hasattr(x, 'dtype') and hasattr(x, 'shape'):
return jax_core.ShapedArray(
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True))
return jax_core.raise_to_shaped(jax_core.get_aval(x))
return jax_core.get_aval(x)


def pv_like(x, abstract=True):
Expand Down

0 comments on commit 449b41d

Please sign in to comment.