From 018afc80bc1c4191e75532d07022c57eb6801ca2 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 10 Sep 2024 11:11:29 +0300 Subject: [PATCH] adding more info about training and adding make_train with YX, ZY and ZX crops --- cellpose/gui/make_train.py | 57 +++++++++++++++++------------------ docs/train.rst | 61 +++++++++++++++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 31 deletions(-) diff --git a/cellpose/gui/make_train.py b/cellpose/gui/make_train.py index 076b717e..bba5accd 100644 --- a/cellpose/gui/make_train.py +++ b/cellpose/gui/make_train.py @@ -25,8 +25,6 @@ def main(): help='axis of image which corresponds to image channels') input_img_args.add_argument('--z_axis', default=None, type=int, help='axis of image which corresponds to Z dimension') - input_img_args.add_argument('--t_axis', default=None, type=int, - help='axis of image which corresponds to T dimension') input_img_args.add_argument( '--chan', default=0, type=int, help= 'channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s') @@ -39,22 +37,17 @@ def main(): input_img_args.add_argument( '--all_channels', action='store_true', help= 'use all channels in image if using own model and images with special channels') - training_args = parser.add_argument_group("training arguments") - training_args.add_argument( - '--mask_filter', default='_masks', type=str, help= - 'end string for masks to run on. use "_seg.npy" for manual annotations from the GUI. Default: %(default)s' - ) # algorithm settings algorithm_args = parser.add_argument_group("algorithm arguments") algorithm_args.add_argument('--sharpen_radius', required=False, default=0.0, - type=float, help='tile normalization') - algorithm_args.add_argument('--tile_norm', required=False, default=0.0, type=float, - help='tile normalization') + type=float, help='high-pass filtering radius. Default: %(default)s') + algorithm_args.add_argument('--tile_norm', required=False, default=0, type=int, + help='tile normalization block size. Default: %(default)s') algorithm_args.add_argument('--nimg_per_tif', required=False, default=10, type=int, - help='number of slices to save') + help='number of crops in XY to save per tiff. Default: %(default)s') algorithm_args.add_argument('--crop_size', required=False, default=512, type=int, - help='size of random crop to save') + help='size of random crop to save. Default: %(default)s') args = parser.parse_args() @@ -64,29 +57,33 @@ def main(): else: imf = None - image_names = io.get_image_files(args.dir, args.mask_filter, imf=imf, + image_names = io.get_image_files(args.dir, "_masks", imf=imf, look_one_level_down=args.look_one_level_down) - np.random.seed(0) - nimg_per_tif = 10 + nimg_per_tif = args.nimg_per_tif + crop_size = args.crop_size os.makedirs(os.path.join(args.dir, 'train/'), exist_ok=True) + pm = [(0, 1, 2, 3), (2, 0, 1, 3), (1, 0, 2, 3)] + npm = ["YX", "ZY", "ZX"] for name in image_names: name0 = os.path.splitext(os.path.split(name)[-1])[0] - img = io.imread(name) - #print(img.shape) - Ly, Lx = img.shape[1:3] - img = img[8:] - imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]] - for k, img in enumerate(imgs): - if args.tile_norm: - img = transforms.normalize99_tile(img, blocksize=args.tile_norm) - if args.sharpen_radius: - img = transforms.smooth_sharpen_img(img, - sharpen_radius=args.sharpen_radius) - ly = np.random.randint(0, Ly - args.crop_size) - lx = np.random.randint(0, Lx - args.crop_size) - io.imsave(os.path.join(args.dir, f'train/{name0}_{k}.tif'), - img[ly:ly + args.crop_size, lx:lx + args.crop_size]) + img0 = io.imread(name) + img0 = transforms.convert_image(img0, channels=[args.chan, args.chan2], channel_axis=args.channel_axis, z_axis=args.z_axis) + for p in range(3): + img = img0.transpose(pm[p]).copy() + print(npm[p], img[0].shape) + Ly, Lx = img.shape[1:3] + imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]] + for k, img in enumerate(imgs): + if args.tile_norm: + img = transforms.normalize99_tile(img, blocksize=args.tile_norm) + if args.sharpen_radius: + img = transforms.smooth_sharpen_img(img, + sharpen_radius=args.sharpen_radius) + ly = 0 if Ly - crop_size <= 0 else np.random.randint(0, Ly - crop_size) + lx = 0 if Lx - crop_size <= 0 else np.random.randint(0, Lx - crop_size) + io.imsave(os.path.join(args.dir, f'train/{name0}_{npm[p]}_{k}.tif'), + img[ly:ly + args.crop_size, lx:lx + args.crop_size].squeeze()) if __name__ == '__main__': diff --git a/docs/train.rst b/docs/train.rst index af7adb4b..dded5a45 100644 --- a/docs/train.rst +++ b/docs/train.rst @@ -83,7 +83,8 @@ In a notebook, you can train with the `train_seg` function: n_epochs=100, model_name="my_new_model") -Training arguments on the CLI +CLI training options +~~~~~~~~~~~~~~~~~~~~ :: @@ -114,3 +115,61 @@ Training arguments on the CLI Name of model to save as, defaults to name describing model architecture. Model is saved in the folder specified by --dir in models subfolder. + + +Re-training a model +~~~~~~~~~~~~~~~~~~~ + +We find that for re-training, using SGD generally works better, and it is the default in the GUI. +The options in the code above are the default options for retraining in the GUI and in the Cellpose 2.0 paper +``(weight_decay=1e-4, SGD=True, learning_rate=0.1, n_epochs=100)``, +although in the paper we often use 300 epochs instead of 100 epochs, and it may help to use more epochs, +especially when you have more training data. + +When re-training, keep in mind that the normalization happens per image that you train on, and often these are image crops from full images. +These crops may look different after normalization than the full images. To approximate per-crop normalization on the full images, we have the option for +tile normalization that can be set in ``model.eval``: ``normalize={"tile_norm_blocksize": 128}``. Alternatively/additionally, you may want to change +the overall normalization scaling on the full images, e.g. ``normalize={"percentile": [3, 98]``. You can visualize how the normalization looks in +a notebook for example with ``from cellpose import transforms; plt.imshow(transforms.normalize99(img, lower=3, upper=98))``. The default +that will be used for training on the image crops is ``[1, 99]``. + +You can create image crops from z-stacks (in XY, YZ and XZ) using the script ``cellpose/gui/make_train.py``: + +:: + python cellpose/gui/make_train.py --help + usage: make_train.py [-h] [--dir DIR] [--image_path IMAGE_PATH] [--look_one_level_down] [--img_filter IMG_FILTER] + [--channel_axis CHANNEL_AXIS] [--z_axis Z_AXIS] [--chan CHAN] [--chan2 CHAN2] [--invert] + [--all_channels] [--sharpen_radius SHARPEN_RADIUS] [--tile_norm TILE_NORM] + [--nimg_per_tif NIMG_PER_TIF] [--crop_size CROP_SIZE] + + cellpose parameters + + options: + -h, --help show this help message and exit + + input image arguments: + --dir DIR folder containing data to run or train on. + --image_path IMAGE_PATH + if given and --dir not given, run on single image instead of folder (cannot train with this + option) + --look_one_level_down + run processing on all subdirectories of current folder + --img_filter IMG_FILTER + end string for images to run on + --channel_axis CHANNEL_AXIS + axis of image which corresponds to image channels + --z_axis Z_AXIS axis of image which corresponds to Z dimension + --chan CHAN channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: 0 + --chan2 CHAN2 nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: 0 + --invert invert grayscale channel + --all_channels use all channels in image if using own model and images with special channels + + algorithm arguments: + --sharpen_radius SHARPEN_RADIUS + high-pass filtering radius. Default: 0.0 + --tile_norm TILE_NORM + tile normalization block size. Default: 0 + --nimg_per_tif NIMG_PER_TIF + number of crops in XY to save per tiff. Default: 10 + --crop_size CROP_SIZE + size of random crop to save. Default: 512