Skip to content

Commit

Permalink
Move training-only dependencies to [train] extra
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Feb 13, 2024
1 parent 6ab5146 commit b38ba1f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ TODO: write this section

## Installation

`k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e <path to repository>`.
`k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on.

To run the training and inference scripts, clone this repository and run `pip install -e <path to repository>[train]`
(to install with the `train` extra that includes additional libraries required for training).

## Training

Expand Down
11 changes: 8 additions & 3 deletions k_diffusion/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import operator

import numpy as np
from skimage import transform
try:
import skimage.transform as skt
except ImportError:
skt = None
import torch
from torch import nn

Expand Down Expand Up @@ -31,6 +34,8 @@ def rotate2d(theta):

class KarrasAugmentationPipeline:
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8, disable_all=False):
if not skt:
raise ImportError('Please install scikit-image to use KarrasAugmentationPipeline')
self.a_prob = a_prob
self.a_scale = a_scale
self.a_aniso = a_aniso
Expand Down Expand Up @@ -78,9 +83,9 @@ def __call__(self, image):
image_orig = np.array(image, dtype=np.float32) / 255
if image_orig.ndim == 2:
image_orig = image_orig[..., None]
tf = transform.AffineTransform(mat.numpy())
tf = skt.AffineTransform(mat.numpy())
if not self.disable_all:
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
image = skt.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
else:
image = image_orig
cond = torch.zeros_like(cond)
Expand Down
5 changes: 4 additions & 1 deletion k_diffusion/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from pathlib import Path

from cleanfid.inception_torchscript import InceptionV3W
import clip
import torch
from torch import nn
Expand All @@ -16,6 +15,10 @@
class InceptionV3FeatureExtractor(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
try:
from cleanfid.inception_torchscript import InceptionV3W
except ImportError as ie:
raise ImportError('Please install clean-fid to use InceptionV3FeatureExtractor') from ie
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
Expand Down
9 changes: 6 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,24 @@ license = MIT

[options]
packages = find:
install_requires =
install_requires =
accelerate
clean-fid
clip-anytorch
dctorch
einops
jsonmerge
kornia
Pillow
safetensors
scikit-image
scipy
torch >= 2.1
torchdiffeq
torchsde
torchvision
tqdm

[options.extras_require]
train =
clean-fid
scikit-image
wandb

0 comments on commit b38ba1f

Please sign in to comment.