From b1d1af910b0a1b3554b8798a9e531c4533561ab3 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sat, 13 Jul 2024 15:55:56 -0400 Subject: [PATCH 01/22] adding more args to denoise --- cellpose/denoise.py | 18 ++++++++---------- cellpose/resnet_torch.py | 1 + 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 1270632d..61443510 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -937,6 +937,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch nbatch = 0 + train_losses, test_losses = [], [] for iepoch in range(n_epochs): np.random.seed(iepoch) rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,), @@ -974,7 +975,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N nsum += len(img) nbatch += 1 - if iepoch % 10 == 0 or iepoch < 10: + if iepoch % 5 == 0 or iepoch < 10: lavg = lavg / nsum lavg_per = lavg_per / nsum if test_data is not None or test_files is not None: @@ -1005,21 +1006,18 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N lavgt += loss.item() * img.shape[0] nsum += len(img) + lavgt = lavgt / nsum denoise_logger.info( "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f" - % (iepoch, time.time() - tic, lavg, lavg_per, lavgt / nsum, + % (iepoch, time.time() - tic, lavg, lavg_per, lavgt, learning_rate[iepoch])) + test_losses.append(lavgt) else: denoise_logger.info( "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" % (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch])) - elif iepoch < 50: - lavg = lavg / nsum - lavg_per = lavg_per / nsum - denoise_logger.info( - "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" % - (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch])) - + train_losses.append(lavg) + if save_path is not None: if iepoch == n_epochs - 1 or iepoch % save_every == 1: if save_each: #separate files as model progresses @@ -1031,7 +1029,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N else: filename = save_path - return filename + return filename, train_losses, test_losses if __name__ == "__main__": diff --git a/cellpose/resnet_torch.py b/cellpose/resnet_torch.py index 56b0a36a..808b26be 100644 --- a/cellpose/resnet_torch.py +++ b/cellpose/resnet_torch.py @@ -199,6 +199,7 @@ class CPnet(nn.Module): def __init__(self, nbase, nout, sz, mkldnn=False, conv_3D=False, max_pool=True, diam_mean=30.): super().__init__() + self.nchan = nbase[0] self.nbase = nbase self.nout = nout self.sz = sz From 22cf562228c80356c90a4c5d441fdfa8e5224e40 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sat, 13 Jul 2024 16:12:35 -0400 Subject: [PATCH 02/22] adding model_name as input --- cellpose/denoise.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 61443510..f8e3705e 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -849,7 +849,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0, iso=True, downsample=0., learning_rate=0.005, n_epochs=500, momentum=0.9, weight_decay=0.00001, batch_size=8, nimg_per_epoch=None, - nimg_test_per_epoch=None): + nimg_test_per_epoch=None, model_name=None): # net properties device = net.device @@ -864,21 +864,24 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N d = datetime.datetime.now() if save_path is not None: - filename = "" - lstrs = ["per", "seg", "rec"] - for k, (l, s) in enumerate(zip(lam, lstrs)): - filename += f"{s}_{l:.2f}_" - if poisson.sum() > 0: - filename += "poisson_" - if blur.sum() > 0: - if iso: - filename += "blur_" - else: - filename += "bluraniso_" - if downsample.sum() > 0: - filename += "downsample_" - filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f") - filename = os.path.join(save_path, filename) + if model_name is None: + filename = "" + lstrs = ["per", "seg", "rec"] + for k, (l, s) in enumerate(zip(lam, lstrs)): + filename += f"{s}_{l:.2f}_" + if poisson.sum() > 0: + filename += "poisson_" + if blur.sum() > 0: + if iso: + filename += "blur_" + else: + filename += "bluraniso_" + if downsample.sum() > 0: + filename += "downsample_" + filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f") + filename = os.path.join(save_path, filename) + else: + filename = os.path.join(save_path, model_name) print(filename) for i in range(len(poisson)): denoise_logger.info( From c0dfd356940fdadb907d7d9f337fab034a3e803e Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Tue, 16 Jul 2024 09:46:26 -0400 Subject: [PATCH 03/22] interleave different noise types --- cellpose/denoise.py | 79 +++++++++++++++++++++++++++++++-------------- 1 file changed, 54 insertions(+), 25 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index f8e3705e..17827143 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -210,7 +210,7 @@ def img_norm(imgi): def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7, ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None, - ds=None): + ds=None, partial_blur=False): """Adds noise to the input image. Args: @@ -234,6 +234,7 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp """ device = lbl.device imgi = torch.zeros_like(lbl) + Ly, Lx = lbl.shape[-2:] diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device) #ds0 = 1 if ds is None else ds.item() @@ -278,9 +279,24 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1) gfilt /= gfilt.sum(axis=(1, 2), keepdims=True) - imgi[iblur] = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1), + lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1), padding=gfilt.shape[-1] // 2, groups=gfilt.shape[0]).transpose(1, 0) + if partial_blur: + #yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100) + imgi[iblur] = lbl[iblur].clone() + Lxc = int(Lx * 0.85) + ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32), + torch.arange(0, Lxc, dtype=torch.float32), + indexing="ij") + mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2)) + mask -= mask.min() + mask /= mask.max() + lbl_blur_crop = lbl_blur[:, :, :, :Lxc] + imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask + + (1-mask) * imgi[iblur, :, :, :Lxc]) + else: + imgi[iblur] = lbl_blur imgi[~iblur] = lbl[~iblur] @@ -476,7 +492,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1, augment=False, resample=True, invert=False, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, - min_size=15, niter=None, interp=True): + min_size=15, niter=None, interp=True, bsize=224): """ Restore array or list of images using the image restoration model, and then segment. @@ -541,7 +557,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, channel_axis=channel_axis, z_axis=z_axis, normalize=normalize_params, rescale=rescale, diameter=diameter, tile=tile, - tile_overlap=tile_overlap) + tile_overlap=tile_overlap, bsize=bsize) # turn off special normalization for segmentation normalize_params = normalize_default @@ -557,7 +573,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, invert=invert, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy, stitch_threshold=stitch_threshold, min_size=min_size, niter=niter, - interp=interp) + interp=interp, bsize=bsize) return masks, flows, styles, img_restore @@ -658,7 +674,8 @@ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None, self.net_type = "cellpose_denoise" def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, - normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1): + normalize=True, rescale=None, diameter=None, tile=True, + tile_overlap=0.1, bsize=224): """ Restore array or list of images using the image restoration model. @@ -716,7 +733,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, isinstance(rescale, np.ndarray) else rescale, diameter=diameter[i] if isinstance(diameter, list) or isinstance(diameter, np.ndarray) else diameter, tile=tile, - tile_overlap=tile_overlap) + tile_overlap=tile_overlap, bsize=bsize) imgs.append(imgi) if isinstance(x, np.ndarray): imgs = np.array(imgs) @@ -765,18 +782,18 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, x[..., c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size, normalize=normalize, rescale=rescale0, tile=tile, - tile_overlap=tile_overlap) + tile_overlap=tile_overlap, bsize=bsize) else: x[..., c] = self._eval(self.net_chan2, x[..., c:c + 1], batch_size=batch_size, normalize=normalize, rescale=rescale0, tile=tile, - tile_overlap=tile_overlap) + tile_overlap=tile_overlap, bsize=bsize) x = x[0] if squeeze else x return x def _eval(self, net, x, batch_size=8, normalize=True, rescale=None, tile=True, - tile_overlap=0.1): + tile_overlap=0.1, bsize=224): """ Run image restoration model on a single channel. @@ -829,7 +846,7 @@ def _eval(self, net, x, batch_size=8, normalize=True, rescale=None, tile=True, if img.ndim == 2: img = img[:, :, np.newaxis] yf, style = run_net(net, img, batch_size=batch_size, augment=False, - tile=tile, tile_overlap=tile_overlap) + tile=tile, tile_overlap=tile_overlap, bsize=bsize) img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2]) if img.ndim == 2: @@ -847,7 +864,8 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N test_labels=None, test_files=None, train_probs=None, test_probs=None, lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None, save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0, - iso=True, downsample=0., learning_rate=0.005, n_epochs=500, momentum=0.9, + iso=True, downsample=0., ds_max=7, + learning_rate=0.005, n_epochs=500, weight_decay=0.00001, batch_size=8, nimg_per_epoch=None, nimg_test_per_epoch=None, model_name=None): @@ -950,7 +968,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N for param_group in optimizer.param_groups: param_group["lr"] = learning_rate[iepoch] lavg, lavg_per, nsum = 0, 0, 0 - for ibatch in range(0, nimg_per_epoch, batch_size): + for ibatch in range(0, nimg_per_epoch, batch_size * nnoise): inds = rperm[ibatch:ibatch + batch_size] if train_data is None: imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds] @@ -958,15 +976,23 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N else: imgs = [train_data[i][:nchan] for i in inds] lbls = [train_labels[i][1:] for i in inds] - inoise = nbatch % nnoise - img, lbl, scale = random_rotate_and_resize_noise( - imgs, lbls, diam_train[inds].copy(), poisson=poisson[inoise], - beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso, - downsample=downsample[inoise], diam_mean=diam_mean, device=device) - #print(torch.isnan(img).sum()) - if torch.isnan(img).sum(): - import pdb - pdb.set_trace() + #inoise = nbatch % nnoise + for inoise in range(nnoise): + imgi, lbli, scale = random_rotate_and_resize_noise( + imgs, lbls, diam_train[inds].copy(), poisson=poisson[inoise], + beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso, + downsample=downsample[inoise], diam_mean=diam_mean, ds_max=ds_max, + device=device) + if inoise == 0: + img = imgi + lbl = lbli + else: + img = torch.cat((img, imgi), axis=0) + lbl = torch.cat((lbl, lbli), axis=0) + if nnoise > 0: + iperm = np.random.permutation(img.shape[0]) + img, lbl = img[iperm], lbl[iperm] + optimizer.zero_grad() loss, loss_per = train_loss(net, img[:, :nchan], net1=net1, img=img[:, nchan:], lbl=lbl, lam=lam) @@ -1069,6 +1095,8 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N help="scale of gaussian blurring stddev") training_args.add_argument("--downsample", default=0., type=float, help="fraction of images to downsample") + training_args.add_argument("--ds_max", default=7, type=int, + help="max downsampling factor") training_args.add_argument("--lam_per", default=1.0, type=float, help="weighting of perceptual loss") training_args.add_argument("--lam_seg", default=1.5, type=float, @@ -1184,10 +1212,11 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N model.net, train_data=train_data, train_labels=labels, train_files=train_files, test_data=test_data, test_labels=test_labels, test_files=test_files, train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta, - blur=blur, gblur=gblur, downsample=downsample, iso=True, n_epochs=args.n_epochs, + blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max, + iso=True, n_epochs=args.n_epochs, learning_rate=args.learning_rate, - lam=[args.lam_per, args.lam_seg, args.lam_rec - ], seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch, + lam=[args.lam_per, args.lam_seg, args.lam_rec], + seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch, nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path) From a8af348fbee34fdcd680088cc0b7307a679b0af3 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Tue, 16 Jul 2024 09:56:35 -0400 Subject: [PATCH 04/22] bug in io --- cellpose/denoise.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 17827143..0096b34d 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -1121,6 +1121,8 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N io.logger_setup() args = parser.parse_args() + lams = [args.lam_per, args.lam_seg, args.lam_rec] + print("lam", lams) if len(args.noise_type) > 0: noise_type = args.noise_type @@ -1163,8 +1165,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N train_data, labels, train_files, train_probs = None, None, None, None test_data, test_labels, test_files, test_probs = None, None, None, None if len(args.file_list) == 0: - output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0, - 0) + output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0) images, labels, image_names, test_images, test_labels, image_names_test = output train_data = [] for i in range(len(images)): @@ -1215,7 +1216,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max, iso=True, n_epochs=args.n_epochs, learning_rate=args.learning_rate, - lam=[args.lam_per, args.lam_seg, args.lam_rec], + lam=lams, seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch, nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path) From b0c03afe0505ce3852124ac35d7a04af59091ff0 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Tue, 16 Jul 2024 20:04:32 -0400 Subject: [PATCH 05/22] bug in inoise --- cellpose/denoise.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 0096b34d..8d4c057b 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -993,14 +993,16 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N iperm = np.random.permutation(img.shape[0]) img, lbl = img[iperm], lbl[iperm] - optimizer.zero_grad() - loss, loss_per = train_loss(net, img[:, :nchan], net1=net1, - img=img[:, nchan:], lbl=lbl, lam=lam) - - loss.backward() - optimizer.step() - lavg += loss.item() * img.shape[0] - lavg_per += loss_per.item() * img.shape[0] + for inoise in range(nnoise): + optimizer.zero_grad() + imgi = img[inoise * batch_size: (inoise+1) * batch_size] + loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1, + img=imgi[:, nchan:], lbl=lbl, lam=lam) + loss.backward() + optimizer.step() + lavg += loss.item() * img.shape[0] + lavg_per += loss_per.item() * img.shape[0] + nsum += len(img) nbatch += 1 From de0e6142b352bdf389ea752067c787f182c88959 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Wed, 17 Jul 2024 06:42:13 -0400 Subject: [PATCH 06/22] bug in nnoise --- cellpose/denoise.py | 52 ++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 8d4c057b..7c3971ae 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -969,7 +969,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N param_group["lr"] = learning_rate[iepoch] lavg, lavg_per, nsum = 0, 0, 0 for ibatch in range(0, nimg_per_epoch, batch_size * nnoise): - inds = rperm[ibatch:ibatch + batch_size] + inds = rperm[ibatch : ibatch + batch_size * nnoise] if train_data is None: imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds] lbls = [io.imread(train_labels_files[i])[1:] for i in inds] @@ -977,32 +977,40 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N imgs = [train_data[i][:nchan] for i in inds] lbls = [train_labels[i][1:] for i in inds] #inoise = nbatch % nnoise - for inoise in range(nnoise): - imgi, lbli, scale = random_rotate_and_resize_noise( - imgs, lbls, diam_train[inds].copy(), poisson=poisson[inoise], - beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso, - downsample=downsample[inoise], diam_mean=diam_mean, ds_max=ds_max, - device=device) - if inoise == 0: - img = imgi - lbl = lbli - else: - img = torch.cat((img, imgi), axis=0) - lbl = torch.cat((lbl, lbli), axis=0) + rnoise = np.random.permutation(nnoise) + for i, inoise in enumerate(rnoise): + if i * batch_size < len(imgs): + imgi, lbli, scale = random_rotate_and_resize_noise( + imgs[i * batch_size : (i + 1) * batch_size], + lbls[i * batch_size : (i + 1) * batch_size], + diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(), + poisson=poisson[inoise], + beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso, + downsample=downsample[inoise], diam_mean=diam_mean, ds_max=ds_max, + device=device) + if i == 0: + img = imgi + lbl = lbli + else: + img = torch.cat((img, imgi), axis=0) + lbl = torch.cat((lbl, lbli), axis=0) + if nnoise > 0: iperm = np.random.permutation(img.shape[0]) img, lbl = img[iperm], lbl[iperm] - for inoise in range(nnoise): + for i in range(nnoise): optimizer.zero_grad() - imgi = img[inoise * batch_size: (inoise+1) * batch_size] - loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1, - img=imgi[:, nchan:], lbl=lbl, lam=lam) - loss.backward() - optimizer.step() - lavg += loss.item() * img.shape[0] - lavg_per += loss_per.item() * img.shape[0] - + imgi = img[i * batch_size: (i + 1) * batch_size] + lbli = lbl[i * batch_size: (i + 1) * batch_size] + if imgi.shape[0] > 0: + loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1, + img=imgi[:, nchan:], lbl=lbli, lam=lam) + loss.backward() + optimizer.step() + lavg += loss.item() * imgi.shape[0] + lavg_per += loss_per.item() * imgi.shape[0] + nsum += len(img) nbatch += 1 From c4d9cb66e8907981e5f4b0346fc228d09f9eee68 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Tue, 23 Jul 2024 12:09:56 -0400 Subject: [PATCH 07/22] fixing bugs with GUI vs CLI for denoising --- cellpose/core.py | 4 ++-- cellpose/denoise.py | 34 +++++++++++++++++++--------------- cellpose/gui/gui3d.py | 34 ++++++++++++++++++---------------- cellpose/models.py | 8 ++++++-- 4 files changed, 45 insertions(+), 35 deletions(-) diff --git a/cellpose/core.py b/cellpose/core.py index aa8399e8..f67009b5 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -272,7 +272,7 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0 yf = np.zeros((Lz, nout, imgi.shape[-2], imgi.shape[-1]), np.float32) styles = [] if ny * nx > batch_size: - ziterator = trange(Lz, file=tqdm_out) + ziterator = trange(Lz, file=tqdm_out, mininterval=30) for i in ziterator: yfi, stylei = _run_tiled(net, imgi[i], augment=augment, bsize=bsize, tile_overlap=tile_overlap) @@ -283,7 +283,7 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0 ntiles = ny * nx nimgs = max(2, int(np.round(batch_size / ntiles))) niter = int(np.ceil(Lz / nimgs)) - ziterator = trange(niter, file=tqdm_out) + ziterator = trange(niter, file=tqdm_out, mininterval=30) for k in ziterator: IMGa = np.zeros((ntiles * nimgs, nchan, ly, lx), np.float32) for i in range(min(Lz - k * nimgs, nimgs)): diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 7c3971ae..0c95eb36 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -245,20 +245,22 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp iblur = np.random.rand(len(lbl)) < blur if iblur.sum() > 0: if sigma0 is None: - # was 10 - xrand = np.random.exponential(1, size=iblur.sum()) - xrand = np.clip(xrand * 0.5, 0.1, 1.0) - xrand *= gblur - sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to( - device) - #(1 + torch.rand(iblur.sum(), device=device)) - if not iso: - sr = diams[iblur] / 30. * 2 * (1 + - torch.rand(iblur.sum(), device=device)) - sigma1 = (torch.rand(iblur.sum(), device=device) > 0.66) * sr + if iso: + # was 10 + # xrand = np.random.exponential(1, size=iblur.sum()) + # xrand = np.clip(xrand * 0.5, 0.1, 1.0) + # xrand *= gblur + # sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to( + # device) + # #(1 + torch.rand(iblur.sum(), device=device)) + # sigma1 = sigma0.clone() + sigma0 = diams[iblur] / 30. * gblur * (1/gblur + + (1 - 1/gblur) * torch.rand(iblur.sum(), device=device)) + sigma1 = sigma0.clone() else: - sigma1 = sigma0.clone( - ) #+ torch.randint(0, 3, size=(len(sigma0.clone()),), device=device) + sigma0 = diams[iblur] / 30. * gblur * (1/gblur + + (1 - 1/gblur) * torch.rand(iblur.sum(), device=device)) + sigma1 = sigma0.clone() / 10. else: sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device) sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device) @@ -555,6 +557,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels, channel_axis=channel_axis, z_axis=z_axis, + do_3D=do_3D, normalize=normalize_params, rescale=rescale, diameter=diameter, tile=tile, tile_overlap=tile_overlap, bsize=bsize) @@ -674,7 +677,7 @@ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None, self.net_type = "cellpose_denoise" def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, - normalize=True, rescale=None, diameter=None, tile=True, + normalize=True, rescale=None, diameter=None, tile=True, do_3D=False, tile_overlap=0.1, bsize=224): """ Restore array or list of images using the image restoration model. @@ -729,6 +732,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2)) else channels, channel_axis=channel_axis, z_axis=z_axis, normalize=normalize, + do_3D=do_3D, rescale=rescale[i] if isinstance(rescale, list) or isinstance(rescale, np.ndarray) else rescale, diameter=diameter[i] if isinstance(diameter, list) or @@ -742,7 +746,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, else: # reshape image x = transforms.convert_image(x, channels, channel_axis=channel_axis, - z_axis=z_axis) + z_axis=z_axis, do_3D=do_3D) if x.ndim < 4: squeeze = True x = x[np.newaxis, ...] diff --git a/cellpose/gui/gui3d.py b/cellpose/gui/gui3d.py index 8804a7e9..8a3f9eea 100644 --- a/cellpose/gui/gui3d.py +++ b/cellpose/gui/gui3d.py @@ -38,8 +38,8 @@ def avg3d(C): """ Ly, Lx = C.shape # pad T by 2 - T = np.zeros((Ly + 2, Lx + 2), np.float32) - M = np.zeros((Ly, Lx), np.float32) + T = np.zeros((Ly + 2, Lx + 2), "float32") + M = np.zeros((Ly, Lx), "float32") T[1:-1, 1:-1] = C.copy() y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int), indexing="ij") @@ -244,7 +244,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True): vc = stroke[iz, 2] if iz.sum() > 0: # get points inside drawn points - mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8) + mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8") pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2), axis=-1)[:, np.newaxis, :] mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) @@ -265,7 +265,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True): elif ioverlap.sum() > 0: ar, ac = ar[~ioverlap], ac[~ioverlap] # compute outline of new mask - mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8) + mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8") mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1 contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) @@ -282,7 +282,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True): pix = np.append(pix, np.vstack((ars, acs)), axis=-1) mall = mall[:, pix[0].min():pix[0].max() + 1, - pix[1].min():pix[1].max() + 1].astype(np.float32) + pix[1].min():pix[1].max() + 1].astype("float32") ymin, xmin = pix[0].min(), pix[1].min() if len(zdraw) > 1: mall, zfill = interpZ(mall, zdraw - zmin) @@ -422,15 +422,15 @@ def update_ortho(self): for j in range(2): if j == 0: if self.view == 0: - image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2) + image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy() else: image = self.stack_filtered[zmin:zmax, :, - x].transpose(1, 0, 2) + x].transpose(1, 0, 2).copy() else: image = self.stack[ zmin:zmax, - y, :] if self.view == 0 else self.stack_filtered[zmin:zmax, - y, :] + y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax, + y, :].copy() if self.nchan == 1: # show single channel image = image[..., 0] @@ -458,11 +458,13 @@ def update_ortho(self): self.imgOrtho[j].setLevels( self.saturation[0][self.currentZ]) elif self.color == 4: - image = image.astype(np.float32).mean(axis=-1).astype(np.uint8) + if image.ndim > 2: + image = image.astype("float32").mean(axis=2).astype("uint8") self.imgOrtho[j].setImage(image, autoLevels=False, lut=None) self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ]) elif self.color == 5: - image = image.astype(np.float32).mean(axis=-1).astype(np.uint8) + if image.ndim > 2: + image = image.astype("float32").mean(axis=2).astype("uint8") self.imgOrtho[j].setImage(image, autoLevels=False, lut=self.cmap[0]) self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ]) @@ -470,7 +472,7 @@ def update_ortho(self): self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect) else: - image = np.zeros((10, 10), np.uint8) + image = np.zeros((10, 10), "uint8") self.imgOrtho[0].setImage(image, autoLevels=False, lut=None) self.imgOrtho[0].setLevels([0.0, 255.0]) self.imgOrtho[1].setImage(image, autoLevels=False, lut=None) @@ -478,8 +480,8 @@ def update_ortho(self): zrange = zmax - zmin self.layer_ortho = [ - np.zeros((self.Ly, zrange, 4), np.uint8), - np.zeros((zrange, self.Lx, 4), np.uint8) + np.zeros((self.Ly, zrange, 4), "uint8"), + np.zeros((zrange, self.Lx, 4), "uint8") ] if self.masksOn: for j in range(2): @@ -488,7 +490,7 @@ def update_ortho(self): else: cp = self.cellpix[zmin:zmax, y] self.layer_ortho[j][..., :3] = self.cellcolors[cp, :] - self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype(np.uint8) + self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8") if self.selected > 0: self.layer_ortho[j][cp == self.selected] = np.array( [255, 255, 255, self.opacity]) @@ -499,7 +501,7 @@ def update_ortho(self): op = self.outpix[zmin:zmax, :, x].T else: op = self.outpix[zmin:zmax, y] - self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype(np.uint8) + self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8") for j in range(2): self.layerOrtho[j].setImage(self.layer_ortho[j]) diff --git a/cellpose/models.py b/cellpose/models.py index 5329d7f5..6f4edfc9 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -35,7 +35,7 @@ "lowhigh": None, "percentile": None, "normalize": True, - "norm3D": False, + "norm3D": True, "sharpen_radius": 0, "smooth_radius": 0, "tile_norm_blocksize": 0, @@ -263,7 +263,7 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, if (pretrained_model and not Path(pretrained_model).exists() and np.any([pretrained_model == s for s in all_models])): model_type = pretrained_model - + # check if model_type is builtin or custom user model saved in .cellpose/models if model_type is not None and np.any([model_type == s for s in all_models]): if np.any([model_type == s for s in MODEL_NAMES]): @@ -286,6 +286,10 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, models_logger.warning( "pretrained_model path does not exist, using default model") use_default = True + else: + if pretrained_model[-13:] == "nucleitorch_0": + builtin = True + self.diam_mean = 17. builtin = True if use_default else builtin self.pretrained_model = model_path( From e7ac746ca7c1283800e64706e4b31bda85d01dbb Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Tue, 23 Jul 2024 13:59:16 -0400 Subject: [PATCH 08/22] setting gblur --- cellpose/denoise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 0c95eb36..40aee0e8 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -1151,7 +1151,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N blur = 0.8 downsample = 0. beta = 0.1 - gblur = 1.0 + gblur = 12.0 elif noise_type == "downsample": poisson = 0.8 blur = 0.8 From 35b431ec0b8820a94f15cbe3daf57ed9869675c9 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Tue, 23 Jul 2024 16:48:30 -0400 Subject: [PATCH 09/22] setting gblur --- cellpose/denoise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 40aee0e8..ad05a0b1 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -1156,8 +1156,8 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N poisson = 0.8 blur = 0.8 downsample = 0.8 - beta = 0.01 - gblur = 0.5 + beta = 0.05 + gblur = 8.0 elif noise_type == "all": poisson = [0.8, 0.8, 0.8] blur = [0., 0.8, 0.8] From 3ecd2790cea7fb02579f5d568e5459a98e857450 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Tue, 23 Jul 2024 21:36:24 -0400 Subject: [PATCH 10/22] setting for blurring --- cellpose/denoise.py | 9 ++++++--- cellpose/train.py | 23 +++++++++++++++-------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index ad05a0b1..173701c3 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -1062,9 +1062,9 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N train_losses.append(lavg) if save_path is not None: - if iepoch == n_epochs - 1 or iepoch % save_every == 1: + if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): if save_each: #separate files as model progresses - filename0 = filename + "_epoch_" + str(iepoch) + filename0 = str(filename) + f"_epoch_{iepoch:%04d}" else: filename0 = filename denoise_logger.info(f"saving network parameters to {filename0}") @@ -1125,6 +1125,9 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N help="learning rate. Default: %(default)s") training_args.add_argument("--n_epochs", default=2000, type=int, help="number of epochs. Default: %(default)s") + training_args.add_argument( + "--save_each", default=False, action="store_true", + help="save each epoch as separate model") training_args.add_argument( "--nimg_per_epoch", default=0, type=int, help="number of images per epoch. Default is length of training images") @@ -1151,7 +1154,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N blur = 0.8 downsample = 0. beta = 0.1 - gblur = 12.0 + gblur = 10.0 elif noise_type == "downsample": poisson = 0.8 blur = 0.8 diff --git a/cellpose/train.py b/cellpose/train.py index 01d5fc64..af214210 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -328,7 +328,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, test_probs=None, load_files=True, batch_size=8, learning_rate=0.005, n_epochs=2000, weight_decay=1e-5, momentum=0.9, SGD=False, channels=None, channel_axis=None, rgb=False, normalize=True, compute_flows=False, - save_path=None, save_every=100, nimg_per_epoch=None, + save_path=None, save_every=100, save_each=False, nimg_per_epoch=None, nimg_test_per_epoch=None, rescale=True, scale_range=None, bsize=224, min_train_masks=5, model_name=None): """ @@ -359,6 +359,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False. save_path (str, optional): String - where to save the trained model. Defaults to None. save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100. + save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False. nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None. nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None. rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True. @@ -441,10 +442,10 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, t0 = time.time() model_name = f"cellpose_{t0}" if model_name is None else model_name save_path = Path.cwd() if save_path is None else Path(save_path) - model_path = save_path / "models" / model_name + filename = save_path / "models" / model_name (save_path / "models").mkdir(exist_ok=True) - train_logger.info(f">>> saving model to {model_path}") + train_logger.info(f">>> saving model to {filename}") lavg, nsum = 0, 0 for iepoch in range(n_epochs): @@ -519,11 +520,17 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, ) lavg, nsum = 0, 0 - if iepoch > 0 and iepoch % save_every == 0: - net.save_model(model_path) - net.save_model(model_path) - - return model_path + if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): + if save_each and iepoch != n_epochs - 1: #separate files as model progresses + filename0 = str(filename) + f"_epoch_{iepoch:%04d}" + else: + filename0 = filename + train_logger.info(f"saving network parameters to {filename0}") + net.save_model(filename0) + + net.save_model(filename) + + return filename def train_size(net, pretrained_model, train_data=None, train_labels=None, From 9a6b19051a2d9c33354887e12eeec9ad23d043a5 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Wed, 24 Jul 2024 17:20:56 -0400 Subject: [PATCH 11/22] proportional ds --- cellpose/denoise.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 173701c3..184ea318 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -241,8 +241,20 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp ds = ds * torch.ones( (len(lbl),), device=device, dtype=torch.long) if ds is not None else ds + # downsample + ii = [] + idownsample = np.random.rand(len(lbl)) < downsample + if (ds is None and idownsample.sum() > 0.) or not iso: + ds = torch.ones(len(lbl), dtype=torch.long, device=device) + ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),), + device=device) + ii = torch.nonzero(ds > 1).flatten() + elif ds is not None and (ds > 1).sum(): + ii = torch.nonzero(ds > 1).flatten() + # add gaussian blur - iblur = np.random.rand(len(lbl)) < blur + iblur = torch.rand(len(lbl), device=device) < blur + iblur[ii] = True if iblur.sum() > 0: if sigma0 is None: if iso: @@ -254,8 +266,10 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp # device) # #(1 + torch.rand(iblur.sum(), device=device)) # sigma1 = sigma0.clone() - sigma0 = diams[iblur] / 30. * gblur * (1/gblur + - (1 - 1/gblur) * torch.rand(iblur.sum(), device=device)) + xr = torch.rand(len(lbl), device=device) + if ii.shape[0] > 0: + xr[ii] = (ds[ii].float() / 2.) / gblur + sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) sigma1 = sigma0.clone() else: sigma0 = diams[iblur] / 30. * gblur * (1/gblur + @@ -302,16 +316,7 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp imgi[~iblur] = lbl[~iblur] - # downsample - ii = [] - idownsample = np.random.rand(len(lbl)) < downsample - if (ds is None and idownsample.sum() > 0.) or not iso: - ds = torch.ones(len(lbl), dtype=torch.long, device=device) - ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),), - device=device) - ii = torch.nonzero(ds > 1) - elif ds is not None and (ds > 1).sum(): - ii = torch.nonzero(ds > 1) + # apply downsample for k in ii: i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]] imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear") @@ -1159,14 +1164,14 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N poisson = 0.8 blur = 0.8 downsample = 0.8 - beta = 0.05 - gblur = 8.0 + beta = 0.03 + gblur = 5.0 elif noise_type == "all": poisson = [0.8, 0.8, 0.8] blur = [0., 0.8, 0.8] downsample = [0., 0., 0.8] - beta = [0.7, 0.1, 0.01] - gblur = [0., 1.0, 0.5] + beta = [0.7, 0.1, 0.05] + gblur = [0., 10.0, 8.0] else: raise ValueError(f"{noise_type} noise_type is not supported") else: From 043f89cfc2322ab006b71ea68821e8adf0d8b4c1 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Wed, 24 Jul 2024 17:26:35 -0400 Subject: [PATCH 12/22] bug in ds --- cellpose/denoise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 184ea318..8238c371 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -267,7 +267,7 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp # #(1 + torch.rand(iblur.sum(), device=device)) # sigma1 = sigma0.clone() xr = torch.rand(len(lbl), device=device) - if ii.shape[0] > 0: + if len(ii) > 0: xr[ii] = (ds[ii].float() / 2.) / gblur sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) sigma1 = sigma0.clone() From 7e61f5052bd80ade9793e2d07bb6869946468a66 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Mon, 29 Jul 2024 08:11:43 -0400 Subject: [PATCH 13/22] updating denoise --- cellpose/denoise.py | 4 ++-- cellpose/transforms.py | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 8238c371..adf57961 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -1170,8 +1170,8 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N poisson = [0.8, 0.8, 0.8] blur = [0., 0.8, 0.8] downsample = [0., 0., 0.8] - beta = [0.7, 0.1, 0.05] - gblur = [0., 10.0, 8.0] + beta = [0.7, 0.1, 0.03] + gblur = [0., 10.0, 5.0] else: raise ValueError(f"{noise_type} noise_type is not supported") else: diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 4fcacfcd..980e22a5 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -729,7 +729,7 @@ def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEA return imgs -def pad_image_ND(img0, div=16, extra=1, min_size=None): +def pad_image_ND(img0, div=16, extra=1, min_size=None, zpad=False): """Pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D). Args: @@ -758,7 +758,13 @@ def pad_image_ND(img0, div=16, extra=1, min_size=None): ypad2 = extra * div // 2 + Lpad - Lpad // 2 if img0.ndim > 3: - pads = np.array([[0, 0], [0, 0], [xpad1, xpad2], [ypad1, ypad2]]) + if zpad: + Lpad = int(div * np.ceil(img0.shape[-3] / div) - img0.shape[-3]) + zpad1 = extra * div // 2 + Lpad // 2 + zpad2 = extra * div // 2 + Lpad - Lpad // 2 + else: + zpad1, zpad2 = 0, 0 + pads = np.array([[0, 0], [zpad1, zpad2], [xpad1, xpad2], [ypad1, ypad2]]) else: pads = np.array([[0, 0], [xpad1, xpad2], [ypad1, ypad2]]) @@ -767,8 +773,11 @@ def pad_image_ND(img0, div=16, extra=1, min_size=None): Ly, Lx = img0.shape[-2:] ysub = np.arange(xpad1, xpad1 + Ly) xsub = np.arange(ypad1, ypad1 + Lx) - - return I, ysub, xsub + if zpad: + zsub = np.arange(zpad1, zpad1 + img0.shape[-3]) + return I, ysub, xsub, zsub + else: + return I, ysub, xsub def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=False, From 2382f361a007dbd23a03af2ec0cd54d2874e4fba Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Mon, 5 Aug 2024 11:17:09 -0400 Subject: [PATCH 14/22] bug in pretrained_model --- cellpose/denoise.py | 16 ++++++---------- cellpose/models.py | 2 +- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index adf57961..070e9e47 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -258,22 +258,18 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp if iblur.sum() > 0: if sigma0 is None: if iso: - # was 10 - # xrand = np.random.exponential(1, size=iblur.sum()) - # xrand = np.clip(xrand * 0.5, 0.1, 1.0) - # xrand *= gblur - # sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to( - # device) - # #(1 + torch.rand(iblur.sum(), device=device)) - # sigma1 = sigma0.clone() xr = torch.rand(len(lbl), device=device) if len(ii) > 0: xr[ii] = (ds[ii].float() / 2.) / gblur sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) sigma1 = sigma0.clone() else: - sigma0 = diams[iblur] / 30. * gblur * (1/gblur + - (1 - 1/gblur) * torch.rand(iblur.sum(), device=device)) + xr = torch.rand(len(lbl), device=device) + if len(ii) > 0: + xr[ii] = (ds[ii].float() / 2.) / gblur + sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) + #sigma0 = diams[iblur] / 30. * gblur * (1/gblur + + # (1 - 1/gblur) * torch.rand(iblur.sum(), device=device)) sigma1 = sigma0.clone() / 10. else: sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device) diff --git a/cellpose/models.py b/cellpose/models.py index 6f4edfc9..ad9f7eaa 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -286,7 +286,7 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, models_logger.warning( "pretrained_model path does not exist, using default model") use_default = True - else: + elif pretrained_model: if pretrained_model[-13:] == "nucleitorch_0": builtin = True self.diam_mean = 17. From 0e1dd920a3b408293de13ee4f543286934b74d9b Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Mon, 5 Aug 2024 11:31:48 -0400 Subject: [PATCH 15/22] adding more zeros to LR --- cellpose/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellpose/train.py b/cellpose/train.py index af214210..f58240ef 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -516,7 +516,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, lavgt /= len(rperm) lavg /= nsum train_logger.info( - f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.4f}, time {time.time()-t0:.2f}s" + f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" ) lavg, nsum = 0, 0 From 615a674df6b11173366b1ae768cc274a1289b3f3 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Thu, 15 Aug 2024 12:39:23 -0400 Subject: [PATCH 16/22] adding in original blur --- cellpose/core.py | 2 +- cellpose/denoise.py | 22 +++++++++-------- paper/3.0/analysis.py | 57 ++++++------------------------------------- paper/3.0/figures.py | 40 +++++++++++++++++++++--------- 4 files changed, 49 insertions(+), 72 deletions(-) diff --git a/cellpose/core.py b/cellpose/core.py index f67009b5..40883e97 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -220,7 +220,7 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1, # slices from padding # slc = [slice(0, self.nclasses) for n in range(imgs.ndim)] # changed from imgs.shape[n]+1 for first slice size slc = [slice(0, imgs.shape[n] + 1) for n in range(imgs.ndim)] - slc[-3] = slice(0, 3) + slc[-3] = slice(0, net.nout) slc[-2] = slice(ysub[0], ysub[-1] + 1) slc[-1] = slice(xsub[0], xsub[-1] + 1) slc = tuple(slc) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 070e9e47..317cdc7f 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -210,7 +210,7 @@ def img_norm(imgi): def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7, ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None, - ds=None, partial_blur=False): + ds=None, uniform_blur=False, partial_blur=False): """Adds noise to the input image. Args: @@ -257,20 +257,22 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp iblur[ii] = True if iblur.sum() > 0: if sigma0 is None: - if iso: + if not iso or uniform_blur: xr = torch.rand(len(lbl), device=device) if len(ii) > 0: xr[ii] = (ds[ii].float() / 2.) / gblur sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) - sigma1 = sigma0.clone() + if iso: + sigma1 = sigma0.clone() + else: + sigma1 = sigma0.clone() / 10. else: - xr = torch.rand(len(lbl), device=device) - if len(ii) > 0: - xr[ii] = (ds[ii].float() / 2.) / gblur - sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) - #sigma0 = diams[iblur] / 30. * gblur * (1/gblur + - # (1 - 1/gblur) * torch.rand(iblur.sum(), device=device)) - sigma1 = sigma0.clone() / 10. + xrand = np.random.exponential(1, size=iblur.sum()) + xrand = np.clip(xrand * 0.5, 0.1, 1.0) + xrand *= gblur + sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to( + device) + sigma1 = sigma0.clone() else: sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device) sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device) diff --git a/paper/3.0/analysis.py b/paper/3.0/analysis.py index 62e3a400..aff669cf 100644 --- a/paper/3.0/analysis.py +++ b/paper/3.0/analysis.py @@ -24,47 +24,7 @@ device = torch.device("cuda") try: - import segmentation_models_pytorch as smp - - class Transformer(nn.Module): - - def __init__(self, pretrained_model=None, encoder="mit_b5", - encoder_weights="imagenet", decoder="FPN"): - super().__init__() - net_fcn = smp.FPN if decoder == "FPN" else smp.MAnet - self.net = net_fcn( - encoder_name=encoder, - encoder_weights=encoder_weights if pretrained_model is None else - None, # use `imagenet` pre-trained weights for encoder initialization - in_channels=3, - classes=3, - activation=None) - self.nout = 3 - self.mkldnn = False - if pretrained_model is not None: - state_dict = torch.load(pretrained_model) - if list(state_dict.keys())[0][:7] == "module.": - from collections import OrderedDict - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = k[ - 7:] # remove 'module.' of DataParallel/DistributedDataParallel - new_state_dict[name] = v - self.net.load_state_dict(new_state_dict) - else: - self.load_state_dict(state_dict) - - def forward(self, X): - X = torch.cat( - (X, torch.zeros( - (X.shape[0], 1, X.shape[2], X.shape[3]), device=X.device)), dim=1) - y = self.net(X) - return y, torch.zeros((X.shape[0], 256), device=X.device) - - @property - def device(self): - return next(self.parameters()).device - + from cellpose.segformer import Transformer except Exception as e: print(e) print("need to install segmentation_models_pytorch to run transformer") @@ -402,21 +362,20 @@ def cyto3_comparisons(folder): ] net_types = ["generalist", "specialist", "transformer"] - for net_type in net_types: + for net_type in net_types[-1:]: if net_type == "generalist": seg_model = models.Cellpose(gpu=True, model_type="cyto3") elif net_type == "transformer": - seg_model = models.CellposeModel(gpu=True, pretrained_model=None) pretrained_model = "/home/carsen/.cellpose/models/transformer_cp3" - seg_model.net = Transformer(pretrained_model=pretrained_model, - decoder="MAnet").to(device) - for f in folders: + seg_model = models.CellposeModel(gpu=True, backbone="transformer", + pretrained_model=pretrained_model) + for f in folders[:3]: if net_type == "specialist": seg_model = models.CellposeModel(gpu=True, model_type=f"{f}_cp3") root = Path(folder) / f"images_{f}" channels = [1, 2] if f == "tissuenet" or f == "cyto2" else [1, 0] - tifs = (root / "test").glob("*.tif") + tifs = natsorted((root / "test").glob("*.tif")) tifs = [tif for tif in tifs] tifs = [ tif for tif in tifs @@ -424,7 +383,7 @@ def cyto3_comparisons(folder): ] if net_type != "generalist": d = np.load( - f"/media/carsen/ssd4/datasets_cellpose/{f}_generalist_masks.npy", + Path(folder) / f"{f}_generalist_masks.npy", allow_pickle=True).item() diams = d["diams"] else: @@ -456,7 +415,7 @@ def cyto3_comparisons(folder): dat["performance"] = [ap, tp, fp, fn] dat["diams"] = diams - #p.save(f"/media/carsen/ssd4/datasets_cellpose/{f}_{net_type}_masks.npy", dat) + #np.save(f"/media/carsen/ssd4/datasets_cellpose/{f}_{net_type}_masks.npy", dat) if __name__ == '__main__': diff --git a/paper/3.0/figures.py b/paper/3.0/figures.py index 6e3cbd7e..97803a20 100644 --- a/paper/3.0/figures.py +++ b/paper/3.0/figures.py @@ -1021,6 +1021,16 @@ def load_benchmarks_specialist(folder, thresholds=np.arange(0.5, 1.05, 0.05)): imgs_all.append(test_care) masks_all.append(masks_care) + dat = np.load(root / "noisy_test" / f"test_{noise_type}_denoiseg_specialist.npy", + allow_pickle=True).item() + test_dns = dat["test_denoiseg"][:nimg_test] + masks_dns = dat["masks_denoiseg"][:nimg_test] + imgs_all.append(test_dns) + masks_all.append(masks_dns) + masks_dns = dat["masks_denoiseg_seg"][:nimg_test] + imgs_all.append(test_dns) + masks_all.append(masks_dns) + dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3.npy", allow_pickle=True).item() istr = ["rec", "per", "seg", "perseg"] @@ -1069,11 +1079,16 @@ def suppfig_specialist(folder, save_fig=True): legstr0 = [] for ls in legstr[:-1]: legstr0.append(" ".join(ls.split(" ")[1:])) - legstr0.insert(4, "CARE") + legstr0[-1] = u"\u2013 " + legstr0[-1] + legstr0.insert(4, u"\u2013 CARE") + legstr0.insert(5, u"\u2013 denoiseg") + legstr0.insert(6, "-- denoiseg\n(segmentation)") cols0 = list(cols[:-1].copy()) cols0.insert(4, [1, 0.5, 1]) + cols0.insert(5, 0.4*np.ones(3)) + cols0.insert(6, 0.4*np.ones(3)) print(len(cols0)) - legstr0[-1] = "Cellpose3\n(per. + seg.)" + legstr0[-1] = u"\u2013 Cellpose3\n(per. + seg.)" il = 0 @@ -1094,7 +1109,7 @@ def suppfig_specialist(folder, save_fig=True): imset = imgs_all[1].copy() ax = plt.subplot(grid[0, j]) pos = ax.get_position().bounds - ax.set_position([pos[0] - 0.015 * j, pos[1] - 0.04, pos[2], pos[3]]) + ax.set_position([pos[0] - 0.02 * j, pos[1] - 0.04, pos[2], pos[3]]) ly, lx = 128, 128 dy, dx = 20, 30 ni = 5 @@ -1119,17 +1134,18 @@ def suppfig_specialist(folder, save_fig=True): ax.text(0.02, 1.2, "Specialist dataset", fontsize="large", fontstyle="italic", transform=ax.transAxes) - transl = mtransforms.ScaledTranslation(-50 / 72, 8 / 72, fig.dpi_scale_trans) + transl = mtransforms.ScaledTranslation(-45 / 72, 8 / 72, fig.dpi_scale_trans) ax = plt.subplot(grid[0, -1]) pos = ax.get_position().bounds - ax.set_position([pos[0] + 0.03, pos[1] - 0.03, pos[2] * 0.8, + ax.set_position([pos[0] + 0.01, pos[1] - 0.03, pos[2] * 0.8, pos[3] * 1]) #+pos[3]*0.15-0.03, pos[2], pos[3]*0.7]) il = plot_label(ltr, il, ax, transl, fs_title) - theight = [0, 1, 2, 3, 4, 5, 6, 7, 5.1] - for k in [1, 2, 3, 4, 8]: - ax.plot(thresholds, aps[k, :, :].mean(axis=0), color=cols0[k]) + theight = [0, 0, 4, 3, 6, 5, 1, 5, 7, 8, 7.1] + for k in [1, 2, 3, 4, 5, 6, 10]: + ax.plot(thresholds, aps[k, :, :].mean(axis=0), color=cols0[k], + lw=3 if k==4 else 1, ls="--" if k==6 else "-") #ax.errorbar(thresholds, aps[k,:,:].mean(axis=0), aps[k,:,:].std(axis=0) / 10**0.5, color=cols0[k]) - ax.text(0.59, 0.55 + 0.08 * theight[k], legstr0[k], color=cols0[k], + ax.text(0.7, 0.3 + 0.09 * theight[k], legstr0[k], color=cols0[k], transform=ax.transAxes) ax.set_ylim([0, 0.8]) ax.set_ylabel("average precision (AP)") @@ -1139,11 +1155,11 @@ def suppfig_specialist(folder, save_fig=True): transl = mtransforms.ScaledTranslation(-10 / 72, 20 / 72, fig.dpi_scale_trans) - kk = [2, 3, 4, 8] + kk = [2, 3, 4, 10] iex = 8 ylim = [10, 310] xlim = [100, 500] - legstr0[-1] = "Cellpose3 (per. + seg.)" + legstr0[-1] = u"\u2013 Cellpose3 (per. + seg.)" for j, k in enumerate(kk): ax = plt.subplot(grid[1, j]) pos = ax.get_position().bounds @@ -1156,7 +1172,7 @@ def suppfig_specialist(folder, save_fig=True): ax.axis("off") ax.set_ylim(ylim) ax.set_xlim(xlim) - ax.set_title(legstr0[k], color=cols0[k], fontsize="medium") + ax.set_title(legstr0[k][2:], color=cols0[k], fontsize="medium") ax.text(1, -0.04, f"AP@0.5 = {aps[k,iex,0] : 0.2f}", va="top", ha="right", transform=ax.transAxes) if j == 0: From 6759f70253fdd0791b05f8500b7df9569378fd79 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sat, 17 Aug 2024 08:37:54 -0400 Subject: [PATCH 17/22] adding anio --- cellpose/denoise.py | 44 +++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 317cdc7f..7ecab448 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -257,15 +257,20 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp iblur[ii] = True if iblur.sum() > 0: if sigma0 is None: - if not iso or uniform_blur: + if uniform_blur and iso: xr = torch.rand(len(lbl), device=device) if len(ii) > 0: - xr[ii] = (ds[ii].float() / 2.) / gblur + xr[ii] = ds[ii].float() / 2. / gblur sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) - if iso: - sigma1 = sigma0.clone() - else: - sigma1 = sigma0.clone() / 10. + sigma1 = sigma0.clone() + elif not iso: + xr = torch.rand(len(lbl), device=device) + if len(ii) > 0: + xr[ii] = (ds[ii].float()) / gblur + xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35 + xr[ii] = torch.clip(xr[ii], 0.05, 1.5) + sigma0 = diams[iblur] / 30. * gblur * xr[iblur] + sigma1 = sigma0.clone() / 10. else: xrand = np.random.exponential(1, size=iblur.sum()) xrand = np.clip(xrand * 0.5, 0.1, 1.0) @@ -341,7 +346,7 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7, downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30, - ds_max=7, iso=True, rotate=True, + ds_max=7, uniform_blur=False, iso=True, rotate=True, device=torch.device("cuda"), xy=(224, 224), nchan_noise=1, keep_raw=True): """ @@ -894,13 +899,12 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N lstrs = ["per", "seg", "rec"] for k, (l, s) in enumerate(zip(lam, lstrs)): filename += f"{s}_{l:.2f}_" + if not iso: + filename += "aniso_" if poisson.sum() > 0: filename += "poisson_" if blur.sum() > 0: - if iso: - filename += "blur_" - else: - filename += "bluraniso_" + filename += "blur_" if downsample.sum() > 0: filename += "downsample_" filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f") @@ -1112,7 +1116,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N help="scale of gaussian blurring stddev") training_args.add_argument("--downsample", default=0., type=float, help="fraction of images to downsample") - training_args.add_argument("--ds_max", default=7, type=int, + training_args.add_argument("--ds_max", default=10, type=int, help="max downsampling factor") training_args.add_argument("--lam_per", default=1.0, type=float, help="weighting of perceptual loss") @@ -1146,6 +1150,8 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N if len(args.noise_type) > 0: noise_type = args.noise_type + uniform_blur = False + iso = True if noise_type == "poisson": poisson = 0.8 blur = 0. @@ -1157,19 +1163,27 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N blur = 0.8 downsample = 0. beta = 0.1 - gblur = 10.0 + gblur = 0.5 elif noise_type == "downsample": poisson = 0.8 blur = 0.8 downsample = 0.8 beta = 0.03 - gblur = 5.0 + gblur = 1.0 elif noise_type == "all": poisson = [0.8, 0.8, 0.8] blur = [0., 0.8, 0.8] downsample = [0., 0., 0.8] beta = [0.7, 0.1, 0.03] gblur = [0., 10.0, 5.0] + uniform_blur = True + elif noise_type == "aniso": + poisson = 0.8 + blur = 0.8 + downsample = 0.8 + beta = 0.1 + gblur = args.ds_max * 1.5 + iso = False else: raise ValueError(f"{noise_type} noise_type is not supported") else: @@ -1234,7 +1248,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N test_data=test_data, test_labels=test_labels, test_files=test_files, train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta, blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max, - iso=True, n_epochs=args.n_epochs, + iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs, learning_rate=args.learning_rate, lam=lams, seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch, From 2e9f44a6a72cb7a2f1ccc0c338cf2213f731b9fa Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sat, 17 Aug 2024 08:52:38 -0400 Subject: [PATCH 18/22] reverting ds_max to 7 --- cellpose/denoise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 7ecab448..499e2dec 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -1116,7 +1116,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N help="scale of gaussian blurring stddev") training_args.add_argument("--downsample", default=0., type=float, help="fraction of images to downsample") - training_args.add_argument("--ds_max", default=10, type=int, + training_args.add_argument("--ds_max", default=7, type=int, help="max downsampling factor") training_args.add_argument("--lam_per", default=1.0, type=float, help="weighting of perceptual loss") From 1d90853db43b4f4fcda1c61fed2914ba12a310d0 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Sat, 17 Aug 2024 09:05:04 -0400 Subject: [PATCH 19/22] adding uniform blur arg --- cellpose/denoise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 499e2dec..f4fa603e 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -876,7 +876,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N test_labels=None, test_files=None, train_probs=None, test_probs=None, lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None, save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0, - iso=True, downsample=0., ds_max=7, + iso=True, uniform_blur=False, downsample=0., ds_max=7, learning_rate=0.005, n_epochs=500, weight_decay=0.00001, batch_size=8, nimg_per_epoch=None, nimg_test_per_epoch=None, model_name=None): From fc473ca7b3647cd00074f9a08cfc4a2f69e4ddab Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Fri, 6 Sep 2024 13:01:57 -0400 Subject: [PATCH 20/22] updating training --- cellpose/core.py | 6 +- cellpose/denoise.py | 79 ++++++--- cellpose/models.py | 41 ++--- cellpose/transforms.py | 25 ++- paper/3.0/analysis.py | 123 +++++--------- paper/3.0/figures.py | 377 +++++++++++++++++++++++++++++++++-------- 6 files changed, 433 insertions(+), 218 deletions(-) diff --git a/cellpose/core.py b/cellpose/core.py index 40883e97..96acdb1b 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -272,7 +272,8 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0 yf = np.zeros((Lz, nout, imgi.shape[-2], imgi.shape[-1]), np.float32) styles = [] if ny * nx > batch_size: - ziterator = trange(Lz, file=tqdm_out, mininterval=30) + ziterator = (trange(Lz, file=tqdm_out, mininterval=30) + if Lz > 1 else range(Lz)) for i in ziterator: yfi, stylei = _run_tiled(net, imgi[i], augment=augment, bsize=bsize, tile_overlap=tile_overlap) @@ -283,7 +284,8 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0 ntiles = ny * nx nimgs = max(2, int(np.round(batch_size / ntiles))) niter = int(np.ceil(Lz / nimgs)) - ziterator = trange(niter, file=tqdm_out, mininterval=30) + ziterator = (trange(niter, file=tqdm_out, mininterval=30) + if Lz > 1 else range(niter)) for k in ziterator: IMGa = np.zeros((ntiles * nimgs, nchan, ly, lx), np.float32) for i in range(min(Lz - k * nimgs, nimgs)): diff --git a/cellpose/denoise.py b/cellpose/denoise.py index f4fa603e..9641d409 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -794,13 +794,13 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, x[..., c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size, normalize=normalize, rescale=rescale0, tile=tile, - tile_overlap=tile_overlap, bsize=bsize) + tile_overlap=tile_overlap, bsize=bsize)[...,0] else: x[..., c] = self._eval(self.net_chan2, x[..., c:c + 1], batch_size=batch_size, normalize=normalize, rescale=rescale0, tile=tile, - tile_overlap=tile_overlap, bsize=bsize) + tile_overlap=tile_overlap, bsize=bsize)[...,0] x = x[0] if squeeze else x return x @@ -845,31 +845,39 @@ def _eval(self, net, x, batch_size=8, normalize=True, rescale=None, tile=True, do_normalization = True if normalize_params["normalize"] else False - tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO) - iterator = trange(nimg, file=tqdm_out, - mininterval=30) if nimg > 1 else range(nimg) - imgs = np.zeros((*x.shape[:-1], 1), np.float32) - for i in iterator: - img = np.asarray(x[i]) - if do_normalization: - img = transforms.normalize_img(img, **normalize_params) - if rescale != 1.0: - img = transforms.resize_image(img, rsz=[rescale, rescale]) - if img.ndim == 2: - img = img[:, :, np.newaxis] - yf, style = run_net(net, img, batch_size=batch_size, augment=False, - tile=tile, tile_overlap=tile_overlap, bsize=bsize) - img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2]) - - if img.ndim == 2: - img = img[:, :, np.newaxis] - imgs[i] = img - del yf, style + img = np.asarray(x) + if do_normalization: + img = transforms.normalize_img(img, **normalize_params) + if rescale != 1.0: + img = transforms.resize_image(img, rsz=rescale) + yf, style = run_net(self.net, img, bsize=bsize, + tile=tile, tile_overlap=tile_overlap) + yf = transforms.resize_image(yf, shape[1], shape[2]) + imgs = yf + del yf, style + + # imgs = np.zeros((*x.shape[:-1], 1), np.float32) + # for i in iterator: + # img = np.asarray(x[i]) + # if do_normalization: + # img = transforms.normalize_img(img, **normalize_params) + # if rescale != 1.0: + # img = transforms.resize_image(img, rsz=[rescale, rescale]) + # if img.ndim == 2: + # img = img[:, :, np.newaxis] + # yf, style = run_net(net, img, batch_size=batch_size, augment=False, + # tile=tile, tile_overlap=tile_overlap, bsize=bsize) + # img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2]) + + # if img.ndim == 2: + # img = img[:, :, np.newaxis] + # imgs[i] = img + # del yf, style net_time = time.time() - tic if nimg > 1: denoise_logger.info("imgs denoised in %2.2fs" % (net_time)) - return imgs.squeeze() + return imgs def train(net, train_data=None, train_labels=None, train_files=None, test_data=None, @@ -997,7 +1005,8 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(), poisson=poisson[inoise], beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso, - downsample=downsample[inoise], diam_mean=diam_mean, ds_max=ds_max, + downsample=downsample[inoise], uniform_blur=uniform_blur, + diam_mean=diam_mean, ds_max=ds_max, device=device) if i == 0: img = imgi @@ -1049,8 +1058,8 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N img, lbl, scale = random_rotate_and_resize_noise( imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise], beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise], - iso=iso, downsample=downsample[inoise], diam_mean=diam_mean, - device=device) + iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur, + diam_mean=diam_mean, ds_max=ds_max, device=device) loss, loss_per = test_loss(net, img[:, :nchan], net1=net1, img=img[:, nchan:], lbl=lbl, lam=lam) @@ -1158,18 +1167,32 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N downsample = 0. beta = 0.7 gblur = 1.0 - elif noise_type == "blur": + elif noise_type == "blur_expr": poisson = 0.8 blur = 0.8 downsample = 0. beta = 0.1 gblur = 0.5 - elif noise_type == "downsample": + elif noise_type == "blur": + poisson = 0.8 + blur = 0.8 + downsample = 0. + beta = 0.1 + gblur = 10.0 + uniform_blur = True + elif noise_type == "downsample_expr": poisson = 0.8 blur = 0.8 downsample = 0.8 beta = 0.03 gblur = 1.0 + elif noise_type == "downsample": + poisson = 0.8 + blur = 0.8 + downsample = 0.8 + beta = 0.03 + gblur = 5.0 + uniform_blur = True elif noise_type == "all": poisson = [0.8, 0.8, 0.8] blur = [0., 0.8, 0.8] diff --git a/cellpose/models.py b/cellpose/models.py index ad9f7eaa..a1b3e2b6 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -502,37 +502,18 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non del yf else: tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) - iterator = trange(nimg, file=tqdm_out, - mininterval=30) if nimg > 1 else range(nimg) - styles = np.zeros((nimg, self.nbase[-1]), np.float32) + img = np.asarray(x) + if do_normalization: + img = transforms.normalize_img(img, **normalize_params) + if rescale != 1.0: + img = transforms.resize_image(img, rsz=rescale) + yf, style = run_net(self.net, img, bsize=bsize, augment=augment, + tile=tile, tile_overlap=tile_overlap) if resample: - dP = np.zeros((2, nimg, shape[1], shape[2]), np.float32) - cellprob = np.zeros((nimg, shape[1], shape[2]), np.float32) - else: - dP = np.zeros( - (2, nimg, int(shape[1] * rescale), int(shape[2] * rescale)), - np.float32) - cellprob = np.zeros( - (nimg, int(shape[1] * rescale), int(shape[2] * rescale)), - np.float32) - for i in iterator: - img = np.asarray(x[i]) - if do_normalization: - img = transforms.normalize_img(img, **normalize_params) - if rescale != 1.0: - img = transforms.resize_image(img, rsz=rescale) - yf, style = run_net(self.net, img, bsize=bsize, augment=augment, - tile=tile, tile_overlap=tile_overlap) - if resample: - yf = transforms.resize_image(yf, shape[1], shape[2]) - - cellprob[i] = yf[:, :, 2] - dP[:, i] = yf[:, :, :2].transpose((2, 0, 1)) - if self.nclasses == 4: - if i == 0: - bd = np.zeros_like(cellprob) - bd[i] = yf[:, :, 3] - styles[i][:len(style)] = style + yf = transforms.resize_image(yf, shape[1], shape[2]) + dP = np.moveaxis(yf[..., :2], source=-1, destination=0).copy() + cellprob = yf[..., 2] + styles = style del yf, style styles = styles.squeeze() diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 980e22a5..c9154594 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -469,9 +469,9 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha if len(to_squeeze) > 0: channel_axis = update_axis( channel_axis, to_squeeze, - x.ndim) if channel_axis is not None else channel_axis + x.ndim) if channel_axis is not None else None z_axis = update_axis(z_axis, to_squeeze, - x.ndim) if z_axis is not None else z_axis + x.ndim) if z_axis is not None else None x = x.squeeze() # put z axis first @@ -480,7 +480,19 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha if channel_axis is not None: channel_axis += 1 z_axis = 0 - + elif z_axis is None and x.ndim > 2 and channels is not None and min(x.shape) > 5 : + # if there are > 5 channels and channels!=None, assume first dimension is z + min_dim = min(x.shape) + if min_dim != channel_axis: + z_axis = (x.shape).index(min_dim) + if z_axis != 0: + x = move_axis(x, m_axis=z_axis, first=True) + if channel_axis is not None: + channel_axis += 1 + transforms_logger.warning(f"z_axis not specified, assuming it is dim {z_axis}") + transforms_logger.warning(f"if this is actually the channel_axis, use 'model.eval(channel_axis={z_axis}, ...)'") + z_axis = 0 + if z_axis is not None: if x.ndim == 3: x = x[..., np.newaxis] @@ -500,7 +512,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha if channel_axis is None: x = move_min_dim(x) - + if x.ndim > 3: transforms_logger.info( "multi-stack tiff read in as having %d planes %d channels" % @@ -723,7 +735,8 @@ def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEA else: imgs = np.zeros((img0.shape[0], Ly, Lx, img0.shape[-1]), np.float32) for i, img in enumerate(img0): - imgs[i] = cv2.resize(img, (Lx, Ly), interpolation=interpolation) + imgi = cv2.resize(img, (Lx, Ly), interpolation=interpolation) + imgs[i] = imgi if imgi.ndim > 2 else imgi[..., np.newaxis] else: imgs = cv2.resize(img0, (Lx, Ly), interpolation=interpolation) return imgs @@ -835,7 +848,7 @@ def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=Fal # generate random augmentation parameters flip = np.random.rand() > .5 theta = np.random.rand() * np.pi * 2 if rotate else 0. - scale[n] = (1 - scale_range / 2) + scale_range * np.random.rand() + scale[n] = 2 ** (-2 + 5 * np.random.rand())#(1 - scale_range / 2) + scale_range * np.random.rand() if rescale is not None: scale[n] *= 1. / rescale[n] dxy = np.maximum(0, np.array([Lx * scale[n] - xy[1], diff --git a/paper/3.0/analysis.py b/paper/3.0/analysis.py index aff669cf..2a9eb1de 100644 --- a/paper/3.0/analysis.py +++ b/paper/3.0/analysis.py @@ -35,49 +35,56 @@ def seg_eval_cp3(folder, noise_type="poisson"): """ need to download test_poisson.npy, test_blur.npy, test_downsample.npy (for cells and/or nuclei) - - (was computed with old flows, but results similar with new flows) """ + """ ctypes = ["cyto2", "nuclei"] - for ctype in ctypes: + for c, ctype in enumerate(ctypes): + print(ctype) + pretrained_models = [f"/home/carsen/.cellpose/models/{model_names[noise_type]}{istr}_{ctype}" + for istr in ["_rec", "_seg", "_per", ""]] + pretrained_models.extend([f"/home/carsen/.cellpose/models/{model_names[noise_type]}_cyto3", + f"/home/carsen/.cellpose/models/oneclick_{ctype}", + f"/home/carsen/.cellpose/models/oneclick_cyto3"]) + + seg_model = models.CellposeModel(gpu=True, model_type=ctype) + folder_name = ctype - diam_mean = 30 if ctype == "cyto2" else 17 root = Path(folder) / f"images_{folder_name}/" - + model_name = model_names[noise_type] + nimg_test = 68 if ctype=="cyto2" else 111 + diam_mean = 30. if ctype == "cyto2" else 17. ### cellpose enhance dat = np.load(root / "noisy_test" / f"test_{noise_type}.npy", - allow_pickle=True).item() - test_noisy = dat["test_noisy"] - masks_true = dat["masks_true"] - diam_test = dat["diam_test"] if "diam_test" in dat else 30. * np.ones( + allow_pickle=True).item() + test_noisy = dat["test_noisy"][:nimg_test] + masks_true = dat["masks_true"][:nimg_test] + diam_test = dat["diam_test"][:nimg_test] if "diam_test" in dat else diam_mean * np.ones( len(test_noisy)) - istr = ["rec", "seg", "per", "perseg"] - for k in range(len(istr)): - model_name = model_names[noise_type] - if istr[k] != "perseg": - model_name += "_" + istr[k] - model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, - model_type=f"{model_name}_{ctype}") - imgs2 = model.eval([test_noisy[i][0] for i in range(len(test_noisy))], - diameter=diam_test, channel_axis=0) - print(imgs2[0].shape) - seg_model = models.CellposeModel(gpu=True, model_type=ctype) - masks2, flows2, styles2 = seg_model.eval(imgs2, channels=[1, 0], - diameter=diam_test, channel_axis=0, - normalize=True) - flows = [flow[0] for flow in flows2] + thresholds = np.arange(0.5, 1.05, 0.05) + istrs = ["rec", "seg", "per", "perseg", "noise_spec", "data_spec", "gen"] + + print(pretrained_models) + aps = [] + for istr, pretrained_model in zip(istrs, pretrained_models): + dn_model = denoise.DenoiseModel(gpu=True, nchan=1, + diam_mean = 30 if "cyto" in pretrained_model else 17, + pretrained_model=pretrained_model) + dn_model.pretrained_model = "test" + imgs2 = dn_model.eval([test_noisy[i][0] for i in range(len(test_noisy))], + diameter=diam_test, channel_axis=0) + + masks2, flows, styles2 = seg_model.eval(imgs2, channels=[1, 0], + diameter=diam_test, channel_axis=-1, + normalize=True) + + ap, tp, fp, fn = metrics.average_precision(masks_true, masks2, threshold=thresholds) + print(f"{noise_type} {istr} AP@0.5 \t = {ap[:,0].mean(axis=0):.3f}") - ap, tp, fp, fn = metrics.average_precision(masks_true, masks2) - if ctype == "cyto2": - print(f"{istr[k]} AP@0.5 \t = {ap[:68,0].mean(axis=0):.3f}") - else: - print(f"{istr[k]} AP@0.5 \t = {ap[:,0].mean(axis=0):.3f}") - - dat[f"test_{istr[k]}"] = imgs2 - dat[f"masks_{istr[k]}"] = masks2 - dat[f"flows_{istr[k]}"] = flows - - #np.save(root / "noisy_test" / f"test_{noise_type}_cp3.npy", dat) + dat[f"test_{istr}"] = imgs2 + dat[f"masks_{istr}"] = masks2 + dat[f"flows_{istr}"] = flows + aps.append(ap) + np.save(root / "noisy_test" / f"test_{noise_type}_cp3_all.npy", dat) if noise_type == "poisson": ### cellpose retrained @@ -97,7 +104,7 @@ def seg_eval_cp3(folder, noise_type="poisson"): dat[f"masks_retrain"] = masks2 - #np.save(root / "noisy_test" / f"test_{noise_type}_cp_retrain.npy", dat) + np.save(root / "noisy_test" / f"test_{noise_type}_cp_retrain.npy", dat) def blind_denoising(folder): @@ -310,48 +317,6 @@ def specialist_training(root): noise2void.train_test_specialist(root, n_epochs=100, lr=4e-4, test=True) -def seg_eval_oneclick(folder): - noise_types = ["poisson", "blur", "downsample"] - ctypes = ["cyto2", "nuclei"] - for c, ctype in enumerate(ctypes): - folder_name = ctype - diam_mean = 30. - root = Path(f"/media/carsen/ssd4/datasets_cellpose/images_{folder_name}/") - print(ctype) - for n, noise_type in enumerate(noise_types): - print(noise_type) - ### cellpose enhance - dat = np.load(root / "noisy_test" / f"test_{noise_type}.npy", - allow_pickle=True).item() - test_noisy = dat["test_noisy"] - masks_true = dat["masks_true"] - diam_test = dat["diam_test"] if "diam_test" in dat else 30. * np.ones( - len(test_noisy)) - - model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, - model_type=model_names[noise_type] + "_cyto3", - device=torch.device("cuda")) - imgs2 = model.eval([test_noisy[i][0] for i in range(len(test_noisy))], - diameter=diam_test, channel_axis=0) - - seg_model = models.CellposeModel(gpu=True, model_type=ctype, - device=torch.device("cuda")) - masks2, flows2, styles2 = seg_model.eval(imgs2, channels=[1, 0], - diameter=diam_test, channel_axis=0, - normalize=True) - istr = "generalist" - ap, tp, fp, fn = metrics.average_precision(masks_true, masks2) - if ctype == "cyto2": - print(f"{istr} AP@0.5 \t = {ap[:68,0].mean(axis=0):.3f}") - else: - print(f"{istr} AP@0.5 \t = {ap[:,0].mean(axis=0):.3f}") - - dat[f"test_{istr}"] = imgs2 - dat[f"masks_{istr}"] = masks2 - - np.save(root / "noisy_test" / f"test_{noise_type}_generalist_cp3.npy", dat) - - def cyto3_comparisons(folder): """ diameters computed from generalist model cyto3 will need segmentation_models_pytorch to run transformer """ @@ -369,7 +334,7 @@ def cyto3_comparisons(folder): pretrained_model = "/home/carsen/.cellpose/models/transformer_cp3" seg_model = models.CellposeModel(gpu=True, backbone="transformer", pretrained_model=pretrained_model) - for f in folders[:3]: + for f in folders: if net_type == "specialist": seg_model = models.CellposeModel(gpu=True, model_type=f"{f}_cp3") diff --git a/paper/3.0/figures.py b/paper/3.0/figures.py index 97803a20..02bef712 100644 --- a/paper/3.0/figures.py +++ b/paper/3.0/figures.py @@ -71,7 +71,7 @@ def load_benchmarks(folder, noise_type="poisson", ctype="cyto2", imgs_all.append(test_n2s) masks_all.append(masks_n2s) - dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3.npy", + dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3_all.npy", allow_pickle=True).item() istr = ["rec", "per", "seg", "perseg"] for k in range(len(istr)): @@ -621,13 +621,14 @@ def suppfig_nuclei(folder, save_fig=False): def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_data", + folder3="/media/carsen/ssd4/denoising/ribo_denoise/", save_fig=False): thresholds = np.arange(0.5, 1.05, 0.05) - fig = plt.figure(figsize=(14, 8), dpi=100) + fig = plt.figure(figsize=(14, 12), dpi=100) yratio = 14 / 8 - grid = plt.GridSpec(5, 8, figure=fig, left=0.02, right=0.98, top=0.96, bottom=0.1, - wspace=0.05, hspace=0.25) + grid = plt.GridSpec(7, 8, figure=fig, left=0.02, right=0.97, top=0.98, bottom=0.08, + wspace=0.12, hspace=0.25) transl = mtransforms.ScaledTranslation(-18 / 72, 10 / 72, fig.dpi_scale_trans) il = 0 @@ -639,7 +640,7 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d kk=[0, 1, 5], seg=[0, 0, 0], dy=0.015) dat = np.load(f"{folder2}/cp_masks.npy", allow_pickle=True).item() - grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 5, subplot_spec=grid[-2:, :], + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 5, subplot_spec=grid[-4:-2, :], wspace=0.05, hspace=0.1) iex = 10 @@ -666,7 +667,7 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d ax = plt.subplot(grid1[0, k]) pos = ax.get_position().bounds - ax.set_position([pos[0], pos[1] - 0.05, pos[2], pos[3]]) + ax.set_position([pos[0], pos[1] - 0.01, pos[2], pos[3]]) ax.imshow(img, vmin=0, vmax=1, cmap="gray") ax.set_title(titlesd[k], color="k" if k < 2 else cols[-2], fontsize="medium") ax.set_xlim(xlim) @@ -679,7 +680,7 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d ax = plt.subplot(grid1[1, k]) pos = ax.get_position().bounds - ax.set_position([pos[0], pos[1] - 0.05, pos[2], pos[3]]) + ax.set_position([pos[0], pos[1] - 0.01, pos[2], pos[3]]) ax.imshow(img, vmin=0, vmax=1, cmap="gray") ax.set_xlim(xlim) ax.set_ylim(ylim) @@ -706,7 +707,7 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d transl = mtransforms.ScaledTranslation(-40 / 72, 15 / 72, fig.dpi_scale_trans) ax = plt.subplot(grid1[:, 3]) pos = ax.get_position().bounds - ax.set_position([pos[0] + 0.05, pos[1] - 0.03, pos[2] * 0.7, pos[3]]) + ax.set_position([pos[0] + 0.05, pos[1] - 0.0, pos[2] * 0.7, pos[3]]) nl = 0 titlesd = titles.copy() titlesd[7] = "Cellpose3" @@ -731,7 +732,7 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d ax = plt.subplot(grid1[:, 4]) pos = ax.get_position().bounds - ax.set_position([pos[0] + 0.05, pos[1] - 0.03, pos[2] * 0.7, pos[3]]) + ax.set_position([pos[0] + 0.05, pos[1] - 0.0, pos[2] * 0.7, pos[3]]) kk = [1, 2, 3, 7] for k in range(len(aps)): means = np.array([aps[k][nl][:, 0].mean(axis=0) for nl in [0, 2, 1]]) @@ -746,6 +747,117 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d ax.set_xticklabels(["2%", "3%", "5%"]) ax.set_xlabel("laser power") + + dat = np.load(f"{folder3}/ribo_denoise_n2v.npy", allow_pickle=True).item() + ap_n2v = dat["ap_n2v"] + dat = np.load(f"{folder3}/ribo_denoise_n2s.npy", allow_pickle=True).item() + ap_n2s = dat["ap_n2s"] + dat = np.load(f"{folder3}/ribo_denoise.npy", allow_pickle=True).item() + navgs = dat["navgs"] + + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 5, subplot_spec=grid[-2:, :], + wspace=0.05, hspace=0.1) + iex = 3 + nl = 2 + ylim = [350, 500] + xlim = [200, 500] + + transl = mtransforms.ScaledTranslation(-18 / 72, 26 / 72, fig.dpi_scale_trans) + outlines_gt = utils.outlines_list(dat["masks_clean"][iex].copy(), + multiprocessing=False) + titlest = ["clean (300 frames averaged)", "noisy (4 frames averaged)", "denoised (Cellpose3)"] + for k in range(3): + if k == 0: + img = dat["clean"][iex].copy() + elif k == 1: + img = dat["noisy"][nl][iex].copy() + maskk = dat["masks_noisy"][nl][iex].copy() + ap = dat["ap_noisy"][nl][iex, 0] + else: + img = dat["imgs_dn"][nl][iex].copy() + maskk = dat["masks_dn"][nl][iex].copy() + ap = dat["ap_dn"][nl][iex, 0] + img = transforms.normalize99(img) + + ax = plt.subplot(grid1[0, k]) + pos = ax.get_position().bounds + ax.set_position([pos[0], pos[1] - 0.05, pos[2], pos[3]]) + ax.imshow(img, vmin=0., vmax=0.75, cmap="gray") + ax.set_title(titlest[k], color="k" if k < 2 else [0,0.5,0], fontsize="medium") + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.axis("off") + if k == 0: + ax.text(0, 1.25, "Denoising two-photon imaging in mice", fontsize="large", + fontstyle="italic", transform=ax.transAxes) + il = plot_label(ltr, il, ax, transl, fs_title) + + ax = plt.subplot(grid1[1, k]) + pos = ax.get_position().bounds + ax.set_position([pos[0], pos[1] - 0.05, pos[2], pos[3]]) + ax.imshow(img, vmin=0, vmax=1, cmap="gray") + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.axis("off") + #ax.set_title("segmentation") + if k == 0: + for o in outlines_gt: + ax.plot(o[:, 0], o[:, 1], color=[1, 0, 1], lw=1, ls="--") + ax.text(1, -0.15, "ground-truth", ha="right", transform=ax.transAxes) + else: + outlines = utils.outlines_list(maskk, multiprocessing=False) + for o in outlines: + ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1, ls="--") + ax.text(1, -0.15, f"AP@0.5 = {ap:.2f}", ha="right", transform=ax.transAxes) + + grid11 = matplotlib.gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=grid1[:, -2:], + wspace=0.3, hspace=0.1) + + + transl = mtransforms.ScaledTranslation(-35 / 72, 25 / 72, fig.dpi_scale_trans) + ax = plt.subplot(grid11[:, 0]) + pos = ax.get_position().bounds + ax.set_position([pos[0] + 0.05, pos[1] - 0.04, pos[2] * 0.8, pos[3]*0.9]) + aps = [dat["ap_noisy"], ap_n2v, ap_n2s, dat["ap_dn"]] + theight = [-0.9, 3, 2, 4] + kk = [1, 2, 3, 7] + titlesd[1] = "noisy\n(4 frames\naveraged)" + for k in range(len(aps)): + means = aps[k][nl, :12].mean(axis=0) + ax.plot(thresholds, means, color=cols[kk[k]]) + ax.text(1.15, 0.62 + theight[k] * 0.09, titlesd[kk[k]], + color=cols[kk[k]], + transform=ax.transAxes, ha="right") + ax.set_ylim([0, 0.8]) + ax.text(-0.18, 1.13, "Segmentation performance", fontstyle="italic", + transform=ax.transAxes, fontsize="large") + ax.set_ylabel("average precision (AP)") + ax.set_xlabel("IoU threshold") + ax.set_xticks(np.arange(0.5, 1.05, 0.25)) + ax.set_yticks(np.arange(0, 1.1, 0.2)) + ax.set_ylim([0, 0.83]) + ax.set_xlim([0.5, 1.0]) + il = plot_label(ltr, il, ax, transl, fs_title) + + ifrs = [slice(0, 12), slice(12, 20)] + for i, ifr in enumerate(ifrs): + ax = plt.subplot(grid11[:, i + 1]) + pos = ax.get_position().bounds + ax.set_position([pos[0] + 0.04 - i*0.01, pos[1] - 0.04, pos[2] * 0.8, pos[3]*0.9]) + nifr = ifr.stop - ifr.start + for k in range(len(aps)): + means = np.array([aps[k][nl][ifr, 0].mean(axis=0) for nl in range(len(aps[k]))]) + sems = np.array([aps[k][nl][ifr, 0].std(axis=0) / (nifr**0.5) for nl in range(len(aps[k]))]) + ax.errorbar(np.arange(0, len(means)), means, sems, color=cols[kk[k]]) + + ax.set_xticks(np.arange(0, len(navgs), 2)) + ax.set_xticklabels([f"{navg}" for navg in navgs[::2]]) + ax.set_yticks(np.arange(0, 1.1, 0.2)) + ax.set_ylim([0, 0.83]) + ax.set_xlabel("# of frames averaged") + ax.set_title("dense expression" if i==0 else "sparse expression", fontsize="medium") + + if save_fig: os.makedirs("figs/", exist_ok=True) fig.savefig("figs/fig2.pdf", dpi=100) @@ -1031,7 +1143,7 @@ def load_benchmarks_specialist(folder, thresholds=np.arange(0.5, 1.05, 0.05)): imgs_all.append(test_dns) masks_all.append(masks_dns) - dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3.npy", + dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3_all.npy", allow_pickle=True).item() istr = ["rec", "per", "seg", "perseg"] for k in range(len(istr)): @@ -1094,7 +1206,7 @@ def suppfig_specialist(folder, save_fig=True): fig = plt.figure(figsize=(9, 5), dpi=100) yratio = 9 / 5 - grid = plt.GridSpec(2, 4, figure=fig, left=0.02, right=0.98, top=0.96, bottom=0.1, + grid = plt.GridSpec(2, 4, figure=fig, left=0.02, right=0.96, top=0.96, bottom=0.1, wspace=0.15, hspace=0.2) titles = ["train - clean", "train - noisy", "test - noisy"] @@ -1114,7 +1226,7 @@ def suppfig_specialist(folder, save_fig=True): dy, dx = 20, 30 ni = 5 img0 = np.ones((ly + (ni - 1) * dy, lx + (ni - 1) * dx)) - ii = np.arange(0, 5)[::-1] if j == 2 else np.arange(0, 20 * ni, 20) + ii = np.arange(0, 5)[::-1] if j == 2 else np.arange(1, 20 * ni, 20)[::-1] if j < 2: x0, y0 = 20, 20 else: @@ -1186,6 +1298,128 @@ def suppfig_specialist(folder, save_fig=True): os.makedirs("figs/", exist_ok=True) fig.savefig("figs/suppfig_specialist.pdf", dpi=100) +def suppfig_impr(folder, save_fig=True): + aps_all = [[], []] + imgs_all, masks_all = [[], []], [[], []] + inds_all = [[], []] + diams = [[], []] + noise_types = ["poisson", "blur", "downsample"] + for noise_type in noise_types: + for j, ctype in enumerate(["cyto2", "nuclei"]): + nimg_test = 68 if ctype == "cyto2" else 111 + folder_name = ctype + root = Path(f"{folder}/images_{folder_name}/") + + dat = np.load(root / "noisy_test" / f"test_{noise_type}.npy", + allow_pickle=True).item() + test_data = dat["test_data"][:nimg_test] + test_noisy = dat["test_noisy"][:nimg_test] + masks_noisy = dat["masks_noisy"][:nimg_test] + masks_true = dat["masks_true"][:nimg_test] + masks_data = dat["masks_orig"][:nimg_test] + diam_test = dat["diam_test"][:nimg_test] + noise_levels = dat["noise_levels"][:nimg_test] + + dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3_all.npy", + allow_pickle=True).item() + + masks_denoised = dat["masks_perseg"][:nimg_test] + test_denoised = dat["test_perseg"][:nimg_test] + thresholds=np.arange(0.5, 1.05, 0.05) + ap_c, tp_d, fp_d, fn_d = metrics.average_precision(masks_true, masks_data, + threshold=thresholds) + ap_d, tp_d, fp_d, fn_d = metrics.average_precision(masks_true, masks_denoised, + threshold=thresholds) + ap_n, tp_n, fp_n, fn_n = metrics.average_precision(masks_true, masks_noisy, + threshold=thresholds) + + aps_all[j].append([ap_c, ap_n, ap_d]) + igood = np.nonzero(ap_d[:,0] > 0)[0] + impr = (ap_d[igood,0] - ap_n[igood,0]) / ap_n[igood,0] + ii = np.hstack((impr.argsort()[-2:][::-1], impr.argsort()[:2])) + ii = igood[ii] + imgs_all[j].append([np.array([test_data[i].squeeze(), test_noisy[i].squeeze(), test_denoised[i].squeeze()]) + for i in ii]) + masks_all[j].append([np.array([masks_data[i].squeeze(), masks_noisy[i].squeeze(), masks_denoised[i].squeeze()]) + for i in ii]) + diams[j].append(dat["diam_test"][ii]) + inds_all[j].append(ii) + + colors = [["darkblue", "royalblue", [0.46, 1, 0], "cyan", "orange", "maroon"], + ["darkblue", [0.46, 1, 0], "dodgerblue"]] + + titles = [["CellImageLibrary", "Cells : fluorescent", "Cells : nonfluorescent", + "Cell membranes", "Microscopy : other", "Non-microscopy"], + ["DSB 2018 / kaggle", "MoNuSeg (H&E)", "ISBI 2009 (fluorescent)"]] + + cinds = [[np.arange(0, 11), np.arange(11,28,1,int), np.arange(28,33,1,int), + np.arange(33,42,1,int), np.arange(42,55,1,int), + np.arange(55,68,1,int)], + [np.arange(0, 75), np.arange(75, 103), np.arange(103, 111)]] + + ddeg = ["noisy", "blurry", "downsampled"] + dcorr = ["denoised", "deblurred", "upsampled"] + dtitle = ["Denoising", "Deblurring", "Upsampling"] + + fig = plt.figure(figsize=(14,8)) + yratio = 14/10 + grid = plt.GridSpec(2, 5, hspace=0.3, wspace=0.5, + left=0.05, right=0.97, top=0.95, bottom=0.05) + il = 0 + transl = mtransforms.ScaledTranslation(-45 / 72, 5 / 72, fig.dpi_scale_trans) + + for c, ctype in enumerate(["cyto2", "nuclei"]): + for d in range(3): + imgs = imgs_all[c][d] + masks = masks_all[c][d] + inds = inds_all[c][d] + aps = aps_all[c][d] + + ax = plt.subplot(grid[c, d + 2*(d>0)]) + pos = ax.get_position().bounds + ax.set_position([pos[0], pos[1]+(pos[3]-pos[2]*yratio), pos[2], pos[2]*yratio]) + for k in range(len(cinds[c])): + ax.scatter(aps[1][cinds[c][k],0], aps[2][cinds[c][k],0], marker="x", + label=titles[c][k], color=colors[c][k]) + ax.plot([0, 1], [0, 1], color="k", lw=1, ls="--") + ax.set_xlabel(f"{ddeg[d]}, AP@0.5") + ax.set_ylabel(f"{dcorr[d]}, AP@0.5", color=[0, 0.5, 0]) + ax.text(-0.2, 1.05, dtitle[d], fontsize="large", transform=ax.transAxes, + fontstyle="italic") + il = plot_label(ltr, il, ax, transl, fs_title) + if d==0: + ax.legend(loc="lower center", bbox_to_anchor=(0.5, -1.3+c*0.4), fontsize="small") + + dstr = ["clean", "noisy", "denoised"] + diam_mean = 30 if ctype=="cyto2" else 17 + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(3, 4, subplot_spec=grid[c, 1:3], + wspace=0.15, hspace=0.05) + for j in range(4): + Ly, Lx = imgs[j][0].shape + yinds, xinds = plot.interesting_patch(masks[j][0], + bsize=min(Ly, Lx, int(300 * diams[0][0][j] / diam_mean))) + for k in range(3): + ax = plt.subplot(grid1[k, j]) + pos = ax.get_position().bounds + ax.set_position([pos[0]-0.01, pos[1] - 0.015*k, *pos[2:]]) + ax.imshow(imgs[j][k], vmin=0, vmax=1, cmap="gray") + ax.axis("off") + #outlines = utils.outlines_list(masks[j][k], multiprocessing=False) + #for o in outlines: + # ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--") + ax.set_ylim([yinds[0], yinds[-1]+1]) + ax.set_xlim([xinds[0], xinds[-1]+1]) + ax.text(1, -0.01, f"AP@0.5 = {aps[k][inds[j],0]:.2f}", ha="right", + va="top", fontsize="small", transform=ax.transAxes) + if j%2==0 and k==0: + istr = ["most improved", "least improved"] + ax.set_title(f"{istr[j//2]}", fontsize="medium", fontstyle="italic") + if j==0: + ax.text(-0.05, 0.5, dstr[k], ha="right", va="center", + rotation=90, transform=ax.transAxes, + color="k" if k<2 else [0., 0.5, 0]) + fig.savefig("figs/suppfig_impr.pdf", dpi=300) + def load_benchmarks_generalist(folder, noise_type="poisson", ctype="cyto2", thresholds=np.arange(0.5, 1.05, 0.05)): @@ -1209,20 +1443,15 @@ def load_benchmarks_generalist(folder, noise_type="poisson", ctype="cyto2", imgs_all.append(test_data) imgs_all.append(test_noisy) - dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3.npy", - allow_pickle=True).item() - test_dn = dat["test_perseg"][:nimg_test] - masks_dn = dat["masks_perseg"][:nimg_test] - imgs_all.append(test_dn) - masks_all.append(masks_dn) - - dat = np.load(root / "noisy_test" / f"test_{noise_type}_generalist_cp3.npy", - allow_pickle=True).item() - test_dn = dat["test_generalist"][:nimg_test] - masks_dn = dat["masks_generalist"][:nimg_test] - imgs_all.append(test_dn) - masks_all.append(masks_dn) - + dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3_all.npy", + allow_pickle=True).item() + istrs = ["perseg", "noise_spec", "data_spec", "gen"] + for istr in istrs: + test_dn = dat[f"test_{istr}"][:nimg_test] + masks_dn = dat[f"masks_{istr}"][:nimg_test] + imgs_all.append(test_dn) + masks_all.append(masks_dn) + # benchmarking aps = [] tps = [] @@ -1241,7 +1470,7 @@ def load_benchmarks_generalist(folder, noise_type="poisson", ctype="cyto2", diam_test) -def fig5(folder, save_fig=True): +def fig6(folder, save_fig=True): folders = [ "cyto2", "nuclei", "tissuenet", "livecell", "yeast_BF", "yeast_PhC", "bact_phase", "bact_fluor", "deepbacs" @@ -1264,8 +1493,9 @@ def fig5(folder, save_fig=True): diams = [utils.diameters(lbl)[0] for lbl in lbls] + gen_model = "/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039" model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, - model_type="denoise_cyto3") + pretrained_model=gen_model) seg_model = models.CellposeModel(gpu=True, model_type="cyto3") pscales = [1.5, 20., 1.5, 1., 5., 40., 3.] denoise.deterministic() @@ -1283,7 +1513,6 @@ def fig5(folder, save_fig=True): imgs[j][i], diameter=diams[i], channels=[0, 0], tile_overlap=0.5, flow_threshold=0.4, augment=True, bsize=224, niter=2000 if folders[i - 2] == "bact_phase" else None)[0]) - api = np.array( [metrics.average_precision(lbls, masks[i])[0][:, 0] for i in range(3)]) @@ -1299,65 +1528,70 @@ def fig5(folder, save_fig=True): print(ctype, noise_type, aps0[1:, :, 0].mean(axis=1)) aps.append(aps0) - fig = plt.figure(figsize=(14, 10), dpi=100) - yratio = 14 / 10 - grid = plt.GridSpec(4, 14, figure=fig, left=0.02, right=0.98, top=0.96, bottom=0.1, - wspace=0.05, hspace=0.3) + fig = plt.figure(figsize=(14, 7), dpi=100) + yratio = 14 / 7 + grid = plt.GridSpec(3, 14, figure=fig, left=0.02, right=0.97, top=0.97, bottom=0.1, + wspace=0.05, hspace=0.2) - grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(1, 6, subplot_spec=grid[0, :], + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(1, 8, subplot_spec=grid[0, :], wspace=0.4, hspace=0.15) - transl = mtransforms.ScaledTranslation(-40 / 72, 10 / 72, fig.dpi_scale_trans) + transl = mtransforms.ScaledTranslation(-0 / 72, 3 / 72, fig.dpi_scale_trans) il = 0 noise_type = ["poisson", "blur", "downsample"][i % 3] + ax = plt.subplot(grid1[0:2]) + pos = ax.get_position().bounds + im = plt.imread("figs/cellpose3_models.png") + yr = im.shape[0] / im.shape[1] + w = 0.22 + ax.set_position([0.0, pos[1]-0.08, w, w*yratio*yr]) + plt.imshow(im) + ax.axis("off") + ax.text(0.08, 1.02, "General restoration models", transform=ax.transAxes, + fontstyle="italic", fontsize="large") + il = plot_label(ltr, il, ax, transl, fs_title) + + transl = mtransforms.ScaledTranslation(-40 / 72, 20 / 72, fig.dpi_scale_trans) thresholds = np.arange(0.5, 1.05, 0.05) - cols0 = np.array(cols)[[0, 0, 7, 7]].copy() - cols0[-1] = np.array([0, 1, 0]) - cols0 = np.clip(cols0, 0, 1) - lss0 = ["-", "-", "-", "--"] - legstr0 = ["", u"\u2013 noisy image", u"\u2013 dataset-specific", u"-- one-click"] - theight = [0, 1, 3, 2] + cols0 = np.array([[0, 0, 0], [0, 0, 0], [0, 128, 0], [180, 229, 162], + [246, 198, 173], [192, 71, 29], ]) + cols0 = cols0 / 255 + lss0 = ["-", "-", "-","-", "-", "-"] + legstr0 = ["", u"\u2013 noisy image", u"\u2013 original", + u"\u2013 noise-specific", "\u2013 data-specific", u"-- one-click"] + theight = [0, 0,4,3,2,1] for i in range(6): ctype = "cellpose test set" if i < 3 else "nuclei test set" noise_type = ["denoising", "deblurring", "upsampling"][i % 3] - ax = plt.subplot(grid1[i]) + ax = plt.subplot(grid1[i+2]) pos = ax.get_position().bounds ax.set_position([ - pos[0] + (5 - i) * 0.01 - 0.02 + 0.03 * (i > 2), pos[1] - 0.05, + pos[0] + 0.025 * (i>2), pos[1] - 0.05, # (5 - i) * 0.01 - 0.02 + 0.03 * (i > 2) pos[2] * 0.92, pos[3] ]) - for k in range(1, len(aps[0])): - ax.plot(thresholds, aps[i][k].mean(axis=0), color=cols0[k], ls=lss0[k]) - if i == 0 or i == 3: - ax.text(0.43, 0.62 + 0.1 * theight[k], legstr0[k], color=cols0[k], - transform=ax.transAxes) + ax.plot(thresholds, aps[i][k].mean(axis=0), color=cols0[k], ls=lss0[k], lw=1) if i == 0 or i == 3: ax.set_ylabel("average precision (AP)") + ax.set_xlabel("IoU threshold") il = plot_label(ltr, il, ax, transl, fs_title) if i == 1 or i == 4: - ax.text(0.5, 1.3, ctype, transform=ax.transAxes, ha="center", + ax.text(0.5, 1.18, ctype, transform=ax.transAxes, ha="center", fontsize="large") - if i == 0: - ax.text(-0.35, 1.35, "One-click models", transform=ax.transAxes, - fontstyle="italic", fontsize="large") - + ax.set_ylim([0, 0.72]) - ax.set_xlabel("IoU threshold") - ax.set_xticks(np.arange(0.5, 1.05, 0.1)) + ax.set_xticks(np.arange(0.5, 1.05, 0.25)) ax.set_xlim([0.5, 1.0]) ax.set_title(f"{noise_type}", fontsize="medium") - #yr, xr = 200, 240 - - titlesj = ["clean", "noisy", "denoised"] + titlesj = ["clean", "noisy", "denoised (one-click)"] titlesi = [ "Tissuenet", "Livecell", "Yeaz bright-field", "YeaZ phase-contrast", "Omnipose phase-contrast", "Omnipose fluorescent", "DeepBacs" ] - colsj = cols0[[0, 1, 3]] + colsj = cols0[[0, 1, -1]] ly0 = 250 @@ -1371,15 +1605,10 @@ def fig5(folder, save_fig=True): mask_gt = lbls[i].copy() #outlines_gt = utils.outlines_list(mask_gt, multiprocessing=False) - for j in range(3): - #img = np.zeros((*imgs[j][i].shape[1:], 3)) - #img[:,:,1:] = imgs[j][i][[0,1]].transpose(1,2,0) - if imgs[j][i].ndim == 3: - imgs[j][i] = imgs[j][i][0] - - img = np.clip(transforms.normalize99(imgs[j][i].copy()), 0, 1) + for j in range(1, 3): + img = np.clip(transforms.normalize99(imgs[j][i].copy().squeeze()), 0, 1) for k in range(2): - ax = plt.subplot(grid[j + 1, 2 * i + k]) + ax = plt.subplot(grid[j, 2 * i + k]) pos = ax.get_position().bounds ax.set_position([ pos[0] + 0.003 * i - 0.00 * k, pos[1] - (2 - j) * 0.025 - 0.07, @@ -1417,11 +1646,9 @@ def fig5(folder, save_fig=True): if k == 0 and j == 0: ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes, fontsize="medium") - if save_fig: os.makedirs("figs/", exist_ok=True) - fig.savefig("figs/fig5.pdf", dpi=100) - + fig.savefig("figs/fig6.pdf", dpi=150) def load_seg_generalist(folder): folders = [ @@ -1489,7 +1716,7 @@ def load_seg_generalist(folder): return apcs, api, imgs, masks_true, masks_pred -def suppfig_generalist(folder, save_fig=True): +def fig5(folder, save_fig=True): thresholds = np.arange(0.5, 1.05, 0.05) apcs, api, imgs, masks_true, masks_pred = load_seg_generalist(folder) titlesi = [ @@ -1599,7 +1826,7 @@ def suppfig_generalist(folder, save_fig=True): if save_fig: os.makedirs("figs/", exist_ok=True) - fig.savefig("figs/suppfig_generalist.pdf", dpi=100) + fig.savefig("figs/fig5.pdf", dpi=100) if __name__ == "__main__": @@ -1633,10 +1860,14 @@ def suppfig_generalist(folder, save_fig=True): fig4(folder, save_fig=0, ctype="nuclei") plt.show() + # ex images + suppfig_impr(folder, save_fig=0) + plt.show() + # one-click + supergeneralist fig5(folder, save_fig=0) plt.show() - suppfig_generalist(folder, save_fig=0) + fig6(folder, save_fig=0) plt.show() From 1af4546ff48aa00aea1aaaec1d14160b507cbcbf Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 7 Sep 2024 06:13:19 -0400 Subject: [PATCH 21/22] adding oneclick button --- cellpose/gui/gui.py | 46 ++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py index 94296419..dd518b13 100644 --- a/cellpose/gui/gui.py +++ b/cellpose/gui/gui.py @@ -733,32 +733,31 @@ def make_buttons(self): self.l0.addWidget(self.denoiseBox, b, 0, 1, 9) b0 = 0 - self.denoiseBoxG.addWidget(QLabel("mode:"), b0, 0, 1, 3) - + # DENOISING self.DenoiseButtons = [] nett = [ - "filter image (settings below)", "clear restore/filter", + "filter image (settings below)", "denoise (please set cell diameter first)", "deblur (please set cell diameter first)", "upsample to 30. diameter (cyto3) or 17. diameter (nuclei) (please set cell diameter first) (disabled in 3D)", + "one-click model trained to denoise+deblur+upsample (please set cell diameter first)" ] - self.denoise_text = ["filter", "none", "denoise", "deblur", "upsample"] + self.denoise_text = ["none", "filter", "denoise", "deblur", "upsample", "one-click"] self.restore = None self.ratio = 1. - jj = 3 + jj = 0 + w = 3 for j in range(len(self.denoise_text)): self.DenoiseButtons.append( guiparts.DenoiseButton(self, self.denoise_text[j])) - w = 3 self.denoiseBoxG.addWidget(self.DenoiseButtons[-1], b0, jj, 1, w) - jj += w self.DenoiseButtons[-1].setFixedWidth(75) self.DenoiseButtons[-1].setToolTip(nett[j]) self.DenoiseButtons[-1].setFont(self.medfont) - b0 += 1 if j == 1 else 0 - jj = 0 if j == 1 else jj + b0 += 1 if j%2==1 else 0 + jj = 0 if j%2==1 else jj + w # b0+=1 self.save_norm = QCheckBox("save restored/filtered image") @@ -767,22 +766,23 @@ def make_buttons(self): self.save_norm.setChecked(True) # self.denoiseBoxG.addWidget(self.save_norm, b0, 0, 1, 8) - b0 += 1 - label = QLabel("Cellpose3 model type:") + b0 -= 3 + label = QLabel("restore-dataset:") label.setToolTip( - "choose model type and click [denoise], [deblur], or [upsample]") + "choose dataset and click [denoise], [deblur], [upsample], or [one-click]") label.setFont(self.medfont) - self.denoiseBoxG.addWidget(label, b0, 0, 1, 4) + self.denoiseBoxG.addWidget(label, b0, 6, 1, 3) + b0 += 1 self.DenoiseChoose = QComboBox() self.DenoiseChoose.setFont(self.medfont) - self.DenoiseChoose.addItems(["one-click", "nuclei"]) - self.DenoiseChoose.setFixedWidth(100) + self.DenoiseChoose.addItems(["cyto3", "cyto2", "nuclei"]) + self.DenoiseChoose.setFixedWidth(85) tipstr = "choose model type and click [denoise], [deblur], or [upsample]" self.DenoiseChoose.setToolTip(tipstr) - self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 5, 1, 4) + self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 6, 1, 3) - b0 += 1 + b0 += 2 # FILTERING self.filtBox = QCollapsible("custom filter settings") self.filtBox._toggle_btn.setFont(self.medfont) @@ -1019,7 +1019,7 @@ def enable_buttons(self): for i in range(len(self.DenoiseButtons)): self.DenoiseButtons[i].setEnabled(True) if self.load_3D: - self.DenoiseButtons[-1].setEnabled(False) + self.DenoiseButtons[-2].setEnabled(False) self.ModelButtonB.setEnabled(True) self.SizeButton.setEnabled(True) self.newmodel.setEnabled(True) @@ -2213,7 +2213,7 @@ def compute_restore(self): self.DenoiseChoose.setCurrentIndex(1) if "upsample" in self.restore: i = self.DenoiseChoose.currentIndex() - diam_up = 30. if i == 0 else 17. + diam_up = 30. if i==0 or i==1 else 17. print(diam_up, self.ratio) self.Diameter.setText(str(diam_up / self.ratio)) self.compute_denoise_model(model_type=model_type) @@ -2264,16 +2264,16 @@ def compute_denoise_model(self, model_type=None): self.progress.setValue(0) try: tic = time.time() - nstr = "cyto3" if self.DenoiseChoose.currentText( - ) == "one-click" else "nuclei" - print(model_type) + nstr = self.DenoiseChoose.currentText() + nstr.replace("-", "") self.clear_restore() model_name = model_type + "_" + nstr + print(model_name) # denoising model self.denoise_model = denoise.DenoiseModel(gpu=self.useGPU.isChecked(), model_type=model_name) self.progress.setValue(10) - diam_up = 30. if "cyto3" in model_name else 17. + diam_up = 30. if "cyto" in model_name else 17. # params channels = self.get_channels() From 77b75cb1b05a212d28507fc772e787995e02d794 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 7 Sep 2024 09:11:00 -0400 Subject: [PATCH 22/22] adding training documentation for denoising --- docs/restore.rst | 48 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/docs/restore.rst b/docs/restore.rst index c3864c44..be15e18c 100644 --- a/docs/restore.rst +++ b/docs/restore.rst @@ -5,11 +5,14 @@ Image Restoration The image restoration module ``denoise`` provides functions for restoring degraded images. There are two main classes, ``DenoiseModel`` for image restoration only, and -``CellposeDenoiseModel`` for image restoration and then segmentation. There are three types -of image restoration provided, denoising, deblurring, and upsampling, and for each of these -there are two models, one trained on the full ``cyto3`` training set and one trained on -the ``nuclei`` training set: ``'denoise_cyto3'``, ``'deblur_cyto3'``, ``'upsample_cyto3'``, -``'denoise_nuclei'``, ``'deblur_nuclei'``, ``'upsample_nuclei'``. +``CellposeDenoiseModel`` for image restoration and then segmentation. There are four types +of image restoration provided, denoising, deblurring, upsampling and one-click (trained on +all degradation types), and for each of these +there are three models, one trained on the full ``cyto3`` training set, one trained on the +``cyto2`` training set, and one trained on the ``nuclei`` training set: +``'denoise_cyto3'``, ``'deblur_cyto3'``, ``'upsample_cyto3'``, ``'oneclick_cyto3'``, +``'denoise_cyto2'``, ``'deblur_cyto2'``, ``'upsample_cyto2'``, ``'oneclick_cyto2'``, +``'denoise_nuclei'``, ``'deblur_nuclei'``, ``'upsample_nuclei'``, ``'oneclick_nuclei'``. DenoiseModel -------------- @@ -70,5 +73,36 @@ For more details refer to the API section. Command line usage --------------------- -These models can be used on the command line with input ``--restore_type`` and flag -``--chan2_restore``. +These models can be used on the command line with model_type input using ``--restore_type`` +and add flag ``--chan2_restore`` for restoring the optional nuclear channel, e.g.: + +:: + + python -m cellpose --dir /path/to/images --model_type cyto3 --restore_type denoise_cyto3 --diameter 25 --chan2_restore --chan 2 --chan2 1 + +Training your own models +-------------------------- + +It is also possible to train your own models for image restoration using the +``cellpose.denoise`` module. For example, to train a denoising (Poisson noise) +model with the cyto2 segmentation model with train_data and train_labels +(images and ``_flows.tif``): + +:: + + from cellpose import denoise + model = denoise.DenoiseModel(gpu=True, nchan=1) + + io.logger_setup() + model_path = model.train(train_data, train_labels, test_data=None, test_labels=None, + save_path=save_path, iso=True, + blur=0., downsample=0., poisson=0.8, + n_epochs=2000, learning_rate=0.001, + seg_model_type="/home/carsen/.cellpose/models/cyto2torch_0") + + +This training can also be performed on the command line: + +:: + + python cellpose/denoise.py --dir /path/to/images --noise_type poisson --seg_model_type cyto2 --diam_mean 30. \ No newline at end of file