Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NVIDIA apex support and gradient checkpointing to reduce memory footprint #1090

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,13 @@ def set_requires_grad(self, nets, requires_grad=False):
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad

def make_data_parallel(self):
"""Make models data parallel"""
if len(self.gpu_ids) == 0:
return
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net = torch.nn.DataParallel(net, self.gpu_ids) # multi-GPUs
setattr(self, 'net' + name, net)
43 changes: 36 additions & 7 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
from torch.utils.checkpoint import checkpoint

try:
from apex import amp
except ImportError:
print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level")

class CycleGANModel(BaseModel):
"""
Expand Down Expand Up @@ -96,6 +101,13 @@ def __init__(self, opt):
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)

if opt.apex:
[self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D] = amp.initialize(
[self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=3)

# need to be wrapped after amp.initialize
self.make_data_parallel()

def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.

Expand All @@ -112,11 +124,17 @@ def set_input(self, input):
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
if not self.isTrain or not self.opt.checkpointing:
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
else:
self.rec_A = checkpoint(self.netG_B, self.fake_B)
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
if not self.isTrain or not self.opt.checkpointing:
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
else:
self.rec_B = checkpoint(self.netG_A, self.fake_A)

def backward_D_basic(self, netD, real, fake):
def backward_D_basic(self, netD, real, fake, loss_id):
"""Calculate GAN loss for the discriminator

Parameters:
Expand All @@ -135,18 +153,23 @@ def backward_D_basic(self, netD, real, fake):
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward()
if self.opt.apex:
with amp.scale_loss(loss_D, self.optimizer_D, loss_id=loss_id) as loss_D_scaled:
loss_D_scaled.backward()
else:
loss_D.backward()

return loss_D

def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, loss_id=0)

def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A, loss_id=1)

def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
Expand Down Expand Up @@ -175,7 +198,13 @@ def backward_G(self):
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()

if self.opt.apex:
with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=2) as loss_G_scaled:
loss_G_scaled.backward()
else:
self.loss_G.backward()


def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
Expand Down
1 change: 0 additions & 1 deletion models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net

Expand Down
27 changes: 25 additions & 2 deletions models/pix2pix_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import torch
from torch.utils.checkpoint import checkpoint

from .base_model import BaseModel
from . import networks

try:
from apex import amp
except ImportError:
print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level")


class Pix2PixModel(BaseModel):
""" This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
Expand Down Expand Up @@ -70,6 +77,12 @@ def __init__(self, opt):
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)

if opt.apex:
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D] = amp.initialize(
[self.netG, self.netD], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=2)

self.make_data_parallel()

def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.

Expand Down Expand Up @@ -99,7 +112,12 @@ def backward_D(self):
self.loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()

if self.opt.apex:
with amp.scale_loss(self.loss_D, self.optimizer_D, loss_id=0) as loss_D_scaled:
loss_D_scaled.backward()
else:
self.loss_D.backward()

def backward_G(self):
"""Calculate GAN and L1 loss for the generator"""
Expand All @@ -111,7 +129,12 @@ def backward_G(self):
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
# combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G.backward()

if self.opt.apex:
with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=1) as loss_G_scaled:
loss_G_scaled.backward()
else:
self.loss_G.backward()

def optimize_parameters(self):
self.forward() # compute fake images: G(A)
Expand Down
2 changes: 2 additions & 0 deletions models/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(self, opt):
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = [self.optimizer]

# need to be wrapped after amp.initialize
self.make_data_parallel()
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks

def set_input(self, input):
Expand Down
1 change: 1 addition & 0 deletions models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, opt):
# assigns the model to self.netG_[suffix] so that it can be loaded
# please see <BaseModel.load_networks>
setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
self.make_data_parallel()

def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Expand Down
9 changes: 9 additions & 0 deletions options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,14 @@ def initialize(self, parser):
parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')

# training optimizations
parser.add_argument('--checkpointing', action='store_true',
help='if true, it applies gradient checkpointing, saves memory but it makes the training slower')
parser.add_argument('--opt_level', default='O0', help='amp opt_level, default="O0" equals fp32 training')
self.isTrain = True
return parser

def parse(self):
opt = BaseOptions.parse(self)
opt.apex = opt.opt_level != "O0"
return opt