Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cloudpickle and deepcopy support for Jaxprs #18243

Conversation

chaserileyroberts
Copy link
Contributor

@chaserileyroberts chaserileyroberts commented Oct 23, 2023

Fixes #9444

Needed to implement the methods __reduce_ex__ and __deepcopy__ on Primitive and SourceInfo. Added unit tests in a new file copying_test.py.

This is generally useful in distributed environments, i.e., I can make a transform to create a shard_maped jaxpr, send this jaxpr to each of my worker nodes over the network via Ray / Dask (both of which use cloudpickle), and then just execute the jaxprs.

@chaserileyroberts
Copy link
Contributor Author

Classic case of works on my machine. Not sure what's wrong with the docs build

jax/_src/core.py Outdated Show resolved Hide resolved
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the intended use case: is it necessary to render Jaxpr pickleable, or would it suffice to wrap things up in jit and de/serialize the jitted function instead? Jitted functions are already pickleable.

@chaserileyroberts
Copy link
Contributor Author

chaserileyroberts commented Oct 24, 2023

Jitted functions are already pickleable.

Not for my usecase. The jaxpr is still picked up in the closure, making the pickle impossible.

>>> import jax
>>> import cloudpickle
>>> def f(a):
...     return a + a
... 
>>> j = jax.make_jaxpr(f)(1)
>>> cloudpickle.cloudpickle.dumps(j)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chase/anaconda3/envs/python311/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
    cp.dump(obj)
  File "/home/chase/anaconda3/envs/python311/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
    return super().dump(obj)
           ^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'jaxlib.xla_extension.Traceback' object
>>>
>>> jitted = jax.jit(jax.core.jaxpr_as_fun(j))
>>> cloudpickle.cloudpickle.dumps(jitted)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chase/anaconda3/envs/python311/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
    cp.dump(obj)
  File "/home/chase/anaconda3/envs/python311/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
    return super().dump(obj)
           ^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'jaxlib.xla_extension.Traceback' object

lower(...) isn't pickleable either.

>>> s = jax.jit(jax.core.jaxpr_as_fun(j)).lower(1)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
>>> s
<jax._src.stages.Lowered object at 0x7fa49bc19070>
>>> cloudpickle.cloudpickle.dumps(s)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chase/anaconda3/envs/clean11/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
    cp.dump(obj)
  File "/home/chase/anaconda3/envs/clean11/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
    return super().dump(obj)
           ^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'jaxlib.mlir._mlir_libs._mlir.ir.Module' object

@hawkinsp
Copy link
Collaborator

That wasn't exactly the question, though. You can pickle jit(f), no problem. Why is the jaxpr the right thing to pickle?

@chaserileyroberts
Copy link
Contributor Author

Why is the jaxpr the right thing to pickle?

Because my goal is to be able to transform f before sending it off to the workers. I can not do that with just the original definition of f unless I do that transform instead on every single worker, which is not scalable for my goals.

@chaserileyroberts
Copy link
Contributor Author

which is not scalable for my goals.

For example: if my transform is a randomized optimization, I would need to ensure all of the workers came to the exact same solution. Possible, but incredibly fragile. Solving this first on the client before dispatching prevents this issue forever.

@chaserileyroberts
Copy link
Contributor Author

chaserileyroberts commented Oct 24, 2023

I want jaxprs to be able to be passed around, modified, and executed using standard distributed python tooling. This could have a wide range of applications beyond just pickling jit(f)

You could imagine the jaxprs being used with a Ray / Dask server just for the compiler optimization and not even the actual execution. Something like:

# Use the remote cluster to optimize your jaxpr.
jxpr = best_of([optimize.remote(f, seed) for seed in initial_seeds], my_metrics)

# Execute the jaxpr locally on your GPU.
jax.jit(jaxpr_as_fun(jxpr))(my_args)

@froystig
Copy link
Member

I think we can imagine similar such things (in fact, @mattjj had a branch doing something similar for a Ray prototype a while back, though we never merged it). While we all broadly agree that this can be potentially useful, it's not free, hence the questions about alternatives. For your immediate, concrete use case, does pickling jit(f) suffice? Could you do anything else short of pickling jaxpr?

@chaserileyroberts
Copy link
Contributor Author

chaserileyroberts commented Oct 25, 2023

For your immediate, concrete use case, does pickling jit(f) suffice?

Sadly not really, I've tried and it's usually pretty painful.

Could you do anything else short of pickling jaxpr?

This is the setup I want to solve: I have a jax function, f, and I have a transform trfm, which can be expensive, and possibly non deterministic. I want to execute trfm(f)(args...) in an SPMD fashion on a distributed mesh

I've come up with several ideas to do this

  1. Make all of the transforms execute on the workers.

    • This is what is recommended currently in JAX.
    • This is reasonable for well-defined, deterministic transforms (i.e., grad, vmap, vjp), but can become difficult if trfm needs to do expensive optimization searches, or if any randomization is used.
    • When developers create their own custom transform as jax.extend evolves, they're going to have a bad time if when they need to debug a nondeterminism bug across a cluster.
  2. Make some kind of separate IR that is pickled instead of the Jaxprs.

    • I don't think anyone wants to support this lol.
  3. Support pickling Jaxprs.

    • This PR.
    • In this setup, our Jaxpr can be derived locally if trfm is randomized or complicated. We then dispatch this jaxpr with Ray or Dask (via cloudpickle) to the entire mesh.
    • All nodes have the exact same Jaxpr, so when you run jit(jaxpr_as_fun(jaxpr))(arg...) on all nodes simultaneously, the chance of bugs related to nodes running mismatching SPMD binaries drops significantly.

Honestly, these are the only solutions I could think of. I've had to both 1) and 2) at various times in the past and they always are very fragile. If instead I could have 3) JustWork™, it would significantly simplify a lot of the dispatching infrastructure for auto-partitioning work.

There could be a forth even easier solution I am missing, but I haven't found it yet. Ideas are welcomed!

While we all broadly agree that this can be potentially useful, it's not free, hence the questions about alternatives

Nothing is ever free, but what is the cost we're trying to avoid here? We have unit tests that will catch obvious problems quickly, and I am happy to be the one responsible to fix issues related to this (I'll probably be the one hitting issues the most anyway lol).

I can see the argument against adding another global dictionary to manage, but we already use similar global dictionaries for vmap, jit, grad, and well, basically everything! It's not a weird thing to see in the JAX codebase.

I can also see an argument against the name strings being used for infrastructure. I also don't like this either, but the inclusion of namespace and possibly also including the jax.__version__ (I should add this...), should be enough to avoid conflicts/compatibility issues. There are no name conflicts as it stands today (at least in OSS land), and again issues could likely be caught quickly with good unit tests.

So the cost is:

  • Manage 40 new LOC, a single extra global dictionary, and a few unit tests.
  • Risk that we add new attributes that are not pickleable in the future and have to deal with them.
    • Unit tests will likely catch it, and you can None them out in a __reduce_ex__ method like I did with Traceback. Annoying but not terrible.

The value:

  • Ray and Dask are automatically fully compatible with Jaxprs.
  • Unlocks distributed jaxpr optimizations and dispatch.

In my opinion it's super worth it.

@froystig
Copy link
Member

At one point not long ago, @pschuh made executables experimentally serializable by relying on pickle's persistent ID mechanism. See:

https://github.com/google/jax/blob/cd177fd5663e1f25c94e76e6babf6d676c8f5c50/jax/experimental/serialize_executable.py#L62-L91

Could something similar be useful here, in particular to decouple a bit from the jax core type definitions (especially if we want this actually decoupled at first)?

https://docs.python.org/3/library/pickle.html#persistence-of-external-objects
https://docs.python.org/3/library/pickle.html#dispatch-tables
https://docs.python.org/3/library/pickle.html#custom-reduction-for-types-functions-and-other-objects

@chaserileyroberts chaserileyroberts force-pushed the chase/fix/pickle_support branch 2 times, most recently from fed1c78 to 58fbc85 Compare October 26, 2023 20:27
@chaserileyroberts
Copy link
Contributor Author

@froystig please take a look at the latest implementation. I think it should be a much more agreeable solution than what I had before.

@chaserileyroberts chaserileyroberts force-pushed the chase/fix/pickle_support branch 3 times, most recently from f7c4610 to 8f4ac71 Compare October 26, 2023 20:40
@yashk2810
Copy link
Collaborator

With #18243 (comment), why can't this live in your experimental project? In other words, pickling jaxprs doesn't have to live in JAX with the above approach I think :)

I would recommend that you try out what's recommended in the above comment and see if that works?

@pschuh
Copy link
Collaborator

pschuh commented Oct 26, 2023

Should we instead be serializing the stablehlo if you want it post-transform?

@chaserileyroberts
Copy link
Contributor Author

chaserileyroberts commented Oct 26, 2023

Should we instead be serializing the stablehlo if you want it post-transform?

I want to stay in Jaxpr / python land. The serialization is less important than its compatibility with standard python cloud tooling.

why can't this live in your experimental project?

I can do anything I want internally, but I think this is valuable enough to the larger OSS community for it to exist and work easily.

@chaserileyroberts
Copy link
Contributor Author

Closing as stale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow Jaxprs to be cloudpickle-able
6 participants