-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cloudpickle and deepcopy support for Jaxprs #18243
Cloudpickle and deepcopy support for Jaxprs #18243
Conversation
6e40270
to
818358e
Compare
946edc1
to
34772c1
Compare
Classic case of works on my machine. Not sure what's wrong with the docs build |
34772c1
to
a51f095
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the intended use case: is it necessary to render Jaxpr pickleable, or would it suffice to wrap things up in jit
and de/serialize the jitted function instead? Jitted functions are already pickleable.
Not for my usecase. The jaxpr is still picked up in the closure, making the pickle impossible. >>> import jax
>>> import cloudpickle
>>> def f(a):
... return a + a
...
>>> j = jax.make_jaxpr(f)(1)
>>> cloudpickle.cloudpickle.dumps(j)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/chase/anaconda3/envs/python311/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
cp.dump(obj)
File "/home/chase/anaconda3/envs/python311/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
return super().dump(obj)
^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'jaxlib.xla_extension.Traceback' object
>>>
>>> jitted = jax.jit(jax.core.jaxpr_as_fun(j))
>>> cloudpickle.cloudpickle.dumps(jitted)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/chase/anaconda3/envs/python311/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
cp.dump(obj)
File "/home/chase/anaconda3/envs/python311/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
return super().dump(obj)
^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'jaxlib.xla_extension.Traceback' object
>>> s = jax.jit(jax.core.jaxpr_as_fun(j)).lower(1)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
>>> s
<jax._src.stages.Lowered object at 0x7fa49bc19070>
>>> cloudpickle.cloudpickle.dumps(s)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/chase/anaconda3/envs/clean11/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
cp.dump(obj)
File "/home/chase/anaconda3/envs/clean11/lib/python3.11/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
return super().dump(obj)
^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'jaxlib.mlir._mlir_libs._mlir.ir.Module' object |
That wasn't exactly the question, though. You can pickle |
Because my goal is to be able to transform |
For example: if my transform is a randomized optimization, I would need to ensure all of the workers came to the exact same solution. Possible, but incredibly fragile. Solving this first on the client before dispatching prevents this issue forever. |
I want jaxprs to be able to be passed around, modified, and executed using standard distributed python tooling. This could have a wide range of applications beyond just pickling You could imagine the jaxprs being used with a Ray / Dask server just for the compiler optimization and not even the actual execution. Something like: # Use the remote cluster to optimize your jaxpr.
jxpr = best_of([optimize.remote(f, seed) for seed in initial_seeds], my_metrics)
# Execute the jaxpr locally on your GPU.
jax.jit(jaxpr_as_fun(jxpr))(my_args) |
I think we can imagine similar such things (in fact, @mattjj had a branch doing something similar for a Ray prototype a while back, though we never merged it). While we all broadly agree that this can be potentially useful, it's not free, hence the questions about alternatives. For your immediate, concrete use case, does pickling |
Sadly not really, I've tried and it's usually pretty painful.
This is the setup I want to solve: I have a jax function, I've come up with several ideas to do this
Honestly, these are the only solutions I could think of. I've had to both 1) and 2) at various times in the past and they always are very fragile. If instead I could have 3) JustWork™, it would significantly simplify a lot of the dispatching infrastructure for auto-partitioning work. There could be a forth even easier solution I am missing, but I haven't found it yet. Ideas are welcomed!
Nothing is ever free, but what is the cost we're trying to avoid here? We have unit tests that will catch obvious problems quickly, and I am happy to be the one responsible to fix issues related to this (I'll probably be the one hitting issues the most anyway lol). I can see the argument against adding another global dictionary to manage, but we already use similar global dictionaries for I can also see an argument against the name strings being used for infrastructure. I also don't like this either, but the inclusion of So the cost is:
The value:
In my opinion it's super worth it. |
At one point not long ago, @pschuh made executables experimentally serializable by relying on pickle's persistent ID mechanism. See: Could something similar be useful here, in particular to decouple a bit from the jax core type definitions (especially if we want this actually decoupled at first)? https://docs.python.org/3/library/pickle.html#persistence-of-external-objects |
fed1c78
to
58fbc85
Compare
@froystig please take a look at the latest implementation. I think it should be a much more agreeable solution than what I had before. |
f7c4610
to
8f4ac71
Compare
With #18243 (comment), why can't this live in your experimental project? In other words, pickling jaxprs doesn't have to live in JAX with the above approach I think :) I would recommend that you try out what's recommended in the above comment and see if that works? |
8f4ac71
to
3ef0609
Compare
3ef0609
to
e923b69
Compare
Should we instead be serializing the stablehlo if you want it post-transform? |
I want to stay in Jaxpr / python land. The serialization is less important than its compatibility with standard python cloud tooling.
I can do anything I want internally, but I think this is valuable enough to the larger OSS community for it to exist and work easily. |
Closing as stale. |
Fixes #9444
Needed to implement the methods
__reduce_ex__
and__deepcopy__
onPrimitive
andSourceInfo
. Added unit tests in a new filecopying_test.py
.This is generally useful in distributed environments, i.e., I can make a transform to create a
shard_map
ed jaxpr, send this jaxpr to each of my worker nodes over the network via Ray / Dask (both of which use cloudpickle), and then just execute the jaxprs.