-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial pickle and deepcopy singleton support.
- Loading branch information
1 parent
20e5838
commit 58fbc85
Showing
6 changed files
with
267 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
from typing import Any | ||
|
||
from jax.extend.core import Primitive | ||
from jax import lax | ||
|
||
# Table that stores all primitive definitions. | ||
# Needed so that primitives are treated as singletons | ||
# when using cloudpickle. | ||
# Mapping type: (namespace, name) -> Primitive | ||
_PRIMITIVES_TABLE_: dict[tuple[str, str], Primitive] = {} | ||
|
||
def _primitives_table_get(namespace: str, name: str): | ||
if (namespace, name) in _PRIMITIVES_TABLE_: | ||
return _PRIMITIVES_TABLE_[(namespace, name)] | ||
raise NotImplementedError( | ||
f"Op {(namespace, name)} not found in primitives table." | ||
"You may not have " | ||
f"Primitives in table: {list(_PRIMITIVES_TABLE_.keys())}") | ||
|
||
def enable_pickle(primitive: Primitive): | ||
table_key = (primitive.namespace, primitive.name) | ||
if table_key in _PRIMITIVES_TABLE_: | ||
assert primitive is _PRIMITIVES_TABLE_[table_key], ( | ||
f"The op name {(primitive.namespace, primitive.name)} is already taken. " | ||
"If this is a new custom primitive for your project," | ||
"try changing the namespace with Primitive(..., namespace='<YOUR_PROJECT_NAME>')." | ||
) | ||
_PRIMITIVES_TABLE_[(primitive.namespace, primitive.name)] = primitive | ||
|
||
|
||
def _new_primitive_reduce_ex(self, __protocol) -> tuple[Any, ...]: | ||
return _primitives_table_get, (self.namespace, self.name) | ||
|
||
# Reassign all primitives to use the new __reduce_ex__ method | ||
Primitive.__reduce_ex__ = _new_primitive_reduce_ex # type: ignore [method-assign] | ||
|
||
# Setup our global dictionary | ||
enable_pickle(lax.abs_p) | ||
enable_pickle(lax.acos_p) | ||
enable_pickle(lax.acosh_p) | ||
enable_pickle(lax.add_p) | ||
enable_pickle(lax.after_all_p) | ||
enable_pickle(lax.and_p) | ||
enable_pickle(lax.argmax_p) | ||
enable_pickle(lax.argmin_p) | ||
enable_pickle(lax.asin_p) | ||
enable_pickle(lax.asinh_p) | ||
enable_pickle(lax.atan_p) | ||
enable_pickle(lax.atan2_p) | ||
enable_pickle(lax.atanh_p) | ||
enable_pickle(lax.bitcast_convert_type_p) | ||
enable_pickle(lax.broadcast_in_dim_p) | ||
enable_pickle(lax.cbrt_p) | ||
enable_pickle(lax.ceil_p) | ||
enable_pickle(lax.clamp_p) | ||
enable_pickle(lax.clz_p) | ||
enable_pickle(lax.complex_p) | ||
enable_pickle(lax.concatenate_p) | ||
enable_pickle(lax.conj_p) | ||
enable_pickle(lax.convert_element_type_p) | ||
enable_pickle(lax.copy_p) | ||
enable_pickle(lax.cos_p) | ||
enable_pickle(lax.cosh_p) | ||
enable_pickle(lax.create_token_p) | ||
enable_pickle(lax.div_p) | ||
enable_pickle(lax.dot_general_p) | ||
enable_pickle(lax.eq_p) | ||
enable_pickle(lax.eq_to_p) | ||
enable_pickle(lax.exp_p) | ||
enable_pickle(lax.exp2_p) | ||
enable_pickle(lax.expm1_p) | ||
enable_pickle(lax.floor_p) | ||
enable_pickle(lax.ge_p) | ||
enable_pickle(lax.gt_p) | ||
enable_pickle(lax.imag_p) | ||
enable_pickle(lax.infeed_p) | ||
enable_pickle(lax.integer_pow_p) | ||
enable_pickle(lax.iota_p) | ||
enable_pickle(lax.is_finite_p) | ||
enable_pickle(lax.le_p) | ||
enable_pickle(lax.le_to_p) | ||
enable_pickle(lax.log1p_p) | ||
enable_pickle(lax.log_p) | ||
enable_pickle(lax.logistic_p) | ||
enable_pickle(lax.lt_p) | ||
enable_pickle(lax.lt_to_p) | ||
enable_pickle(lax.max_p) | ||
enable_pickle(lax.min_p) | ||
enable_pickle(lax.mul_p) | ||
enable_pickle(lax.ne_p) | ||
enable_pickle(lax.neg_p) | ||
enable_pickle(lax.nextafter_p) | ||
enable_pickle(lax.not_p) | ||
enable_pickle(lax.or_p) | ||
enable_pickle(lax.outfeed_p) | ||
enable_pickle(lax.pad_p) | ||
enable_pickle(lax.population_count_p) | ||
enable_pickle(lax.pow_p) | ||
enable_pickle(lax.real_p) | ||
enable_pickle(lax.reduce_and_p) | ||
enable_pickle(lax.reduce_max_p) | ||
enable_pickle(lax.reduce_min_p) | ||
enable_pickle(lax.reduce_or_p) | ||
enable_pickle(lax.reduce_p) | ||
enable_pickle(lax.reduce_precision_p) | ||
enable_pickle(lax.reduce_prod_p) | ||
enable_pickle(lax.reduce_sum_p) | ||
enable_pickle(lax.reduce_xor_p) | ||
enable_pickle(lax.rem_p) | ||
enable_pickle(lax.reshape_p) | ||
enable_pickle(lax.rev_p) | ||
enable_pickle(lax.rng_bit_generator_p) | ||
enable_pickle(lax.rng_uniform_p) | ||
enable_pickle(lax.round_p) | ||
enable_pickle(lax.rsqrt_p) | ||
enable_pickle(lax.select_n_p) | ||
enable_pickle(lax.shift_left_p) | ||
enable_pickle(lax.shift_right_arithmetic_p) | ||
enable_pickle(lax.shift_right_logical_p) | ||
enable_pickle(lax.sign_p) | ||
enable_pickle(lax.sin_p) | ||
enable_pickle(lax.sinh_p) | ||
enable_pickle(lax.sort_p) | ||
enable_pickle(lax.sqrt_p) | ||
enable_pickle(lax.squeeze_p) | ||
enable_pickle(lax.sub_p) | ||
enable_pickle(lax.tan_p) | ||
enable_pickle(lax.tanh_p) | ||
enable_pickle(lax.top_k_p) | ||
enable_pickle(lax.transpose_p) | ||
enable_pickle(lax.xor_p) | ||
enable_pickle(lax.bessel_i0e_p) | ||
enable_pickle(lax.bessel_i1e_p) | ||
enable_pickle(lax.digamma_p) | ||
enable_pickle(lax.erfc_p) | ||
enable_pickle(lax.erf_inv_p) | ||
enable_pickle(lax.erf_p) | ||
enable_pickle(lax.igammac_p) | ||
enable_pickle(lax.igamma_grad_a_p) | ||
enable_pickle(lax.igamma_p) | ||
enable_pickle(lax.lgamma_p) | ||
enable_pickle(lax.polygamma_p) | ||
enable_pickle(lax.random_gamma_grad_p) | ||
enable_pickle(lax.regularized_incomplete_beta_p) | ||
enable_pickle(lax.zeta_p) | ||
enable_pickle(lax.dynamic_slice_p) | ||
enable_pickle(lax.dynamic_update_slice_p) | ||
enable_pickle(lax.gather_p) | ||
enable_pickle(lax.scatter_add_p) | ||
enable_pickle(lax.scatter_max_p) | ||
enable_pickle(lax.scatter_min_p) | ||
enable_pickle(lax.scatter_mul_p) | ||
enable_pickle(lax.scatter_p) | ||
enable_pickle(lax.slice_p) | ||
enable_pickle(lax.conv_general_dilated_p) | ||
enable_pickle(lax.reduce_window_max_p) | ||
enable_pickle(lax.reduce_window_min_p) | ||
enable_pickle(lax.reduce_window_p) | ||
enable_pickle(lax.reduce_window_sum_p) | ||
enable_pickle(lax.select_and_gather_add_p) | ||
enable_pickle(lax.select_and_scatter_p) | ||
enable_pickle(lax.select_and_scatter_add_p) | ||
enable_pickle(lax.cond_p) | ||
enable_pickle(lax.cumlogsumexp_p) | ||
enable_pickle(lax.cummax_p) | ||
enable_pickle(lax.cummin_p) | ||
enable_pickle(lax.cumprod_p) | ||
enable_pickle(lax.cumsum_p) | ||
enable_pickle(lax.linear_solve_p) | ||
enable_pickle(lax.scan_p) | ||
enable_pickle(lax.while_p) | ||
enable_pickle(lax.fft_p) | ||
enable_pickle(lax.all_gather_p) | ||
enable_pickle(lax.all_to_all_p) | ||
enable_pickle(lax.axis_index_p) | ||
enable_pickle(lax.pmax_p) | ||
enable_pickle(lax.pmin_p) | ||
enable_pickle(lax.ppermute_p) | ||
enable_pickle(lax.psum_p) | ||
enable_pickle(lax.approx_top_k_p) | ||
enable_pickle(lax.stop_gradient_p) | ||
enable_pickle(lax.sharding_constraint_p) | ||
enable_pickle(lax.device_put_p) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from cloudpickle import cloudpickle | ||
from copy import deepcopy | ||
|
||
from jax._src import core | ||
from jax._src import test_util as jtu | ||
|
||
import jax | ||
import jax.extend.cloudpickle_support | ||
|
||
|
||
class CopyingTest(jtu.JaxTestCase): | ||
|
||
def test_jaxpr_cloudpickle(self): | ||
def f(a, b): | ||
return a + b, b * a | ||
|
||
jxpr = jax.make_jaxpr(f)(1, 2) | ||
|
||
# Load then dump with cloudpickle | ||
encoded = cloudpickle.dumps(jxpr) | ||
loaded_jxpr = cloudpickle.loads(encoded) | ||
|
||
# Eval, jit and execute. | ||
fun = core.jaxpr_as_fun(loaded_jxpr) | ||
assert jax.jit(fun)(5, 3) == [8, 15] | ||
|
||
def test_jaxpr_deepcopy(self): | ||
def f(a, b): | ||
return b * a, a + b, | ||
|
||
jxpr = jax.make_jaxpr(f)(1, 2) | ||
|
||
# Deepcopy jaxpr | ||
copied_jaxpr = deepcopy(jxpr) | ||
|
||
# Eval, jit and execute. | ||
fun = core.jaxpr_as_fun(copied_jaxpr) | ||
assert jax.jit(fun)(5, 3) == [15, 8] |