You can run this code on your own machine or on Google Colab.
-
Local option: If you choose to run locally, you will need to install some Python packages; see installation.md for instructions.
-
Colab: The first few sections of the notebook will install all required dependencies. You can try out the Colab option by clicking the badge below:
Fill in sections marked with TODO
. In particular, see
- TODO
See the homework pdf for more details.
### TODO
This assignment was created starting with the following three base repositories:
- The original paper repo FeatureLearningRotNet. While this repo was a good first implementation, it is written in an extremely old version of PyTorch and so was not ideal to base the assignment on. It was however critical in understanding the concepts in the paper more deeply.
- The repo Self_Supervised_CNN-RotNet in Tensorflow which was written more recently, allowed us to overcome some of the hurdles in the original paper repository. We based our JAX conversion on the framework provided here very heavily. Many thanks to the creators.
- The repo Flax-ResNets implements ResNets in both PyTorch and Flax and was very helpful in understanding the various Jax/Flax ideas to implement something on our own.
Other useful links and guides in no particular order:
-
The Flax Getting Started Guide as the quintessential source for creating neural networks using Jax.
-
The Flax Guide to Transfer Learning, Extracting Intermediates and this Kaggle Notebook example.
We would like to note that setting up transfer learning in Jax/Flax in this manner was extremely challenging and requires a lot to be desired in terms of how the framework is written.
-
How to think in JAX: Especially to understand what JIT is and how to work with it.
Pay close attention to static_argnums to save yourself a lot of pain later on!
- Tianlun Zhang | Email @berkeley.edu
- Jackson Gao | Email @berkeley.edu
- Jaewon Lee | Email @berkeley.edu
- Aman Saraf | Email @berkeley.edu