diff --git a/train.py b/train.py index 5e9ecc9ca..c6c1fcb1a 100755 --- a/train.py +++ b/train.py @@ -744,7 +744,7 @@ def main(): distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, - img_dtype=model_dtype, + img_dtype=model_dtype or torch.float32, device=device, use_prefetcher=args.prefetcher, use_multi_epochs_loader=args.use_multi_epochs_loader, @@ -769,7 +769,7 @@ def main(): distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, - img_dtype=model_dtype, + img_dtype=model_dtype or torch.float32, device=device, use_prefetcher=args.prefetcher, ) diff --git a/validate.py b/validate.py index 37fefaa6c..a414f36f6 100755 --- a/validate.py +++ b/validate.py @@ -307,7 +307,7 @@ def validate(args): crop_border_pixels=args.crop_border_pixels, pin_memory=args.pin_mem, device=device, - img_dtype=model_dtype, + img_dtype=model_dtype or torch.float32, tf_preprocessing=args.tf_preprocessing, )