Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for other pytorch device types, including MPS #1445

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def __init__(self, opt):
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
self.device = opt.torch_devices[0]
self.torch_devices = opt.torch_devices
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -153,11 +153,10 @@ def save_networks(self, epoch):
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)

if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
# Saving requires moving net to cpu
torch.save(net.module.cpu().state_dict(), save_path)
# Move net back to original device
net.to(self.device)

def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
Expand Down
8 changes: 4 additions & 4 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ def __init__(self, opt):
# The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
not opt.no_dropout, opt.init_type, opt.init_gain, self.torch_devices)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
not opt.no_dropout, opt.init_type, opt.init_gain, self.torch_devices)

if self.isTrain: # define discriminators
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.torch_devices)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.torch_devices)

if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
Expand Down
27 changes: 13 additions & 14 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
from torch.optim import lr_scheduler


from util.util import backend_available
###############################################################################
# Helper Functions
###############################################################################
Expand Down Expand Up @@ -98,25 +98,24 @@ def init_func(m): # define the initialization function
net.apply(init_func) # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
def init_net(net, init_type='normal', init_gain=0.02, torch_devices=[]):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
torch_devices (torch.device list) -- which devices the network runs on: e.g., cuda:0,cuda:1

Return an initialized network.
"""
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
assert(backend_available(torch_devices[0].type), f'Backend {torch_devices[0].type} not available')
net.to(torch_devices[0])
net = torch.nn.DataParallel(net, torch_devices) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net


def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, torch_devices=[]):
"""Create a generator

Parameters:
Expand All @@ -128,7 +127,7 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
use_dropout (bool) -- if use dropout layers.
init_type (str) -- the name of our initialization method.
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
torch_devices (torch.device list) -- which devices the network runs on: e.g., cuda:0,cuda:1

Returns a generator

Expand Down Expand Up @@ -156,10 +155,10 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
return init_net(net, init_type, init_gain, gpu_ids)
return init_net(net, init_type, init_gain, torch_devices)


def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, torch_devices=[]):
"""Create a discriminator

Parameters:
Expand All @@ -170,7 +169,7 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
norm (str) -- the type of normalization layers used in the network.
init_type (str) -- the name of the initialization method.
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
torch_devices (torch.device list) -- which devices the network runs on: e.g., cuda:0,cuda:1

Returns a discriminator

Expand Down Expand Up @@ -200,7 +199,7 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
return init_net(net, init_type, init_gain, gpu_ids)
return init_net(net, init_type, init_gain, torch_devices)


##############################################################################
Expand Down Expand Up @@ -282,7 +281,7 @@ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', const
netD (network) -- discriminator network
real_data (tensor array) -- real images
fake_data (tensor array) -- generated images from the generator
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
device (str) -- GPU / CPU / MPS
type (str) -- if we mix real and fake data or not [real | fake | mixed].
constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
lambda_gp (float) -- weight for this loss
Expand Down
4 changes: 2 additions & 2 deletions models/pix2pix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def __init__(self, opt):
self.model_names = ['G']
# define networks (both generator and discriminator)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
not opt.no_dropout, opt.init_type, opt.init_gain, self.torch_devices)

if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.torch_devices)

if self.isTrain:
# define loss functions
Expand Down
2 changes: 1 addition & 1 deletion models/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, opt):
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
self.model_names = ['G']
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, torch_devices=self.torch_devices)
if self.isTrain: # only defined during training time
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
Expand Down
2 changes: 1 addition & 1 deletion models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, opt):
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
self.model_names = ['G' + opt.model_suffix] # only generator is needed.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.torch_devices)

# assigns the model to self.netG_[suffix] so that it can be loaded
# please see <BaseModel.load_networks>
Expand Down
33 changes: 23 additions & 10 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def initialize(self, parser):
parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--use_wandb', action='store_true', help='use wandb')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--gpu_ids', type=str, default='-1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU', dest='gpu_ids_str')
parser.add_argument('--device_ids', type=str, default='0', help='identifiers ("ordinals") of torch devices (i.e. GPUs, CPUs) to use, e.g. 0 or 0,2. ', dest='device_ids_str')
parser.add_argument('--device_type', type=str, default='cuda', help='torch device to use [cpu | cuda | mps]')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
# model parameters
parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
Expand Down Expand Up @@ -123,15 +125,26 @@ def parse(self):

self.print_options(opt)

# set gpu ids
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
device_ids = [int(i) for i in opt.device_ids_str.split(',') if int(i) >= 0]
gpu_ids = [int(i) for i in opt.gpu_ids_str.split(',') if int(i) >= 0]

# convert old-style `gpu_ids` arg into new-style `torch_devices`
# while catching invalid combinations of `gpu_ids` and device_type
if len(gpu_ids) > 0:
if opt.device_type !='cuda':
#  This is the canonical way to raise an error from a parser
self.parser.error('If --gpu_ids is specified, --device_type must be "cuda". To specify non-cuda device ids, use --device_ids')
device_ids = gpu_ids

# This doesn't currently allow specifying devices with a mix of different types
opt.torch_devices = [torch.device(f'{opt.device_type}:{i}') for i in device_ids]

if opt.device_type == 'cuda':
torch.cuda.device(device_ids[0])
elif opt.device_type == 'mps':
if device_ids != [0]:
#  This is the canonical way to raise an error from a parser
self.parser.error('Devices other than "0" are not currently supported with MPS backend.')

self.opt = opt
return self.opt
Empty file added tests/__init__.py
Empty file.
82 changes: 82 additions & 0 deletions tests/test_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import unittest
from unittest.mock import patch
import torch

import sys
import argparse

from options.train_options import TrainOptions

class TestTrainOptions(unittest.TestCase):

required_args = ['train.py', '--dataroot', '/tmp']


def test_torch_devices_default(self):
with patch('sys.argv', self.required_args + []):
b = TrainOptions()
opt = b.parse()
self.assertEqual(opt.device_type, 'cuda')
self.assertEqual(opt.torch_devices, [torch.device('cuda', 0)])

def test_torch_devices_from_gpu_ids(self):
with patch.object(sys, 'argv', self.required_args + ['--gpu_ids', '0']):
opt = TrainOptions().parse()
self.assertEqual(opt.torch_devices, [torch.device('cuda', 0)])

with patch.object(sys, 'argv', self.required_args + ['--gpu_ids', '1']):
opt = TrainOptions().parse()
self.assertEqual(opt.torch_devices, [torch.device('cuda', 1)])

with patch.object(sys, 'argv', self.required_args + ['--gpu_ids', '0,2']):
opt = TrainOptions().parse()
self.assertEqual(opt.torch_devices, [torch.device('cuda', 0), torch.device('cuda', 2)])

def test_torch_devices_gpu_ids_type_conflict(self):
with patch.object(sys, 'argv', self.required_args + ['--gpu_ids', '0', '--device_type', 'cpu']):
with self.assertRaises(SystemExit) as cm:
TrainOptions().parse()


def test_torch_devices_mps(self):
with patch('sys.argv', self.required_args + ['--device_type', 'mps']):
opt = TrainOptions().parse()
self.assertEqual(opt.torch_devices, [torch.device('mps', 0)])

with patch('sys.argv', self.required_args + ['--device_type', 'mps', '--device_ids', '0']):
opt = TrainOptions().parse()
self.assertEqual(opt.torch_devices, [torch.device('mps', 0)])

with patch.object(sys, 'argv', self.required_args + ['--device_type', 'mps', '--device_ids', '2']):
with self.assertRaises(SystemExit) as cm:
TrainOptions().parse()

def test_torch_devices_cpu(self):
with patch('sys.argv', self.required_args + ['--device_type', 'cpu']):
opt = TrainOptions().parse()
self.assertEqual(opt.torch_devices, [torch.device('cpu', 0)])

with patch('sys.argv', self.required_args + ['--device_type', 'cpu', '--device_ids', '0,2']):
opt = TrainOptions().parse()
self.assertEqual(opt.torch_devices, [torch.device('cpu', 0), torch.device('cpu', 2)])



def test_torch_devices_cuda(self):
with patch('sys.argv', self.required_args + ['--device_type', 'cuda']):
opt = TrainOptions().parse()
self.assertEqual(opt.torch_devices, [torch.device('cuda', 0)])

with patch('sys.argv', self.required_args + ['--device_type', 'cuda', '--device_ids', '0,2']):
opt = TrainOptions().parse()
self.assertEqual(opt.torch_devices, [torch.device('cuda', 0), torch.device('cuda', 2)])










9 changes: 9 additions & 0 deletions util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,12 @@ def mkdir(path):
"""
if not os.path.exists(path):
os.makedirs(path)


def backend_available(type_):
if type_ == 'cpu':
return True
elif type_ == 'cuda':
return torch.cuda.is_available()
else:
return getattr(torch.backends,type_).is_available()