Skip to content

Commit

Permalink
adding more info about training and adding make_train with YX, ZY and…
Browse files Browse the repository at this point in the history
… ZX crops
  • Loading branch information
carsen-stringer committed Sep 10, 2024
1 parent 89e8609 commit 018afc8
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 31 deletions.
57 changes: 27 additions & 30 deletions cellpose/gui/make_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def main():
help='axis of image which corresponds to image channels')
input_img_args.add_argument('--z_axis', default=None, type=int,
help='axis of image which corresponds to Z dimension')
input_img_args.add_argument('--t_axis', default=None, type=int,
help='axis of image which corresponds to T dimension')
input_img_args.add_argument(
'--chan', default=0, type=int, help=
'channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s')
Expand All @@ -39,22 +37,17 @@ def main():
input_img_args.add_argument(
'--all_channels', action='store_true', help=
'use all channels in image if using own model and images with special channels')
training_args = parser.add_argument_group("training arguments")
training_args.add_argument(
'--mask_filter', default='_masks', type=str, help=
'end string for masks to run on. use "_seg.npy" for manual annotations from the GUI. Default: %(default)s'
)

# algorithm settings
algorithm_args = parser.add_argument_group("algorithm arguments")
algorithm_args.add_argument('--sharpen_radius', required=False, default=0.0,
type=float, help='tile normalization')
algorithm_args.add_argument('--tile_norm', required=False, default=0.0, type=float,
help='tile normalization')
type=float, help='high-pass filtering radius. Default: %(default)s')
algorithm_args.add_argument('--tile_norm', required=False, default=0, type=int,
help='tile normalization block size. Default: %(default)s')
algorithm_args.add_argument('--nimg_per_tif', required=False, default=10, type=int,
help='number of slices to save')
help='number of crops in XY to save per tiff. Default: %(default)s')
algorithm_args.add_argument('--crop_size', required=False, default=512, type=int,
help='size of random crop to save')
help='size of random crop to save. Default: %(default)s')

args = parser.parse_args()

Expand All @@ -64,29 +57,33 @@ def main():
else:
imf = None

image_names = io.get_image_files(args.dir, args.mask_filter, imf=imf,
image_names = io.get_image_files(args.dir, "_masks", imf=imf,
look_one_level_down=args.look_one_level_down)

np.random.seed(0)
nimg_per_tif = 10
nimg_per_tif = args.nimg_per_tif
crop_size = args.crop_size
os.makedirs(os.path.join(args.dir, 'train/'), exist_ok=True)
pm = [(0, 1, 2, 3), (2, 0, 1, 3), (1, 0, 2, 3)]
npm = ["YX", "ZY", "ZX"]
for name in image_names:
name0 = os.path.splitext(os.path.split(name)[-1])[0]
img = io.imread(name)
#print(img.shape)
Ly, Lx = img.shape[1:3]
img = img[8:]
imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]]
for k, img in enumerate(imgs):
if args.tile_norm:
img = transforms.normalize99_tile(img, blocksize=args.tile_norm)
if args.sharpen_radius:
img = transforms.smooth_sharpen_img(img,
sharpen_radius=args.sharpen_radius)
ly = np.random.randint(0, Ly - args.crop_size)
lx = np.random.randint(0, Lx - args.crop_size)
io.imsave(os.path.join(args.dir, f'train/{name0}_{k}.tif'),
img[ly:ly + args.crop_size, lx:lx + args.crop_size])
img0 = io.imread(name)
img0 = transforms.convert_image(img0, channels=[args.chan, args.chan2], channel_axis=args.channel_axis, z_axis=args.z_axis)
for p in range(3):
img = img0.transpose(pm[p]).copy()
print(npm[p], img[0].shape)
Ly, Lx = img.shape[1:3]
imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]]
for k, img in enumerate(imgs):
if args.tile_norm:
img = transforms.normalize99_tile(img, blocksize=args.tile_norm)
if args.sharpen_radius:
img = transforms.smooth_sharpen_img(img,
sharpen_radius=args.sharpen_radius)
ly = 0 if Ly - crop_size <= 0 else np.random.randint(0, Ly - crop_size)
lx = 0 if Lx - crop_size <= 0 else np.random.randint(0, Lx - crop_size)
io.imsave(os.path.join(args.dir, f'train/{name0}_{npm[p]}_{k}.tif'),
img[ly:ly + args.crop_size, lx:lx + args.crop_size].squeeze())


if __name__ == '__main__':
Expand Down
61 changes: 60 additions & 1 deletion docs/train.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ In a notebook, you can train with the `train_seg` function:
n_epochs=100, model_name="my_new_model")


Training arguments on the CLI
CLI training options
~~~~~~~~~~~~~~~~~~~~

::

Expand Down Expand Up @@ -114,3 +115,61 @@ Training arguments on the CLI
Name of model to save as, defaults to name describing
model architecture. Model is saved in the folder
specified by --dir in models subfolder.


Re-training a model
~~~~~~~~~~~~~~~~~~~

We find that for re-training, using SGD generally works better, and it is the default in the GUI.
The options in the code above are the default options for retraining in the GUI and in the Cellpose 2.0 paper
``(weight_decay=1e-4, SGD=True, learning_rate=0.1, n_epochs=100)``,
although in the paper we often use 300 epochs instead of 100 epochs, and it may help to use more epochs,
especially when you have more training data.

When re-training, keep in mind that the normalization happens per image that you train on, and often these are image crops from full images.
These crops may look different after normalization than the full images. To approximate per-crop normalization on the full images, we have the option for
tile normalization that can be set in ``model.eval``: ``normalize={"tile_norm_blocksize": 128}``. Alternatively/additionally, you may want to change
the overall normalization scaling on the full images, e.g. ``normalize={"percentile": [3, 98]``. You can visualize how the normalization looks in
a notebook for example with ``from cellpose import transforms; plt.imshow(transforms.normalize99(img, lower=3, upper=98))``. The default
that will be used for training on the image crops is ``[1, 99]``.

You can create image crops from z-stacks (in XY, YZ and XZ) using the script ``cellpose/gui/make_train.py``:

::
python cellpose/gui/make_train.py --help
usage: make_train.py [-h] [--dir DIR] [--image_path IMAGE_PATH] [--look_one_level_down] [--img_filter IMG_FILTER]
[--channel_axis CHANNEL_AXIS] [--z_axis Z_AXIS] [--chan CHAN] [--chan2 CHAN2] [--invert]
[--all_channels] [--sharpen_radius SHARPEN_RADIUS] [--tile_norm TILE_NORM]
[--nimg_per_tif NIMG_PER_TIF] [--crop_size CROP_SIZE]

cellpose parameters

options:
-h, --help show this help message and exit

input image arguments:
--dir DIR folder containing data to run or train on.
--image_path IMAGE_PATH
if given and --dir not given, run on single image instead of folder (cannot train with this
option)
--look_one_level_down
run processing on all subdirectories of current folder
--img_filter IMG_FILTER
end string for images to run on
--channel_axis CHANNEL_AXIS
axis of image which corresponds to image channels
--z_axis Z_AXIS axis of image which corresponds to Z dimension
--chan CHAN channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: 0
--chan2 CHAN2 nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: 0
--invert invert grayscale channel
--all_channels use all channels in image if using own model and images with special channels

algorithm arguments:
--sharpen_radius SHARPEN_RADIUS
high-pass filtering radius. Default: 0.0
--tile_norm TILE_NORM
tile normalization block size. Default: 0
--nimg_per_tif NIMG_PER_TIF
number of crops in XY to save per tiff. Default: 10
--crop_size CROP_SIZE
size of random crop to save. Default: 512

0 comments on commit 018afc8

Please sign in to comment.