Skip to content

Commit

Permalink
Initial pickle and deepcopy singleton support
Browse files Browse the repository at this point in the history
  • Loading branch information
chaserileyroberts committed Oct 23, 2023
1 parent 20e5838 commit 818358e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
36 changes: 35 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,21 +364,55 @@ def __repr__(self):

Atom = Union[Var, Literal]

# Table that stores all primitive definitions.
# Needed so that primitives are treated as singletons
# when using cloudpickle.
# Mapping type: (namespace, name) -> Primitive
_PRIMITIVES_TABLE_: dict[tuple[str, str], Primitive] = {}

def primitives_table_get(namespace: str, name: str):
if (namespace, name) in _PRIMITIVES_TABLE_:
return _PRIMITIVES_TABLE_[(namespace, name)]
raise NotImplementedError(
f"Op {(namespace, name)} not found in primitives table. (Did you import it yet?)."
f"Primitive table has {list(_PRIMITIVES_TABLE_.keys())}")

def primitives_table_set(namespace: str, name: str, primitive: Primitive):
assert (namespace, name) not in _PRIMITIVES_TABLE_, (
f"The op name {(namespace, name)} is already taken. "
"Try changing the namespace with Primitive(..., namespace='<YOUR_PROJECT_NAME>')."
)
_PRIMITIVES_TABLE_[(namespace, name)] = primitive


class Primitive:
# used to identify primitive singletons.
name: str
namespace: str
# set for multi-output primitives.
multiple_results: bool = False
# set for call primitives processed in final style.
call_primitive: bool = False
# set for map primitives processed in final style.
map_primitive: bool = False

def __init__(self, name: str):
def __init__(self, name: str, namespace: str="__jax_core__"):
self.name = name
self.namespace = namespace
primitives_table_set(namespace, name, self)

def __repr__(self):
return f'{self.name}'

def __reduce_ex__(self, __protocol) -> str | tuple[Any, ...]:
return primitives_table_get, (self.namespace, self.name)

def __deepcopy__(self, memo):
# Primitives are singletons, so copying is the same
# as returning self.
memo[id(self)] = self
return self

def bind(self, *args, **params):
assert (not config.enable_checks.value or
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/source_info_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import sysconfig
import threading
import types
from typing import Optional, NamedTuple, Union
from typing import Any, Optional, NamedTuple, Union

import jax.version
from jax._src.lib import xla_client
Expand Down Expand Up @@ -132,6 +132,10 @@ def replace(self, *, traceback: Optional[Traceback] = None,
self.name_stack if name_stack is None else name_stack
)

def __reduce_ex__(self, __protocol) -> str | tuple[Any, ...]:
# Dropping Traceback for now as it is not easy to pickle.
return SourceInfo, (None, self.name_stack)

def new_source_info() -> SourceInfo:
return SourceInfo(None, NameStack())

Expand Down
24 changes: 24 additions & 0 deletions tests/copying_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from cloudpickle import cloudpickle
from copy import deepcopy

import jax
from jax._src.core import jaxpr_as_fun

def test_jaxpr_cloudpickle():
def f(a, b):
return a + b, b * a

jxpr = jax.make_jaxpr(f)(1, 2)
encoded = cloudpickle.dumps(jxpr)
loaded_jxpr = cloudpickle.loads(encoded)

assert jaxpr_as_fun(loaded_jxpr)(5, 3) == [8, 15]

def test_jaxpr_deepcopy():
def f(a, b):
return b * a, a + b,

jxpr = jax.make_jaxpr(f)(1, 2)
copied_jaxpr = deepcopy(jxpr)

assert jaxpr_as_fun(copied_jaxpr)(5, 3) == [15, 8]

0 comments on commit 818358e

Please sign in to comment.