diff --git a/models/base_model.py b/models/base_model.py index 6de961b51a2..c343b20e3b0 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -30,9 +30,9 @@ def __init__(self, opt): -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. """ self.opt = opt - self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain - self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU + self.device = opt.torch_devices[0] + self.torch_devices = opt.torch_devices self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. torch.backends.cudnn.benchmark = True @@ -153,11 +153,10 @@ def save_networks(self, epoch): save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net' + name) - if len(self.gpu_ids) > 0 and torch.cuda.is_available(): - torch.save(net.module.cpu().state_dict(), save_path) - net.cuda(self.gpu_ids[0]) - else: - torch.save(net.cpu().state_dict(), save_path) + # Saving requires moving net to cpu + torch.save(net.module.cpu().state_dict(), save_path) + # Move net back to original device + net.to(self.device) def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 15bb72d8ddc..d62fdfafbcc 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -71,15 +71,15 @@ def __init__(self, opt): # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, - not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + not opt.no_dropout, opt.init_type, opt.init_gain, self.torch_devices) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, - not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + not opt.no_dropout, opt.init_type, opt.init_gain, self.torch_devices) if self.isTrain: # define discriminators self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, - opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.torch_devices) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, - opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.torch_devices) if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels diff --git a/models/networks.py b/models/networks.py index b3a10c99c20..32503a45390 100644 --- a/models/networks.py +++ b/models/networks.py @@ -4,7 +4,7 @@ import functools from torch.optim import lr_scheduler - +from util.util import backend_available ############################################################################### # Helper Functions ############################################################################### @@ -98,25 +98,24 @@ def init_func(m): # define the initialization function net.apply(init_func) # apply the initialization function -def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): +def init_net(net, init_type='normal', init_gain=0.02, torch_devices=[]): """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights Parameters: net (network) -- the network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal gain (float) -- scaling factor for normal, xavier and orthogonal. - gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + torch_devices (torch.device list) -- which devices the network runs on: e.g., cuda:0,cuda:1 Return an initialized network. """ - if len(gpu_ids) > 0: - assert(torch.cuda.is_available()) - net.to(gpu_ids[0]) - net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + assert(backend_available(torch_devices[0].type), f'Backend {torch_devices[0].type} not available') + net.to(torch_devices[0]) + net = torch.nn.DataParallel(net, torch_devices) # multi-GPUs init_weights(net, init_type, init_gain=init_gain) return net -def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, torch_devices=[]): """Create a generator Parameters: @@ -128,7 +127,7 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in use_dropout (bool) -- if use dropout layers. init_type (str) -- the name of our initialization method. init_gain (float) -- scaling factor for normal, xavier and orthogonal. - gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + torch_devices (torch.device list) -- which devices the network runs on: e.g., cuda:0,cuda:1 Returns a generator @@ -156,10 +155,10 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) - return init_net(net, init_type, init_gain, gpu_ids) + return init_net(net, init_type, init_gain, torch_devices) -def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): +def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, torch_devices=[]): """Create a discriminator Parameters: @@ -170,7 +169,7 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal' norm (str) -- the type of normalization layers used in the network. init_type (str) -- the name of the initialization method. init_gain (float) -- scaling factor for normal, xavier and orthogonal. - gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + torch_devices (torch.device list) -- which devices the network runs on: e.g., cuda:0,cuda:1 Returns a discriminator @@ -200,7 +199,7 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal' net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) - return init_net(net, init_type, init_gain, gpu_ids) + return init_net(net, init_type, init_gain, torch_devices) ############################################################################## @@ -282,7 +281,7 @@ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', const netD (network) -- discriminator network real_data (tensor array) -- real images fake_data (tensor array) -- generated images from the generator - device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + device (str) -- GPU / CPU / MPS type (str) -- if we mix real and fake data or not [real | fake | mixed]. constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 lambda_gp (float) -- weight for this loss diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 939eb887ee3..38cf7472e5f 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -54,11 +54,11 @@ def __init__(self, opt): self.model_names = ['G'] # define networks (both generator and discriminator) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, - not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + not opt.no_dropout, opt.init_type, opt.init_gain, self.torch_devices) if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, - opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.torch_devices) if self.isTrain: # define loss functions diff --git a/models/template_model.py b/models/template_model.py index 68cdaf6a9a2..560bc5b55c0 100644 --- a/models/template_model.py +++ b/models/template_model.py @@ -57,7 +57,7 @@ def __init__(self, opt): # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. self.model_names = ['G'] # define networks; you can use opt.isTrain to specify different behaviors for training and test. - self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, torch_devices=self.torch_devices) if self.isTrain: # only defined during training time # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) diff --git a/models/test_model.py b/models/test_model.py index fe15f40176e..5624e24aea0 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -43,7 +43,7 @@ def __init__(self, opt): # specify the models you want to save to the disk. The training/test scripts will call and self.model_names = ['G' + opt.model_suffix] # only generator is needed. self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, - opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.torch_devices) # assigns the model to self.netG_[suffix] so that it can be loaded # please see diff --git a/options/base_options.py b/options/base_options.py index 7a437cc35f3..890121a4a79 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -23,7 +23,9 @@ def initialize(self, parser): parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') parser.add_argument('--use_wandb', action='store_true', help='use wandb') - parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--gpu_ids', type=str, default='-1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU', dest='gpu_ids_str') + parser.add_argument('--device_ids', type=str, default='0', help='identifiers ("ordinals") of torch devices (i.e. GPUs, CPUs) to use, e.g. 0 or 0,2. ', dest='device_ids_str') + parser.add_argument('--device_type', type=str, default='cuda', help='torch device to use [cpu | cuda | mps]') parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') # model parameters parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') @@ -123,15 +125,26 @@ def parse(self): self.print_options(opt) - # set gpu ids - str_ids = opt.gpu_ids.split(',') - opt.gpu_ids = [] - for str_id in str_ids: - id = int(str_id) - if id >= 0: - opt.gpu_ids.append(id) - if len(opt.gpu_ids) > 0: - torch.cuda.set_device(opt.gpu_ids[0]) + device_ids = [int(i) for i in opt.device_ids_str.split(',') if int(i) >= 0] + gpu_ids = [int(i) for i in opt.gpu_ids_str.split(',') if int(i) >= 0] + + # convert old-style `gpu_ids` arg into new-style `torch_devices` + # while catching invalid combinations of `gpu_ids` and device_type + if len(gpu_ids) > 0: + if opt.device_type !='cuda': + #  This is the canonical way to raise an error from a parser + self.parser.error('If --gpu_ids is specified, --device_type must be "cuda". To specify non-cuda device ids, use --device_ids') + device_ids = gpu_ids + + # This doesn't currently allow specifying devices with a mix of different types + opt.torch_devices = [torch.device(f'{opt.device_type}:{i}') for i in device_ids] + + if opt.device_type == 'cuda': + torch.cuda.device(device_ids[0]) + elif opt.device_type == 'mps': + if device_ids != [0]: + #  This is the canonical way to raise an error from a parser + self.parser.error('Devices other than "0" are not currently supported with MPS backend.') self.opt = opt return self.opt diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_options.py b/tests/test_options.py new file mode 100644 index 00000000000..a66fd9b6cae --- /dev/null +++ b/tests/test_options.py @@ -0,0 +1,82 @@ +import unittest +from unittest.mock import patch +import torch + +import sys +import argparse + +from options.train_options import TrainOptions + +class TestTrainOptions(unittest.TestCase): + + required_args = ['train.py', '--dataroot', '/tmp'] + + + def test_torch_devices_default(self): + with patch('sys.argv', self.required_args + []): + b = TrainOptions() + opt = b.parse() + self.assertEqual(opt.device_type, 'cuda') + self.assertEqual(opt.torch_devices, [torch.device('cuda', 0)]) + + def test_torch_devices_from_gpu_ids(self): + with patch.object(sys, 'argv', self.required_args + ['--gpu_ids', '0']): + opt = TrainOptions().parse() + self.assertEqual(opt.torch_devices, [torch.device('cuda', 0)]) + + with patch.object(sys, 'argv', self.required_args + ['--gpu_ids', '1']): + opt = TrainOptions().parse() + self.assertEqual(opt.torch_devices, [torch.device('cuda', 1)]) + + with patch.object(sys, 'argv', self.required_args + ['--gpu_ids', '0,2']): + opt = TrainOptions().parse() + self.assertEqual(opt.torch_devices, [torch.device('cuda', 0), torch.device('cuda', 2)]) + + def test_torch_devices_gpu_ids_type_conflict(self): + with patch.object(sys, 'argv', self.required_args + ['--gpu_ids', '0', '--device_type', 'cpu']): + with self.assertRaises(SystemExit) as cm: + TrainOptions().parse() + + + def test_torch_devices_mps(self): + with patch('sys.argv', self.required_args + ['--device_type', 'mps']): + opt = TrainOptions().parse() + self.assertEqual(opt.torch_devices, [torch.device('mps', 0)]) + + with patch('sys.argv', self.required_args + ['--device_type', 'mps', '--device_ids', '0']): + opt = TrainOptions().parse() + self.assertEqual(opt.torch_devices, [torch.device('mps', 0)]) + + with patch.object(sys, 'argv', self.required_args + ['--device_type', 'mps', '--device_ids', '2']): + with self.assertRaises(SystemExit) as cm: + TrainOptions().parse() + + def test_torch_devices_cpu(self): + with patch('sys.argv', self.required_args + ['--device_type', 'cpu']): + opt = TrainOptions().parse() + self.assertEqual(opt.torch_devices, [torch.device('cpu', 0)]) + + with patch('sys.argv', self.required_args + ['--device_type', 'cpu', '--device_ids', '0,2']): + opt = TrainOptions().parse() + self.assertEqual(opt.torch_devices, [torch.device('cpu', 0), torch.device('cpu', 2)]) + + + + def test_torch_devices_cuda(self): + with patch('sys.argv', self.required_args + ['--device_type', 'cuda']): + opt = TrainOptions().parse() + self.assertEqual(opt.torch_devices, [torch.device('cuda', 0)]) + + with patch('sys.argv', self.required_args + ['--device_type', 'cuda', '--device_ids', '0,2']): + opt = TrainOptions().parse() + self.assertEqual(opt.torch_devices, [torch.device('cuda', 0), torch.device('cuda', 2)]) + + + + + + + + + + diff --git a/util/util.py b/util/util.py index b050c13e1d6..477f3b7f2a5 100644 --- a/util/util.py +++ b/util/util.py @@ -101,3 +101,12 @@ def mkdir(path): """ if not os.path.exists(path): os.makedirs(path) + + +def backend_available(type_): + if type_ == 'cpu': + return True + elif type_ == 'cuda': + return torch.cuda.is_available() + else: + return getattr(torch.backends,type_).is_available()