diff --git a/jax/_src/core.py b/jax/_src/core.py index 6ceb3d48ec9f..d4bbe3df40da 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -364,8 +364,31 @@ def __repr__(self): Atom = Union[Var, Literal] +# Table that stores all primitive definitions. +# Needed so that primitives are treated as singletons +# when using cloudpickle. +# Mapping type: (namespace, name) -> Primitive +_PRIMITIVES_TABLE_: dict[tuple[str, str], Primitive] = {} + +def primitives_table_get(namespace: str, name: str): + if (namespace, name) in _PRIMITIVES_TABLE_: + return _PRIMITIVES_TABLE_[(namespace, name)] + raise NotImplementedError( + f"Op {(namespace, name)} not found in primitives table. (Did you import it yet?)." + f"Primitive table has {list(_PRIMITIVES_TABLE_.keys())}") + +def primitives_table_set(namespace: str, name: str, primitive: Primitive): + assert (namespace, name) not in _PRIMITIVES_TABLE_, ( + f"The op name {(namespace, name)} is already taken. " + "Try changing the namespace with Primitive(..., namespace='')." + ) + _PRIMITIVES_TABLE_[(namespace, name)] = primitive + + class Primitive: + # used to identify primitive singletons. name: str + namespace: str # set for multi-output primitives. multiple_results: bool = False # set for call primitives processed in final style. @@ -373,12 +396,23 @@ class Primitive: # set for map primitives processed in final style. map_primitive: bool = False - def __init__(self, name: str): + def __init__(self, name: str, namespace: str="__jax_core__"): self.name = name + self.namespace = namespace + primitives_table_set(namespace, name, self) def __repr__(self): return f'{self.name}' + def __reduce_ex__(self, __protocol) -> str | tuple[Any, ...]: + return primitives_table_get, (self.namespace, self.name) + + def __deepcopy__(self, memo): + # Primitives are singletons, so copying is the same + # as returning self. + memo[id(self)] = self + return self + def bind(self, *args, **params): assert (not config.enable_checks.value or all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index efec6cd2e561..3aa649054a52 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -21,7 +21,7 @@ import sysconfig import threading import types -from typing import Optional, NamedTuple, Union +from typing import Any, Optional, NamedTuple, Union import jax.version from jax._src.lib import xla_client @@ -132,6 +132,10 @@ def replace(self, *, traceback: Optional[Traceback] = None, self.name_stack if name_stack is None else name_stack ) + def __reduce_ex__(self, __protocol) -> str | tuple[Any, ...]: + # Dropping Traceback for now as it is not easy to pickle. + return SourceInfo, (None, self.name_stack) + def new_source_info() -> SourceInfo: return SourceInfo(None, NameStack()) diff --git a/tests/copying_test.py b/tests/copying_test.py new file mode 100644 index 000000000000..300848d71938 --- /dev/null +++ b/tests/copying_test.py @@ -0,0 +1,25 @@ +from functools import partial +from cloudpickle import cloudpickle +from copy import deepcopy + +import jax +from jax._src.core import jaxpr_as_fun + +def test_jaxpr_cloudpickle(): + def f(a, b): + return a + b, b * a + + jxpr = jax.make_jaxpr(f)(1, 2) + encoded = cloudpickle.dumps(jxpr) + loaded_jxpr = cloudpickle.loads(encoded) + + assert jaxpr_as_fun(loaded_jxpr)(5, 3) == [8, 15] + +def test_jaxpr_deepcopy(): + def f(a, b): + return b * a, a + b, + + jxpr = jax.make_jaxpr(f)(1, 2) + copied_jaxpr = deepcopy(jxpr) + + assert jaxpr_as_fun(copied_jaxpr)(5, 3) == [15, 8]