Skip to content

Commit

Permalink
Initial pickle and deepcopy singleton support.
Browse files Browse the repository at this point in the history
  • Loading branch information
chaserileyroberts committed Oct 26, 2023
1 parent 20e5838 commit 58fbc85
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 3 deletions.
22 changes: 21 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
from jax._src import traceback_util
from jax._src.typing import Array, DimSize, Shape
from jax._src import typing
from jax.version import _version

traceback_util.register_exclusion(__file__)

zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -365,20 +367,38 @@ def __repr__(self):
Atom = Union[Var, Literal]

class Primitive:
# used to identify primitive singletons.
name: str
namespace: str
# set for multi-output primitives.
multiple_results: bool = False
# set for call primitives processed in final style.
call_primitive: bool = False
# set for map primitives processed in final style.
map_primitive: bool = False

def __init__(self, name: str):
def __init__(self, name: str, namespace: str=f"__jaxz{_version}__"):
self.name = name
self.namespace = namespace

def __repr__(self):
return f'{self.name}'

def __reduce_ex__(self, __protocol) -> tuple[Any, ...]:
raise NotImplementedError(
"Seems like you're trying to pickle some internals of JAX. "
"Don't worry, this is supported! You'll just need to enable it "
"with `import jax.extend.cloudpickle_support`. "
"There may still be some rough edges with the implementation. If you hit "
"any issues, please file a bug report at "
"https://github.com/google/jax/issues/new?labels=bug&template=bug-report.yml")

def __deepcopy__(self, memo):
# Primitives are singletons, so copying is the same
# as returning self.
memo[id(self)] = self
return self

def bind(self, *args, **params):
assert (not config.enable_checks.value or
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/source_info_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import sysconfig
import threading
import types
from typing import Optional, NamedTuple, Union
from typing import Any, Optional, NamedTuple, Union

import jax.version
from jax._src.lib import xla_client
Expand Down Expand Up @@ -132,6 +132,10 @@ def replace(self, *, traceback: Optional[Traceback] = None,
self.name_stack if name_stack is None else name_stack
)

def __reduce_ex__(self, __protocol) -> tuple[Any, ...]:
# Dropping Traceback for now as it is not easy to pickle.
return SourceInfo, (None, self.name_stack)

def new_source_info() -> SourceInfo:
return SourceInfo(None, NameStack())

Expand Down
183 changes: 183 additions & 0 deletions jax/extend/cloudpickle_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from typing import Any

from jax.extend.core import Primitive
from jax import lax

# Table that stores all primitive definitions.
# Needed so that primitives are treated as singletons
# when using cloudpickle.
# Mapping type: (namespace, name) -> Primitive
_PRIMITIVES_TABLE_: dict[tuple[str, str], Primitive] = {}

def _primitives_table_get(namespace: str, name: str):
if (namespace, name) in _PRIMITIVES_TABLE_:
return _PRIMITIVES_TABLE_[(namespace, name)]
raise NotImplementedError(
f"Op {(namespace, name)} not found in primitives table."
"You may not have "
f"Primitives in table: {list(_PRIMITIVES_TABLE_.keys())}")

def enable_pickle(primitive: Primitive):
table_key = (primitive.namespace, primitive.name)
if table_key in _PRIMITIVES_TABLE_:
assert primitive is _PRIMITIVES_TABLE_[table_key], (
f"The op name {(primitive.namespace, primitive.name)} is already taken. "
"If this is a new custom primitive for your project,"
"try changing the namespace with Primitive(..., namespace='<YOUR_PROJECT_NAME>')."
)
_PRIMITIVES_TABLE_[(primitive.namespace, primitive.name)] = primitive


def _new_primitive_reduce_ex(self, __protocol) -> tuple[Any, ...]:
return _primitives_table_get, (self.namespace, self.name)

# Reassign all primitives to use the new __reduce_ex__ method
Primitive.__reduce_ex__ = _new_primitive_reduce_ex # type: ignore [method-assign]

# Setup our global dictionary
enable_pickle(lax.abs_p)
enable_pickle(lax.acos_p)
enable_pickle(lax.acosh_p)
enable_pickle(lax.add_p)
enable_pickle(lax.after_all_p)
enable_pickle(lax.and_p)
enable_pickle(lax.argmax_p)
enable_pickle(lax.argmin_p)
enable_pickle(lax.asin_p)
enable_pickle(lax.asinh_p)
enable_pickle(lax.atan_p)
enable_pickle(lax.atan2_p)
enable_pickle(lax.atanh_p)
enable_pickle(lax.bitcast_convert_type_p)
enable_pickle(lax.broadcast_in_dim_p)
enable_pickle(lax.cbrt_p)
enable_pickle(lax.ceil_p)
enable_pickle(lax.clamp_p)
enable_pickle(lax.clz_p)
enable_pickle(lax.complex_p)
enable_pickle(lax.concatenate_p)
enable_pickle(lax.conj_p)
enable_pickle(lax.convert_element_type_p)
enable_pickle(lax.copy_p)
enable_pickle(lax.cos_p)
enable_pickle(lax.cosh_p)
enable_pickle(lax.create_token_p)
enable_pickle(lax.div_p)
enable_pickle(lax.dot_general_p)
enable_pickle(lax.eq_p)
enable_pickle(lax.eq_to_p)
enable_pickle(lax.exp_p)
enable_pickle(lax.exp2_p)
enable_pickle(lax.expm1_p)
enable_pickle(lax.floor_p)
enable_pickle(lax.ge_p)
enable_pickle(lax.gt_p)
enable_pickle(lax.imag_p)
enable_pickle(lax.infeed_p)
enable_pickle(lax.integer_pow_p)
enable_pickle(lax.iota_p)
enable_pickle(lax.is_finite_p)
enable_pickle(lax.le_p)
enable_pickle(lax.le_to_p)
enable_pickle(lax.log1p_p)
enable_pickle(lax.log_p)
enable_pickle(lax.logistic_p)
enable_pickle(lax.lt_p)
enable_pickle(lax.lt_to_p)
enable_pickle(lax.max_p)
enable_pickle(lax.min_p)
enable_pickle(lax.mul_p)
enable_pickle(lax.ne_p)
enable_pickle(lax.neg_p)
enable_pickle(lax.nextafter_p)
enable_pickle(lax.not_p)
enable_pickle(lax.or_p)
enable_pickle(lax.outfeed_p)
enable_pickle(lax.pad_p)
enable_pickle(lax.population_count_p)
enable_pickle(lax.pow_p)
enable_pickle(lax.real_p)
enable_pickle(lax.reduce_and_p)
enable_pickle(lax.reduce_max_p)
enable_pickle(lax.reduce_min_p)
enable_pickle(lax.reduce_or_p)
enable_pickle(lax.reduce_p)
enable_pickle(lax.reduce_precision_p)
enable_pickle(lax.reduce_prod_p)
enable_pickle(lax.reduce_sum_p)
enable_pickle(lax.reduce_xor_p)
enable_pickle(lax.rem_p)
enable_pickle(lax.reshape_p)
enable_pickle(lax.rev_p)
enable_pickle(lax.rng_bit_generator_p)
enable_pickle(lax.rng_uniform_p)
enable_pickle(lax.round_p)
enable_pickle(lax.rsqrt_p)
enable_pickle(lax.select_n_p)
enable_pickle(lax.shift_left_p)
enable_pickle(lax.shift_right_arithmetic_p)
enable_pickle(lax.shift_right_logical_p)
enable_pickle(lax.sign_p)
enable_pickle(lax.sin_p)
enable_pickle(lax.sinh_p)
enable_pickle(lax.sort_p)
enable_pickle(lax.sqrt_p)
enable_pickle(lax.squeeze_p)
enable_pickle(lax.sub_p)
enable_pickle(lax.tan_p)
enable_pickle(lax.tanh_p)
enable_pickle(lax.top_k_p)
enable_pickle(lax.transpose_p)
enable_pickle(lax.xor_p)
enable_pickle(lax.bessel_i0e_p)
enable_pickle(lax.bessel_i1e_p)
enable_pickle(lax.digamma_p)
enable_pickle(lax.erfc_p)
enable_pickle(lax.erf_inv_p)
enable_pickle(lax.erf_p)
enable_pickle(lax.igammac_p)
enable_pickle(lax.igamma_grad_a_p)
enable_pickle(lax.igamma_p)
enable_pickle(lax.lgamma_p)
enable_pickle(lax.polygamma_p)
enable_pickle(lax.random_gamma_grad_p)
enable_pickle(lax.regularized_incomplete_beta_p)
enable_pickle(lax.zeta_p)
enable_pickle(lax.dynamic_slice_p)
enable_pickle(lax.dynamic_update_slice_p)
enable_pickle(lax.gather_p)
enable_pickle(lax.scatter_add_p)
enable_pickle(lax.scatter_max_p)
enable_pickle(lax.scatter_min_p)
enable_pickle(lax.scatter_mul_p)
enable_pickle(lax.scatter_p)
enable_pickle(lax.slice_p)
enable_pickle(lax.conv_general_dilated_p)
enable_pickle(lax.reduce_window_max_p)
enable_pickle(lax.reduce_window_min_p)
enable_pickle(lax.reduce_window_p)
enable_pickle(lax.reduce_window_sum_p)
enable_pickle(lax.select_and_gather_add_p)
enable_pickle(lax.select_and_scatter_p)
enable_pickle(lax.select_and_scatter_add_p)
enable_pickle(lax.cond_p)
enable_pickle(lax.cumlogsumexp_p)
enable_pickle(lax.cummax_p)
enable_pickle(lax.cummin_p)
enable_pickle(lax.cumprod_p)
enable_pickle(lax.cumsum_p)
enable_pickle(lax.linear_solve_p)
enable_pickle(lax.scan_p)
enable_pickle(lax.while_p)
enable_pickle(lax.fft_p)
enable_pickle(lax.all_gather_p)
enable_pickle(lax.all_to_all_p)
enable_pickle(lax.axis_index_p)
enable_pickle(lax.pmax_p)
enable_pickle(lax.pmin_p)
enable_pickle(lax.ppermute_p)
enable_pickle(lax.psum_p)
enable_pickle(lax.approx_top_k_p)
enable_pickle(lax.stop_gradient_p)
enable_pickle(lax.sharding_constraint_p)
enable_pickle(lax.device_put_p)
11 changes: 10 additions & 1 deletion jax/extend/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,14 @@
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.abstract_arrays import (
array_types as array_types
array_types as array_types,
)

from jax._src.core import (
eval_jaxpr as eval_jaxpr,
jaxpr_as_fun as jaxpr_as_fun,
Jaxpr as Jaxpr,
JaxprEqn as JaxprEqn,
JaxprDebugInfo as JaxprDebugInfo,
Primitive as Primitive,
)
10 changes: 10 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,16 @@ jax_test(
] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
)

jax_test(
name = "copying_test",
srcs = ["copying_test.py"],
shard_count = 1,
deps = [
"//jax:internal_test_util",
] + py_deps("cloudpickle"),
)


jax_test(
name = "lax_test",
srcs = ["lax_test.py"],
Expand Down
38 changes: 38 additions & 0 deletions tests/copying_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from cloudpickle import cloudpickle
from copy import deepcopy

from jax._src import core
from jax._src import test_util as jtu

import jax
import jax.extend.cloudpickle_support


class CopyingTest(jtu.JaxTestCase):

def test_jaxpr_cloudpickle(self):
def f(a, b):
return a + b, b * a

jxpr = jax.make_jaxpr(f)(1, 2)

# Load then dump with cloudpickle
encoded = cloudpickle.dumps(jxpr)
loaded_jxpr = cloudpickle.loads(encoded)

# Eval, jit and execute.
fun = core.jaxpr_as_fun(loaded_jxpr)
assert jax.jit(fun)(5, 3) == [8, 15]

def test_jaxpr_deepcopy(self):
def f(a, b):
return b * a, a + b,

jxpr = jax.make_jaxpr(f)(1, 2)

# Deepcopy jaxpr
copied_jaxpr = deepcopy(jxpr)

# Eval, jit and execute.
fun = core.jaxpr_as_fun(copied_jaxpr)
assert jax.jit(fun)(5, 3) == [15, 8]

0 comments on commit 58fbc85

Please sign in to comment.