From eeee38e9728c7f751ecfe2f20e2b55cb78b32015 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 8 Jan 2025 21:10:15 -0800 Subject: [PATCH] Avoid unecessary compat break btw train script and nearby timm versions w/ dtype addition. --- train.py | 4 ++-- validate.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 5e9ecc9ca..c6c1fcb1a 100755 --- a/train.py +++ b/train.py @@ -744,7 +744,7 @@ def main(): distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, - img_dtype=model_dtype, + img_dtype=model_dtype or torch.float32, device=device, use_prefetcher=args.prefetcher, use_multi_epochs_loader=args.use_multi_epochs_loader, @@ -769,7 +769,7 @@ def main(): distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, - img_dtype=model_dtype, + img_dtype=model_dtype or torch.float32, device=device, use_prefetcher=args.prefetcher, ) diff --git a/validate.py b/validate.py index 37fefaa6c..a414f36f6 100755 --- a/validate.py +++ b/validate.py @@ -307,7 +307,7 @@ def validate(args): crop_border_pixels=args.crop_border_pixels, pin_memory=args.pin_mem, device=device, - img_dtype=model_dtype, + img_dtype=model_dtype or torch.float32, tf_preprocessing=args.tf_preprocessing, )