diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index f81945602..e79e22045 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -125,6 +125,7 @@ def __next__(self): if self.debug: logger.info(f"found {len(self.files)} images in the archive") + new_images = [] while len(images) + len(new_images) < self.batch_size: if self.image_index >= len(self.files): break @@ -166,6 +167,10 @@ def collate_fn_remove_corrupted(batch): def main(args): + assert args.load_archive == ( + args.metadata is not None + ), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります" + # model location is model_dir + repo_id # repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_")) @@ -436,7 +441,7 @@ def run_batch( else: image_md = images_metadata.get(image_path, None) if image_md is None: - image_md = {"image_size": [image_size.width, image_size.height]} + image_md = {"image_size": list(image_size)} images_metadata[image_path] = image_md if "tags" not in image_md: image_md["tags"] = [] @@ -464,6 +469,7 @@ def run_batch( # version check major, minor, patch = metadata.get("format_version", "0.0.0").split(".") + major, minor, patch = int(major), int(minor), int(patch) if major > 1 or (major == 1 and minor > 0): logger.warning( f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work." @@ -480,7 +486,7 @@ def run_batch( # prepare DataLoader or something similar :) use_loader = False if args.load_archive: - loader = ArchiveImageLoader(image_paths, args.batch_size) + loader = ArchiveImageLoader([str(p) for p in image_paths], args.batch_size) use_loader = True elif args.max_data_loader_n_workers is not None: # 読み込みの高速化のためにDataLoaderを使うオプション