Skip to content
/ gwgan Public

Learning Generative Models across Incomparable Spaces (ICML 2019)

License

Notifications You must be signed in to change notification settings

bunnech/gwgan

Repository files navigation

GW GAN

(still under construction)

PyTorch Code for reproducing key results of the paper Learning Generative Models across Incomparable Spaces by Charlotte Bunne, David Alvarez-Melis, Andreas Krause and Stefanie Jegelka.

Citation

If you use this code for your research, please cite our paper.

@inproceedings{bunne2019,
  title={{Learning Generative Models across Incomparable Spaces}},
  author={Bunne, Charlotte and Alvarez-Melis, David and Krause, Andreas and Jegelka, Stefanie},
  year={2019},
  booktitle={International Conference on Machine Learning (ICML)},
  volume={97},
}

Installation

To reproduce the experiments, download the source code from the Git repository.

git clone https://github.com/bunnech/gw_gan

Known dependencies: Python (3.7.2), numpy (1.15.4), pandas (0.23.4), matplotlib (3.0.2), seaborn (0.9.0), torch (1.0.0), torchvision (0.2.1).

Experiments

We provide the source code to run GW GAN on a 2D Gaussians dataset as well as on MNIST, fashion-MNIST and gray-scale CIFAR.

on a 2D Gaussians Dataset

In order to reproduce experiments on 2D Gaussians, you can either run the bash script run_gwgan_mlp.sh with pre-defined settings. Alternatively, call

python3 main_gwgan_mlp.py --modes 4mode --num_iter 10000 --l1reg

with the following environment options:

  • --modes defined the number of modes. Available options are 4mode, 5mode, and 8mode. To generate samples in 2D from 3D data, choose 3d_4mode
  • --num_iter defines the number of training iterations (recommended: 10000)
  • --l1reg is a flag which activates l1-regularization (see paper for details)
  • --advsy is a flag which activates the adversary.
  • --id for identification of the training run.

on MNIST, fashion-MNIST and gray-scale CIFAR Dataset

In order to reproduce experiments on 2D Gaussians, you can either run the bash script run_gwgan_cnn.sh with pre-defined settings. Alternatively, call

python3 main_gwgan_cnn.py --data fmnist --num_epochs 100 --beta 35

with the following environment options:

  • --data selects the dataset. Choose between mnist, fmnist and cifar_gray.
  • --num_epochs defines the number of training epochs (default: 200)
  • --n_channels defines the number of channels of the CNN architecture (default=1).
  • --beta defines the parameter of the Procrustes-based orthogonal regularization (recommended for MNIST (mnist): 32, fashion MNIST (fmnist): 35, gray-scale CIFAR (cifar_gray): 40)
  • --cuda is a flag to run the code on GPUs.
  • --id for identification of the training run.

Code Structure

.optra/gromov_wasserstein.py contains the implementation of the Gromov-Wasserstein discrepancy with the mofifications described in the paper.

.model/ contains scripts to generate datasets (data.py), the computation of the loss and regularization approaches (loss.py) as well as the network architectures (model_mlp.py and model_cnn.py).

License

This project is licensed under the MIT License - see the LICENSE file for details.

About

Learning Generative Models across Incomparable Spaces (ICML 2019)

Resources

License

Stars

Watchers

Forks