Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Jaxprs to be cloudpickle-able #9444

Open
chaserileyroberts opened this issue Feb 4, 2022 · 16 comments
Open

Allow Jaxprs to be cloudpickle-able #9444

chaserileyroberts opened this issue Feb 4, 2022 · 16 comments
Labels
enhancement New feature or request

Comments

@chaserileyroberts
Copy link
Contributor

chaserileyroberts commented Feb 4, 2022

Right now, if you try to pickle a jaxpr, you are given an error.

>>> import jax
>>> j = jax.make_jaxpr(f)(1)
>>> j
{ lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }
>>> import cloudpickle
>>> cloudpickle.dumps(j)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chase/anaconda3/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/home/chase/anaconda3/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py", line 563, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle 'jaxlib.xla_extension.Traceback' object

However, there isn't much that needs to change. Simply mapping the jaxpr.eqns to new_jaxpr_eqn that doesn't pass its source info allows the jaxpr to be pickleable.

j.jaxpr.eqns = [jax.core.new_jaxpr_eqn(x.invars, x.outvars, x.primitive, x.params) for x in j.jaxpr.eqns]
cloudpickle.dumps(j) # Works like a charm!
@chaserileyroberts chaserileyroberts added the enhancement New feature or request label Feb 4, 2022
@shoyer
Copy link
Collaborator

shoyer commented Feb 4, 2022

Would you mind commenting on the underlying use-case here?

@chaserileyroberts
Copy link
Contributor Author

Basically what I'm experimenting with is a way to decompose a jax program in to multiple jax+mpi programs that run distributed.

big_jaxpr = jax.make_jaxpr(some_func)(*args, **kwargs)
many_smaller_jaxprs = decompose_problem(big_jaxpr)
my_methods = [jax.jit(jax.core.jaxpr_as_func(j)) for j in many_smaller_jaxprs]
for i, jaxpr in enumerate(many_smaller_jaxprs):
  client.submit(my_methods, *args, **kwargs, worker=i)

Since all of the functions in my_methods have the jaxprs as local variables, those get pickeled when sent off to the remote worker. Right now I'm using the above hack on all of the smaller jaxprs, which works fine for now but it would be nice to have a tested and supported method.

Also, one thing that surprised me very much was that jited methods were always correctly cached on the remote workers. I had no issue with recompilation which I really wasn't expecting given the fact I'm mixing JAX + MPI and Dask.

@shoyer
Copy link
Collaborator

shoyer commented Feb 5, 2022

OK, interesting. I'll let core JAX team comment on the feasibility here, but I guess it could make sense to either make Traceback pickleable, or make including tracebacks in JAXprs optional.

@chaserileyroberts
Copy link
Contributor Author

chaserileyroberts commented Feb 5, 2022

Tracebacks in general are not pickleable sadly since they reference the full memory stack.

https://stackoverflow.com/questions/6132469/why-cant-i-pickle-an-errors-traceback-in-python

Your other suggestion of making tracebacks in Jaxprs optional is already the case I believe. If you don't include a source_info argument when calling new_jaxpr_eqn, no traceback is included. We could perhaps utilize that and add a include_tracebacks=False option to make_jaxpr that forces all of the eqns to not have tracebacks.

@shoyer
Copy link
Collaborator

shoyer commented Feb 7, 2022

Traceback here is a custom XLA/JAX things, so in principle we could override it:
https://github.com/tensorflow/tensorflow/blob/fb91f402331605db55f1cda9603e18835245e6d1/tensorflow/compiler/xla/python/traceback.cc

(that said, it does currently include Python stack frames)

@chaserileyroberts
Copy link
Contributor Author

Actually, upon further experimenting, I'm not sure pickleing is viable for the usecase I described above.

>>> import jax
>>> j = jax.make_jaxpr(lambda a, b: a + b)(1.0, 2.0)
>>> import cloudpickle
>>> j.jaxpr.eqns = [jax.core.new_jaxpr_eqn(x.invars, x.outvars, x.primitive, x.params) for x in j.jaxpr.eqns]
>>> s = cloudpickle.dumps(j)
>>> j2 = cloudpickle.loads(s)
>>> jax.core.jaxpr_as_fun(j2)(1.0, 2.0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'add' not found for platform cpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 147, in jaxpr_as_fun
    return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 330, in eval_jaxpr
    ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 272, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 275, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/core.py", line 591, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 92, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/util.py", line 202, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/util.py", line 195, in cached
    return f(*args, **kwargs)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 111, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/_src/dispatch.py", line 258, in lower_xla_callable
    module = mlir.lower_jaxpr_to_module(
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 409, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 549, in lower_jaxpr_to_fun
    out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
  File "/home/chase/anaconda3/envs/jaxcpu/lib/python3.8/site-packages/jax/interpreters/mlir.py", line 628, in jaxpr_subcomp
    raise NotImplementedError(
NotImplementedError: MLIR translation rule for primitive 'add' not found for platform cpu

Since the Primitives are actual python objects, and the translation rules are determined by dictionary lookups of the objects by reference, those references aren't maintained between a pickle dump/load. I'm not sure this can be supported without a huge amount of structural changes to JAX.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Feb 7, 2022

Well: that would suggest that you'd want to pickle primitives by name not by object identity...

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Feb 7, 2022

I think that is the difference between pickle and cloudpickle. Using the former should avoid the problem.

EDIT: yeah, but this won't work....

@chaserileyroberts
Copy link
Contributor Author

I think that is the difference between pickle and cloudpickle. Using the former should avoid the problem.

I wish it was that simple. Normal pickle has it's own struggles.

>>> j2 = pickle.dumps(j)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
_pickle.PicklingError: Can't pickle <function <lambda> at 0x7f55e2bf6550>: attribute lookup <lambda> on jax._src.lax.utils failed

@hawkinsp
Copy link
Collaborator

hawkinsp commented Feb 7, 2022

I meant: add logic to core.Primitive to have it pickle by name, not by identity.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Feb 8, 2022

To add to the comments about Traceback above: yes, Stephan is right, that theTraceback objects in jaxprs are JAX-internal traceback objects, not Python tracebacks. They exist for one reason only: they are optimized to be fast to collect.

e.g.:

In [1]: import jax

In [3]: Traceback = jax._src.lib.xla_client.Traceback

In [7]: %timeit x = Traceback.get_traceback()
798 ns ± 11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [8]: import inspect

In [10]: %timeit y = inspect.stack()
5.36 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [11]: import traceback

In [13]: %timeit x = traceback.extract_stack()
99.2 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

i.e. calling the JAX traceback is at least 3 orders of magnitude faster than inspect.stack() for the handful of stack frames in my ipython session, and about 2 orders of magnitude faster than traceback.extract_stack().

The tracebacks are simply a list of pairs of (code, lasti) values, where code is a types.CodeType and lasti is an int. We store code values because they are extremely cheap to collect from the interpreter stack, and we do not attempt to interpret them in any way when they are collected. We defer turning the code and lasti values into values meaningful to the user until we need them.

Now, if I were to serialize the traceback values, I'd want to look at two things:

  • most of the time, we only need a single "user frame" from the traceback. We collect the entire traceback because it's easier and cheaper than thinking about which bits we need, and later filter it. So for serialization, you probably don't even need the traceback: you could distill it into a single frame.
  • for serialization, you should not keep the code objects around; they will not make sense on a remote worker. Instead, you'd want to turn the traceback into a tuple containing the interpretation of the frame and serialize that.

The relevant method in the JAX tree is mainly source_info_util.user_frame(). You could in essence call that at serialization time and replace the source info value with a new alternative value that only contains a single frame of interest.

@EelcoHoogendoorn
Copy link

Just throwing my hat in the ring here for a similar feature request.

My use case is wanting to do model cloning on model variants of an evolving codebase; ie, make some architecture tweaks on my RL model, without having to learn from scratch. Making every little thing configurable does not really scale, nor really works since you are constantly breaking backwards compatibility as you evolve your model with progressive insight. One way to do this is to commit every little code change separately and then run that different code version in a separate python process and then clone by exchanging data across processes; but obviously that would also be a horrible pain, compared to the elegance of just being able to just read/write pure jax model.apply functions to disc.

I dont know how hard it will be to obtain a serializable representation; but I do think being able to do so would allow leveraging JAXs functional abilities in a nice way compared to other frameworks.

@chaserileyroberts
Copy link
Contributor Author

chaserileyroberts commented Oct 19, 2023

I work on jax full time now so I'm going to try and lead this.

The two issues are the traceback thing and the global dictionary lookups. The traceback thing can be solved I think by just implementing the encoding/decoding apis. The global dictionary one is harder to solve unfortunately. I'm not sure how to make clouldpickle treat the primitives like "singletons".

@chaserileyroberts
Copy link
Contributor Author

chaserileyroberts commented Oct 20, 2023

Ok so hacking in this change in core.py

# Table that stores all primitive definitions.
# Needed so that primitives are treated as singletons
# when using cloudpickle.
_PRIMITIVES_TABLE_: dict[tuple[str, str], Primitive] = {}

def primitives_table_get(namespace: str, name: str):
  if (namespace, name) in _PRIMITIVES_TABLE_:
    return _PRIMITIVES_TABLE_[(namespace, name)]
  raise NotImplementedError(
    f"Op {(namespace, name)} not found in primitives table. (Did you import it yet?).")

def primitives_table_set(namespace: str, name: str, primitive: Primitive):
  assert (namespace, name) not in _PRIMITIVES_TABLE_, (
    f"The op name {(namespace, name)} is already taken. "
    "Try changing the namespace with Primitive(..., namespace='<YOUR_PROJECT_NAME>')."
  )
  _PRIMITIVES_TABLE_[(namespace, name)] = primitive

class Primitive:
  name: str
  namespace: str
  # set for multi-output primitives.
  multiple_results: bool = False
  # set for call primitives processed in final style.
  call_primitive: bool = False
  # set for map primitives processed in final style.
  map_primitive: bool = False

  def __init__(self, name: str, namespace: str='__jax_internal__'):
    self.name = name
    self.namespace = namespace
    primitives_table_set(namespace, name, self)

  def __reduce__(self):
    return primitives_table_get, (self.namespace, self.name)

And this change just somewhere.

import jaxlib
jaxlib.xla_extension.Traceback.__reduce__ = lambda a: (lambda: None, ())

And now pickle works like a charm.

jxpr = jax.make_jaxpr(lambda a: a + a)(1)
res = cloudpickle.dumps(jxpr)
new_jxpr = cloudpickle.loads(res)
jax.core.jaxpr_as_fun(new_jxpr)(2)
# [Array(4, dtype=int32, weak_type=True)]

Nothing broke in vanilla jax so there are no obvious name collisions. I image however something in google3 will collide. When that happens, using the namespace='...' trick should fix most of the issues quickly.

Will raise a PR tomorrow.

@jjyyxx
Copy link
Contributor

jjyyxx commented Apr 25, 2024

https://gist.github.com/jjyyxx/f64e28f6ccc37c24af9fd17649710b26

I created a demo that successfully saves a traced JAX function using pickle only. This is important to me because during quick research, I often make many small tweaks without version control. Saving the traced JAX function provides a minimal interface to reproduce results later, even if the source code has changed.

Previously, I used ONNX, but it's slow, TensorFlow-dependent, and has many unsupported operations. Saving Jaxpr is much lighter and allows for separating the computation graph from data. I can save the Jaxpr once and only save parameters periodically during training.

Additionally, saving Jaxpr instead of compiled binaries or lowered IR gives me almost full control over further inference. It can be jitted, used on CPU/GPU, vmap-transformed, and exported to ONNX.

While it works for me, there are some unhandled edge cases:

  1. Partial eval function for custom_* (notably jax.nn.relu) is dropped. I lack deep understanding of JAX internals to handle this.
  2. Custom primitives are not handled.
  3. Host callbacks are not allowed.

The hackiest part is mapping primitive names to primitives, which currently involves scanning through modules.

@bionicles
Copy link
Contributor

Hey, I agree the jaxpr makes sense to pickle since it's a portable IR for jax jit functions. However, if we just had "fn_from_jaxpr" then we could save the jaxprs as strings instead of pickles, which would facilitate human readability of serialized jit functions.

One idea to do this, would be

  1. Test a bijection between PjitFunction and Jaxpr
  2. Rename make_jaxpr to jaxpr_from_fn and ideally remove the requirement to pass input
  3. Define fn_from_jaxpr to pass the test (how?)

Created issue:

jax jitted functions cloudpickled work but include some error messages #537 in cloudpipe / cloudpickle repo

Link: cloudpipe/cloudpickle#537

Pic:
Screenshot 2024-06-28 061537

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants