Skip to content

Commit

Permalink
Merge pull request #1008 from MouseLand/rev3
Browse files Browse the repository at this point in the history
Rev3
  • Loading branch information
carsen-stringer authored Sep 7, 2024
2 parents dc3848d + 5711b85 commit 94ed43d
Show file tree
Hide file tree
Showing 11 changed files with 738 additions and 426 deletions.
8 changes: 5 additions & 3 deletions cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1,
# slices from padding
# slc = [slice(0, self.nclasses) for n in range(imgs.ndim)] # changed from imgs.shape[n]+1 for first slice size
slc = [slice(0, imgs.shape[n] + 1) for n in range(imgs.ndim)]
slc[-3] = slice(0, 3)
slc[-3] = slice(0, net.nout)
slc[-2] = slice(ysub[0], ysub[-1] + 1)
slc[-1] = slice(xsub[0], xsub[-1] + 1)
slc = tuple(slc)
Expand Down Expand Up @@ -286,7 +286,8 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0
yf = np.zeros((Lz, nout, imgi.shape[-2], imgi.shape[-1]), np.float32)
styles = []
if ny * nx > batch_size:
ziterator = trange(Lz, file=tqdm_out)
ziterator = (trange(Lz, file=tqdm_out, mininterval=30)
if Lz > 1 else range(Lz))
for i in ziterator:
yfi, stylei = _run_tiled(net, imgi[i], augment=augment, bsize=bsize,
tile_overlap=tile_overlap)
Expand All @@ -297,7 +298,8 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0
ntiles = ny * nx
nimgs = max(2, int(np.round(batch_size / ntiles)))
niter = int(np.ceil(Lz / nimgs))
ziterator = trange(niter, file=tqdm_out)
ziterator = (trange(niter, file=tqdm_out, mininterval=30)
if Lz > 1 else range(niter))
for k in ziterator:
IMGa = np.zeros((ntiles * nimgs, nchan, ly, lx), np.float32)
for i in range(min(Lz - k * nimgs, nimgs)):
Expand Down
320 changes: 204 additions & 116 deletions cellpose/denoise.py

Large diffs are not rendered by default.

46 changes: 23 additions & 23 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,32 +733,31 @@ def make_buttons(self):
self.l0.addWidget(self.denoiseBox, b, 0, 1, 9)

b0 = 0
self.denoiseBoxG.addWidget(QLabel("mode:"), b0, 0, 1, 3)


# DENOISING
self.DenoiseButtons = []
nett = [
"filter image (settings below)",
"clear restore/filter",
"filter image (settings below)",
"denoise (please set cell diameter first)",
"deblur (please set cell diameter first)",
"upsample to 30. diameter (cyto3) or 17. diameter (nuclei) (please set cell diameter first) (disabled in 3D)",
"one-click model trained to denoise+deblur+upsample (please set cell diameter first)"
]
self.denoise_text = ["filter", "none", "denoise", "deblur", "upsample"]
self.denoise_text = ["none", "filter", "denoise", "deblur", "upsample", "one-click"]
self.restore = None
self.ratio = 1.
jj = 3
jj = 0
w = 3
for j in range(len(self.denoise_text)):
self.DenoiseButtons.append(
guiparts.DenoiseButton(self, self.denoise_text[j]))
w = 3
self.denoiseBoxG.addWidget(self.DenoiseButtons[-1], b0, jj, 1, w)
jj += w
self.DenoiseButtons[-1].setFixedWidth(75)
self.DenoiseButtons[-1].setToolTip(nett[j])
self.DenoiseButtons[-1].setFont(self.medfont)
b0 += 1 if j == 1 else 0
jj = 0 if j == 1 else jj
b0 += 1 if j%2==1 else 0
jj = 0 if j%2==1 else jj + w

# b0+=1
self.save_norm = QCheckBox("save restored/filtered image")
Expand All @@ -767,22 +766,23 @@ def make_buttons(self):
self.save_norm.setChecked(True)
# self.denoiseBoxG.addWidget(self.save_norm, b0, 0, 1, 8)

b0 += 1
label = QLabel("Cellpose3 model type:")
b0 -= 3
label = QLabel("restore-dataset:")
label.setToolTip(
"choose model type and click [denoise], [deblur], or [upsample]")
"choose dataset and click [denoise], [deblur], [upsample], or [one-click]")
label.setFont(self.medfont)
self.denoiseBoxG.addWidget(label, b0, 0, 1, 4)
self.denoiseBoxG.addWidget(label, b0, 6, 1, 3)

b0 += 1
self.DenoiseChoose = QComboBox()
self.DenoiseChoose.setFont(self.medfont)
self.DenoiseChoose.addItems(["one-click", "nuclei"])
self.DenoiseChoose.setFixedWidth(100)
self.DenoiseChoose.addItems(["cyto3", "cyto2", "nuclei"])
self.DenoiseChoose.setFixedWidth(85)
tipstr = "choose model type and click [denoise], [deblur], or [upsample]"
self.DenoiseChoose.setToolTip(tipstr)
self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 5, 1, 4)
self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 6, 1, 3)

b0 += 1
b0 += 2
# FILTERING
self.filtBox = QCollapsible("custom filter settings")
self.filtBox._toggle_btn.setFont(self.medfont)
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def enable_buttons(self):
for i in range(len(self.DenoiseButtons)):
self.DenoiseButtons[i].setEnabled(True)
if self.load_3D:
self.DenoiseButtons[-1].setEnabled(False)
self.DenoiseButtons[-2].setEnabled(False)
self.ModelButtonB.setEnabled(True)
self.SizeButton.setEnabled(True)
self.newmodel.setEnabled(True)
Expand Down Expand Up @@ -2213,7 +2213,7 @@ def compute_restore(self):
self.DenoiseChoose.setCurrentIndex(1)
if "upsample" in self.restore:
i = self.DenoiseChoose.currentIndex()
diam_up = 30. if i == 0 else 17.
diam_up = 30. if i==0 or i==1 else 17.
print(diam_up, self.ratio)
self.Diameter.setText(str(diam_up / self.ratio))
self.compute_denoise_model(model_type=model_type)
Expand Down Expand Up @@ -2264,16 +2264,16 @@ def compute_denoise_model(self, model_type=None):
self.progress.setValue(0)
try:
tic = time.time()
nstr = "cyto3" if self.DenoiseChoose.currentText(
) == "one-click" else "nuclei"
print(model_type)
nstr = self.DenoiseChoose.currentText()
nstr.replace("-", "")
self.clear_restore()
model_name = model_type + "_" + nstr
print(model_name)
# denoising model
self.denoise_model = denoise.DenoiseModel(gpu=self.useGPU.isChecked(),
model_type=model_name)
self.progress.setValue(10)
diam_up = 30. if "cyto3" in model_name else 17.
diam_up = 30. if "cyto" in model_name else 17.

# params
channels = self.get_channels()
Expand Down
34 changes: 18 additions & 16 deletions cellpose/gui/gui3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def avg3d(C):
"""
Ly, Lx = C.shape
# pad T by 2
T = np.zeros((Ly + 2, Lx + 2), np.float32)
M = np.zeros((Ly, Lx), np.float32)
T = np.zeros((Ly + 2, Lx + 2), "float32")
M = np.zeros((Ly, Lx), "float32")
T[1:-1, 1:-1] = C.copy()
y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
indexing="ij")
Expand Down Expand Up @@ -244,7 +244,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True):
vc = stroke[iz, 2]
if iz.sum() > 0:
# get points inside drawn points
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8")
pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
axis=-1)[:, np.newaxis, :]
mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
Expand All @@ -265,7 +265,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True):
elif ioverlap.sum() > 0:
ar, ac = ar[~ioverlap], ac[~ioverlap]
# compute outline of new mask
mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8)
mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8")
mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_NONE)
Expand All @@ -282,7 +282,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True):
pix = np.append(pix, np.vstack((ars, acs)), axis=-1)

mall = mall[:, pix[0].min():pix[0].max() + 1,
pix[1].min():pix[1].max() + 1].astype(np.float32)
pix[1].min():pix[1].max() + 1].astype("float32")
ymin, xmin = pix[0].min(), pix[1].min()
if len(zdraw) > 1:
mall, zfill = interpZ(mall, zdraw - zmin)
Expand Down Expand Up @@ -422,15 +422,15 @@ def update_ortho(self):
for j in range(2):
if j == 0:
if self.view == 0:
image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2)
image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy()
else:
image = self.stack_filtered[zmin:zmax, :,
x].transpose(1, 0, 2)
x].transpose(1, 0, 2).copy()
else:
image = self.stack[
zmin:zmax,
y, :] if self.view == 0 else self.stack_filtered[zmin:zmax,
y, :]
y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax,
y, :].copy()
if self.nchan == 1:
# show single channel
image = image[..., 0]
Expand Down Expand Up @@ -458,28 +458,30 @@ def update_ortho(self):
self.imgOrtho[j].setLevels(
self.saturation[0][self.currentZ])
elif self.color == 4:
image = image.astype(np.float32).mean(axis=-1).astype(np.uint8)
if image.ndim > 2:
image = image.astype("float32").mean(axis=2).astype("uint8")
self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
elif self.color == 5:
image = image.astype(np.float32).mean(axis=-1).astype(np.uint8)
if image.ndim > 2:
image = image.astype("float32").mean(axis=2).astype("uint8")
self.imgOrtho[j].setImage(image, autoLevels=False,
lut=self.cmap[0])
self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
self.pOrtho[0].setAspectLocked(lock=True, ratio=self.zaspect)
self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect)

else:
image = np.zeros((10, 10), np.uint8)
image = np.zeros((10, 10), "uint8")
self.imgOrtho[0].setImage(image, autoLevels=False, lut=None)
self.imgOrtho[0].setLevels([0.0, 255.0])
self.imgOrtho[1].setImage(image, autoLevels=False, lut=None)
self.imgOrtho[1].setLevels([0.0, 255.0])

zrange = zmax - zmin
self.layer_ortho = [
np.zeros((self.Ly, zrange, 4), np.uint8),
np.zeros((zrange, self.Lx, 4), np.uint8)
np.zeros((self.Ly, zrange, 4), "uint8"),
np.zeros((zrange, self.Lx, 4), "uint8")
]
if self.masksOn:
for j in range(2):
Expand All @@ -488,7 +490,7 @@ def update_ortho(self):
else:
cp = self.cellpix[zmin:zmax, y]
self.layer_ortho[j][..., :3] = self.cellcolors[cp, :]
self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype(np.uint8)
self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8")
if self.selected > 0:
self.layer_ortho[j][cp == self.selected] = np.array(
[255, 255, 255, self.opacity])
Expand All @@ -499,7 +501,7 @@ def update_ortho(self):
op = self.outpix[zmin:zmax, :, x].T
else:
op = self.outpix[zmin:zmax, y]
self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype(np.uint8)
self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8")

for j in range(2):
self.layerOrtho[j].setImage(self.layer_ortho[j])
Expand Down
49 changes: 17 additions & 32 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"lowhigh": None,
"percentile": None,
"normalize": True,
"norm3D": False,
"norm3D": True,
"sharpen_radius": 0,
"smooth_radius": 0,
"tile_norm_blocksize": 0,
Expand Down Expand Up @@ -263,7 +263,7 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
if (pretrained_model and not Path(pretrained_model).exists() and
np.any([pretrained_model == s for s in all_models])):
model_type = pretrained_model

# check if model_type is builtin or custom user model saved in .cellpose/models
if model_type is not None and np.any([model_type == s for s in all_models]):
if np.any([model_type == s for s in MODEL_NAMES]):
Expand All @@ -286,6 +286,10 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
models_logger.warning(
"pretrained_model path does not exist, using default model")
use_default = True
elif pretrained_model:
if pretrained_model[-13:] == "nucleitorch_0":
builtin = True
self.diam_mean = 17.

builtin = True if use_default else builtin
self.pretrained_model = model_path(
Expand Down Expand Up @@ -503,37 +507,18 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
del yf
else:
tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
iterator = trange(nimg, file=tqdm_out,
mininterval=30) if nimg > 1 else range(nimg)
styles = np.zeros((nimg, self.nbase[-1]), np.float32)
img = np.asarray(x)
if do_normalization:
img = transforms.normalize_img(img, **normalize_params)
if rescale != 1.0:
img = transforms.resize_image(img, rsz=rescale)
yf, style = run_net(self.net, img, bsize=bsize, augment=augment,
tile=tile, tile_overlap=tile_overlap)
if resample:
dP = np.zeros((2, nimg, shape[1], shape[2]), np.float32)
cellprob = np.zeros((nimg, shape[1], shape[2]), np.float32)
else:
dP = np.zeros(
(2, nimg, int(shape[1] * rescale), int(shape[2] * rescale)),
np.float32)
cellprob = np.zeros(
(nimg, int(shape[1] * rescale), int(shape[2] * rescale)),
np.float32)
for i in iterator:
img = np.asarray(x[i])
if do_normalization:
img = transforms.normalize_img(img, **normalize_params)
if rescale != 1.0:
img = transforms.resize_image(img, rsz=rescale)
yf, style = run_net(self.net, img, bsize=bsize, augment=augment,
tile=tile, tile_overlap=tile_overlap)
if resample:
yf = transforms.resize_image(yf, shape[1], shape[2])

cellprob[i] = yf[:, :, 2]
dP[:, i] = yf[:, :, :2].transpose((2, 0, 1))
if self.nclasses == 4:
if i == 0:
bd = np.zeros_like(cellprob)
bd[i] = yf[:, :, 3]
styles[i][:len(style)] = style
yf = transforms.resize_image(yf, shape[1], shape[2])
dP = np.moveaxis(yf[..., :2], source=-1, destination=0).copy()
cellprob = yf[..., 2]
styles = style
del yf, style
styles = styles.squeeze()

Expand Down
1 change: 1 addition & 0 deletions cellpose/resnet_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class CPnet(nn.Module):
def __init__(self, nbase, nout, sz, mkldnn=False, conv_3D=False, max_pool=True,
diam_mean=30.):
super().__init__()
self.nchan = nbase[0]
self.nbase = nbase
self.nout = nout
self.sz = sz
Expand Down
23 changes: 15 additions & 8 deletions cellpose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
test_probs=None, load_files=True, batch_size=8, learning_rate=0.005,
n_epochs=2000, weight_decay=1e-5, momentum=0.9, SGD=False, channels=None,
channel_axis=None, rgb=False, normalize=True, compute_flows=False,
save_path=None, save_every=100, nimg_per_epoch=None,
save_path=None, save_every=100, save_each=False, nimg_per_epoch=None,
nimg_test_per_epoch=None, rescale=True, scale_range=None, bsize=224,
min_train_masks=5, model_name=None):
"""
Expand Down Expand Up @@ -362,6 +362,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False.
save_path (str, optional): String - where to save the trained model. Defaults to None.
save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100.
save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False.
nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None.
nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None.
rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True.
Expand Down Expand Up @@ -444,10 +445,10 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
t0 = time.time()
model_name = f"cellpose_{t0}" if model_name is None else model_name
save_path = Path.cwd() if save_path is None else Path(save_path)
model_path = save_path / "models" / model_name
filename = save_path / "models" / model_name
(save_path / "models").mkdir(exist_ok=True)

train_logger.info(f">>> saving model to {model_path}")
train_logger.info(f">>> saving model to {filename}")

lavg, nsum = 0, 0
for iepoch in range(n_epochs):
Expand Down Expand Up @@ -518,15 +519,21 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
lavgt /= len(rperm)
lavg /= nsum
train_logger.info(
f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.4f}, time {time.time()-t0:.2f}s"
f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s"
)
lavg, nsum = 0, 0

if iepoch > 0 and iepoch % save_every == 0:
net.save_model(model_path)
net.save_model(model_path)
if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
if save_each and iepoch != n_epochs - 1: #separate files as model progresses
filename0 = str(filename) + f"_epoch_{iepoch:%04d}"
else:
filename0 = filename
train_logger.info(f"saving network parameters to {filename0}")
net.save_model(filename0)

net.save_model(filename)

return model_path
return filename


def train_size(net, pretrained_model, train_data=None, train_labels=None,
Expand Down
Loading

0 comments on commit 94ed43d

Please sign in to comment.