-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow Jaxprs to be cloudpickle-able #9444
Comments
Would you mind commenting on the underlying use-case here? |
Basically what I'm experimenting with is a way to decompose a jax program in to multiple jax+mpi programs that run distributed. big_jaxpr = jax.make_jaxpr(some_func)(*args, **kwargs)
many_smaller_jaxprs = decompose_problem(big_jaxpr)
my_methods = [jax.jit(jax.core.jaxpr_as_func(j)) for j in many_smaller_jaxprs]
for i, jaxpr in enumerate(many_smaller_jaxprs):
client.submit(my_methods, *args, **kwargs, worker=i) Since all of the functions in Also, one thing that surprised me very much was that |
OK, interesting. I'll let core JAX team comment on the feasibility here, but I guess it could make sense to either make |
https://stackoverflow.com/questions/6132469/why-cant-i-pickle-an-errors-traceback-in-python Your other suggestion of making tracebacks in Jaxprs optional is already the case I believe. If you don't include a |
(that said, it does currently include Python stack frames) |
Actually, upon further experimenting, I'm not sure pickleing is viable for the usecase I described above. >>> import jax
>>> j = jax.make_jaxpr(lambda a, b: a + b)(1.0, 2.0)
>>> import cloudpickle
>>> j.jaxpr.eqns = [jax.core.new_jaxpr_eqn(x.invars, x.outvars, x.primitive, x.params) for x in j.jaxpr.eqns]
>>> s = cloudpickle.dumps(j)
>>> j2 = cloudpickle.loads(s)
>>> jax.core.jaxpr_as_fun(j2)(1.0, 2.0)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'add' not found for platform cpu
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 147, in jaxpr_as_fun
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 330, in eval_jaxpr
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 272, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 275, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 591, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 92, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/util.py", line 202, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/util.py", line 195, in cached
return f(*args, **kwargs)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 111, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 258, in lower_xla_callable
module = mlir.lower_jaxpr_to_module(
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 409, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 549, in lower_jaxpr_to_fun
out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 628, in jaxpr_subcomp
raise NotImplementedError(
NotImplementedError: MLIR translation rule for primitive 'add' not found for platform cpu Since the |
Well: that would suggest that you'd want to pickle primitives by name not by object identity... |
I think that is the difference between EDIT: yeah, but this won't work.... |
I wish it was that simple. Normal pickle has it's own struggles. >>> j2 = pickle.dumps(j)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
_pickle.PicklingError: Can't pickle <function <lambda> at 0x7f55e2bf6550>: attribute lookup <lambda> on jax._src.lax.utils failed |
I meant: add logic to |
To add to the comments about e.g.:
i.e. calling the JAX traceback is at least 3 orders of magnitude faster than The tracebacks are simply a list of pairs of Now, if I were to serialize the traceback values, I'd want to look at two things:
The relevant method in the JAX tree is mainly |
Just throwing my hat in the ring here for a similar feature request. My use case is wanting to do model cloning on model variants of an evolving codebase; ie, make some architecture tweaks on my RL model, without having to learn from scratch. Making every little thing configurable does not really scale, nor really works since you are constantly breaking backwards compatibility as you evolve your model with progressive insight. One way to do this is to commit every little code change separately and then run that different code version in a separate python process and then clone by exchanging data across processes; but obviously that would also be a horrible pain, compared to the elegance of just being able to just read/write pure jax model.apply functions to disc. I dont know how hard it will be to obtain a serializable representation; but I do think being able to do so would allow leveraging JAXs functional abilities in a nice way compared to other frameworks. |
I work on jax full time now so I'm going to try and lead this. The two issues are the traceback thing and the global dictionary lookups. The traceback thing can be solved I think by just implementing the encoding/decoding apis. The global dictionary one is harder to solve unfortunately. I'm not sure how to make clouldpickle treat the primitives like "singletons". |
Ok so hacking in this change in # Table that stores all primitive definitions.
# Needed so that primitives are treated as singletons
# when using cloudpickle.
_PRIMITIVES_TABLE_: dict[tuple[str, str], Primitive] = {}
def primitives_table_get(namespace: str, name: str):
if (namespace, name) in _PRIMITIVES_TABLE_:
return _PRIMITIVES_TABLE_[(namespace, name)]
raise NotImplementedError(
f"Op {(namespace, name)} not found in primitives table. (Did you import it yet?).")
def primitives_table_set(namespace: str, name: str, primitive: Primitive):
assert (namespace, name) not in _PRIMITIVES_TABLE_, (
f"The op name {(namespace, name)} is already taken. "
"Try changing the namespace with Primitive(..., namespace='<YOUR_PROJECT_NAME>')."
)
_PRIMITIVES_TABLE_[(namespace, name)] = primitive
class Primitive:
name: str
namespace: str
# set for multi-output primitives.
multiple_results: bool = False
# set for call primitives processed in final style.
call_primitive: bool = False
# set for map primitives processed in final style.
map_primitive: bool = False
def __init__(self, name: str, namespace: str='__jax_internal__'):
self.name = name
self.namespace = namespace
primitives_table_set(namespace, name, self)
def __reduce__(self):
return primitives_table_get, (self.namespace, self.name) And this change just somewhere. import jaxlib
jaxlib.xla_extension.Traceback.__reduce__ = lambda a: (lambda: None, ()) And now pickle works like a charm. jxpr = jax.make_jaxpr(lambda a: a + a)(1)
res = cloudpickle.dumps(jxpr)
new_jxpr = cloudpickle.loads(res)
jax.core.jaxpr_as_fun(new_jxpr)(2)
# [Array(4, dtype=int32, weak_type=True)] Nothing broke in vanilla jax so there are no obvious name collisions. I image however something in google3 will collide. When that happens, using the Will raise a PR tomorrow. |
https://gist.github.com/jjyyxx/f64e28f6ccc37c24af9fd17649710b26 I created a demo that successfully saves a traced JAX function using pickle only. This is important to me because during quick research, I often make many small tweaks without version control. Saving the traced JAX function provides a minimal interface to reproduce results later, even if the source code has changed. Previously, I used ONNX, but it's slow, TensorFlow-dependent, and has many unsupported operations. Saving Jaxpr is much lighter and allows for separating the computation graph from data. I can save the Jaxpr once and only save parameters periodically during training. Additionally, saving Jaxpr instead of compiled binaries or lowered IR gives me almost full control over further inference. It can be jitted, used on CPU/GPU, vmap-transformed, and exported to ONNX. While it works for me, there are some unhandled edge cases:
The hackiest part is mapping primitive names to primitives, which currently involves scanning through modules. |
Hey, I agree the jaxpr makes sense to pickle since it's a portable IR for jax jit functions. However, if we just had "fn_from_jaxpr" then we could save the jaxprs as strings instead of pickles, which would facilitate human readability of serialized jit functions. One idea to do this, would be
Created issue: jax jitted functions cloudpickled work but include some error messages #537 in cloudpipe / cloudpickle repo |
Right now, if you try to pickle a jaxpr, you are given an error.
However, there isn't much that needs to change. Simply mapping the
jaxpr.eqns
tonew_jaxpr_eqn
that doesn't pass its source info allows the jaxpr to be pickleable.The text was updated successfully, but these errors were encountered: