Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added different upsampling-modes #357

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def unet_carvana(pretrained=False, scale=0.5):
UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
Set the scale to 0.5 (50%) when predicting.
"""
net = _UNet(n_channels=3, n_classes=2, bilinear=False)
net = _UNet(n_channels=3, n_classes=2, upscaling_mode='transpose')
if pretrained:
if scale == 0.5:
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth'
Expand Down
12 changes: 10 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def get_args():
help='Minimum probability value to consider a mask pixel white')
parser.add_argument('--scale', '-s', type=float, default=0.5,
help='Scale factor for the input images')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--bilinear', action='store_true', default=False, help='deprecated, use `--upscaling_mode=upsample` instead')
parser.add_argument('--upscaling_mode', default='transpose',
const='transpose', nargs='?', choices=['upsample', 'unpool', 'transpose'],
help='Upscaling operation (default: %(default)s)')

return parser.parse_args()

Expand All @@ -81,8 +84,13 @@ def mask_to_image(mask: np.ndarray):
args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
upscaling_mode = args.upscaling_mode

net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear)
# if deprecated billinear arg is set, overwrite upscaling_mode
if args.bilinear:
upscaling_mode = 'upsample'

net = UNet(n_channels=3, n_classes=2, upscaling_mode=upscaling_mode)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {args.model}')
Expand Down
13 changes: 10 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def get_args():
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--bilinear', action='store_true', default=False, help='deprecated, use `--upscaling_mode=upsample` instead')
parser.add_argument('--upscaling_mode', default='transpose', const='transpose', nargs='?', choices=['upsample', 'unpool', 'transpose'], help='Upscaling operation (default: %(default)s)')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')

return parser.parse_args()
Expand All @@ -167,15 +168,21 @@ def get_args():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

upscaling_mode = args.upscaling_mode

# if deprecated billinear arg is set, overwrite upscaling_mode
if args.bilinear:
upscaling_mode = 'upsample'

# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
net = UNet(n_channels=3, n_classes=args.classes, upscaling_mode=upscaling_mode)

logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
f'"\t{net.upscaling_mode}" upscaling')

if args.load:
net.load_state_dict(torch.load(args.load, map_location=device))
Expand Down
30 changes: 15 additions & 15 deletions unet/unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,33 @@


class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
def __init__(self, n_channels, n_classes, upscaling_mode='transpose'):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.upscaling_mode = upscaling_mode

self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
factor = 1 if upscaling_mode == 'transpose' else 2
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.up1 = Up(1024, 512 // factor, upscaling_mode)
self.up2 = Up(512, 256 // factor, upscaling_mode)
self.up3 = Up(256, 128 // factor, upscaling_mode)
self.up4 = Up(128, 64, upscaling_mode)
self.outc = OutConv(64, n_classes)

def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x2, indices2 = self.down1(x1)
x3, indices3 = self.down2(x2)
x4, indices4 = self.down3(x3)
x5, indices5 = self.down4(x4)
x = self.up1(x5, x4, indices5)
x = self.up2(x, x3, indices4)
x = self.up3(x, x2, indices3)
x = self.up4(x, x1, indices2)
logits = self.outc(x)
return logits
28 changes: 18 additions & 10 deletions unet/unet_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,39 @@ class Down(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
self.maxpool = nn.MaxPool2d(2, return_indices=True)
self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x):
return self.maxpool_conv(x)
x, indices = self.maxpool(x)
return self.conv(x), indices


class Up(nn.Module):
"""Upscaling then double conv"""

def __init__(self, in_channels, out_channels, bilinear=True):
def __init__(self, in_channels, out_channels, upscaling_mode='transpose'):
super().__init__()

self.upscaling_mode = upscaling_mode
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
if self.upscaling_mode == 'upsample':
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
elif self.upscaling_mode == 'unpool':
self.up = nn.MaxUnpool2d(2)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
elif self.upscaling_mode == 'transpose':
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
else:
raise ValueError('mode should be one of [`upsample`, `unpool`, `transpose`]')

def forward(self, x1, x2):
x1 = self.up(x1)
def forward(self, x1, x2, indices):
if self.upscaling_mode == 'unpool':
x1 = self.up(x1, indices)
else:
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
Expand Down