From e37e11fe8dc74e44fc660719c14e1e12ea30232a Mon Sep 17 00:00:00 2001 From: ian-coccimiglio Date: Fri, 4 Oct 2024 12:50:53 -0700 Subject: [PATCH 1/2] Added GUI selection for RAdam parameter solver --- cellpose/gui/gui.py | 5 ++++- cellpose/gui/guiparts.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py index 4d92732a..2ee9dafe 100644 --- a/cellpose/gui/gui.py +++ b/cellpose/gui/gui.py @@ -295,6 +295,7 @@ def __init__(self, image=None, logger=None): "learning_rate": 0.1, "weight_decay": 0.0001, "n_epochs": 100, + "SGD": True, "model_name": "CP" + d.strftime("_%Y%m%d_%H%M%S"), } @@ -2170,13 +2171,15 @@ def train_model(self, restore=None, normalize_params=None): save_path = os.path.dirname(self.filename) print("GUI_INFO: name of new model: " + self.training_params["model_name"]) + print(f"GUI_INFO: SGD activated: {self.training_params['SGD']}") self.new_model_path, train_losses = train.train_seg( self.model.net, train_data=self.train_data, train_labels=self.train_labels, channels=self.channels, normalize=normalize_params, min_train_masks=0, - save_path=save_path, nimg_per_epoch=max(8, len(self.train_data)), SGD=True, + save_path=save_path, nimg_per_epoch=max(8, len(self.train_data)), learning_rate=self.training_params["learning_rate"], weight_decay=self.training_params["weight_decay"], n_epochs=self.training_params["n_epochs"], + SGD=self.training_params["SGD"], model_name=self.training_params["model_name"])[:2] # save train losses np.save(str(self.new_model_path) + "_train_losses.npy", train_losses) diff --git a/cellpose/gui/guiparts.py b/cellpose/gui/guiparts.py index 6a71a0f7..7c6deb56 100644 --- a/cellpose/gui/guiparts.py +++ b/cellpose/gui/guiparts.py @@ -233,6 +233,15 @@ def __init__(self, parent, model_strings): self.edits[-1].setFixedWidth(200) self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1) + yoff += 1 + use_SGD = "SGD" + self.useSGD = QCheckBox(f"{use_SGD}") + self.useSGD.setChecked(True) + # self.edits[-1].setText(str(use_SGD)) + # self.edits[-1].setFixedWidth(200) + # self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1) + self.l0.addWidget(self.useSGD, i+yoff, 1, 1, 1) + yoff += len(labels) yoff += 1 @@ -289,6 +298,7 @@ def accept(self, parent): "weight_decay": float(self.edits[1].text()), "n_epochs": int(self.edits[2].text()), "model_name": self.edits[3].text(), + "SGD": True if self.useSGD.isChecked() else False, #"use_norm": True if self.use_norm.isChecked() else False, } self.done(1) From 0d59d07565f5739c8cbc9ddf4594e6b2b555c0e4 Mon Sep 17 00:00:00 2001 From: ian-coccimiglio Date: Fri, 4 Oct 2024 13:11:14 -0700 Subject: [PATCH 2/2] Removed commented --- cellpose/gui/guiparts.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cellpose/gui/guiparts.py b/cellpose/gui/guiparts.py index 7c6deb56..caacfe1d 100644 --- a/cellpose/gui/guiparts.py +++ b/cellpose/gui/guiparts.py @@ -237,9 +237,6 @@ def __init__(self, parent, model_strings): use_SGD = "SGD" self.useSGD = QCheckBox(f"{use_SGD}") self.useSGD.setChecked(True) - # self.edits[-1].setText(str(use_SGD)) - # self.edits[-1].setFixedWidth(200) - # self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1) self.l0.addWidget(self.useSGD, i+yoff, 1, 1, 1) yoff += len(labels)