Skip to content

Efficient empirical NTKs in PyTorch

Notifications You must be signed in to change notification settings

jiaxi98/empirical-ntks

 
 

Repository files navigation

Empirical NTKs in PyTorch

This repository contains code for efficiently computing empirical NTKs and is published alongside our paper "More Than a Toy: Random Matrix Models Predict How Real-World Neural Representations Generalize" (ICML 2022).

Usage

The following command computes the empirical NTK for a subset of CIFAR-10 (specifically, on the first 2,500 train samples and the first 1,000 test samples). The output is a 3,500 x 2,500 matrix.

python3 ntk.py CIFAR-10_0_2500_0_1000 resnet-18_pretrained --workers-per-device 2 --grad-chunksize 1900000 --mm-col-chunksize 20000 --loader-batch-size 50 --loader-num-workers 12

The following command computes the empirical NTK for all of CIFAR-10. The output is a 60,000 x 50,000 matrix.

python3 ~/empirical-ntks/ntk.py CIFAR-10 resnet-18_pretrained --workers-per-device 4 --grad-chunksize 1900000 --mm-col-chunksize 20000 --loader-batch-size 50 --loader-num-workers 12

To work with other datasets or models, see utils.py for further options.

Implementation

We pursue a very simple strategy for computing the empirical NTK: compute the N x P Jacobian matrix (for N samples and P parameters) and multiply it with its transpose. To make this computation feasible, we compute the Jacobian matrix in chunks along the P axis with matrices of size N x P0 (where P0 is set by --grad-chunksize). We store this (still large) matrix in RAM. For each chunk, we then compute the N x N matrix obtained by multiplying each chunk by its transpose; for each such computation, we again chunk along the P axis (and optionally along the N axis), sending each matrix multiplication to the GPU. This latter matrix multiplication step is typically the bottleneck in computation time.

By optimizing data transfer, increasing GPU utilization, and parallelizing with care, our implementation improves significantly over naive baselines. See ntk.py for implementation details.

Performance

Our library computes an empirical NTK (60,000 x 50,000) for a ResNet-18 over CIFAR-10 at float32 precision in 43 minutes (<1e-6 seconds per NTK entry) on a machine with four A100 GPUs and 755GB RAM.

Citation

If you find this code useful in your research, please consider citing our paper:

@inproceedings{wei2022more,
  title = {More Than a Toy: Random Matrix Models Predict How Real-World Neural Representations Generalize},
  author = {Wei, Alexander and Hu, Wei and Steinhardt, Jacob},
  booktitle = {Proceedings of the 39th International Conference on Machine Learning},
  year = {2022}
}

About

Efficient empirical NTKs in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%