Skip to content

Commit

Permalink
Batch mode implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
c-h-benedetti committed Nov 22, 2024
1 parent 3ad82a6 commit 85c892c
Show file tree
Hide file tree
Showing 9 changed files with 450 additions and 94 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,4 @@ venv/
# written by setuptools_scm
**/_version.py

models∕
models/*
src/microglia_analyzer/models/
16 changes: 11 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,31 @@ dependencies = [
"scikit-image",
"opencv-python-headless",
"pint",
"tifffile"
"tifffile",
"torch",
"tensorflow==2.16.1",
"skan"
]

# Required for documentation: tabs, sphinx_sphinx, myst_parser, sphinx_rtd_theme

[project.optional-dependencies]
testing = [
"tox",
"pytest", # https://docs.pytest.org/en/latest/contents.html
"pytest-cov", # https://pytest-cov.readthedocs.io/en/latest/
"pytest-qt", # https://pytest-qt.readthedocs.io/en/latest/
"pytest",
"pytest-cov",
"pytest-qt",
"napari",
"pyqt5",
"numpy",
"qtpy",
"scikit-image",
"opencv-python-headless",
"pint",
"tifffile"
"tifffile",
"torch",
"tensorflow==2.16.1",
"skan"
]

[project.entry-points."napari.manifest"]
Expand Down
193 changes: 159 additions & 34 deletions src/microglia_analyzer/_widget.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from qtpy.QtWidgets import (QWidget, QVBoxLayout, QGroupBox,
QSpinBox, QHBoxLayout, QPushButton,
QFileDialog, QComboBox, QLabel,
from qtpy.QtWidgets import (QWidget, QVBoxLayout, QGroupBox, QTableWidget,
QSpinBox, QHBoxLayout, QPushButton, QHeaderView,
QFileDialog, QComboBox, QLabel, QTableWidgetItem,
QSlider, QSpinBox, QFrame, QLineEdit)

from qtpy.QtCore import QThread, Qt

from PyQt5.QtGui import QFont, QDoubleValidator
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtGui import QFont, QDoubleValidator, QColor
from PyQt5.QtCore import pyqtSignal, Qt

import napari
from napari.utils.notifications import show_info
Expand All @@ -18,11 +18,16 @@
import re

from microglia_analyzer import TIFF_REGEX
from microglia_analyzer.utils import boxes_as_napari_shapes, BBOX_COLORS
from microglia_analyzer.ma_worker import MicrogliaAnalyzer
from microglia_analyzer.qt_workers import QtSegmentMicroglia, QtClassifyMicroglia
from microglia_analyzer.qt_workers import (QtSegmentMicroglia, QtClassifyMicroglia,
QtMeasureMicroglia, QtBatchRunners)

_IMAGE_LAYER_NAME = "µ-Image"
_SEGMENTATION_LAYER_NAME = "µ-Segmentation"
_IMAGE_LAYER_NAME = "µ-Image"
_SEGMENTATION_LAYER_NAME = "µ-Segmentation"
_CLASSIFICATION_LAYER_NAME = "µ-Classification"
_YOLO_LAYER_NAME = "µ-YOLO"
_SKELETON_LAYER_NAME = "µ-Skeleton"

class MicrogliaAnalyzerWidget(QWidget):

Expand Down Expand Up @@ -141,18 +146,47 @@ def segment_microglia_panel(self):
h_layout.addWidget(self.probability_threshold_label)
self.probability_threshold_slider = QSlider(Qt.Horizontal)
self.probability_threshold_slider.setRange(0, 100)
self.probability_threshold_slider.setValue(10)
self.probability_threshold_slider.setValue(5)
self.probability_threshold_slider.setTickInterval(1)
self.probability_threshold_slider.setTickPosition(QSlider.TicksBelow)
self.probability_threshold_slider.valueChanged.connect(self.proba_threshold_update)
h_layout.addWidget(self.probability_threshold_slider)
self.proba_value_label = QLabel("10%")
self.proba_value_label = QLabel("5%")
h_layout.addWidget(self.proba_value_label)
layout.addLayout(h_layout)

self.segment_microglia_group.setLayout(layout)
self.layout.addWidget(self.segment_microglia_group)

def reset_table(self):
classes = self.mam.classes
self.table.setRowCount(0)

items = []
if classes is not None:
items = [(QColor(BBOX_COLORS[i]), c) for i, c in classes.items()]
self.table.setRowCount(len(classes))

for row, (color, word) in enumerate(items):
color_item = QTableWidgetItem()
color_item.setBackground(color)
color_item.setFlags(Qt.ItemIsEnabled)
self.table.setItem(row, 0, color_item)

word_item = QTableWidgetItem(word)
word_item.setFlags(Qt.ItemIsEnabled)
self.table.setItem(row, 1, word_item)

def classes_table_ui(self):
self.table = QTableWidget(self)
self.table.setColumnCount(2)
self.table.setHorizontalHeaderLabels(["Color", "Class"])
self.table.horizontalHeader().setStretchLastSection(True)
self.table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Fixed)
self.table.setColumnWidth(0, 60)
self.table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
return self.table

def classify_microglia_panel(self):
self.classify_microglia_group = QGroupBox("Classification")
layout = QVBoxLayout()
Expand All @@ -163,15 +197,20 @@ def classify_microglia_panel(self):
self.classify_microglia_button.clicked.connect(self.classify_microglia)
layout.addWidget(self.classify_microglia_button)

# List of classes
self.table = self.classes_table_ui()
self.reset_table()
layout.addWidget(self.table)

# Minimum score for classification
h_layout = QHBoxLayout()
self.minimal_score_label = QLabel("Min score:")
h_layout.addWidget(self.minimal_score_label)
self.minimal_score_slider = QSlider(Qt.Horizontal)
self.minimal_score_slider.setRange(0, 100)
self.minimal_score_slider.setValue(50)
self.minimal_score_slider.setValue(15)
h_layout.addWidget(self.minimal_score_slider)
self.min_score_label = QLabel("50%")
self.min_score_label = QLabel("15%")
h_layout.addWidget(self.min_score_label)
self.minimal_score_slider.valueChanged.connect(self.min_score_update)
layout.addLayout(h_layout)
Expand All @@ -183,21 +222,16 @@ def measures_panel(self):
self.microglia_group = QGroupBox("Measures")
layout = QVBoxLayout()

self.skeletonize_microglia_button = QPushButton("🦴 Skeletonize")
self.skeletonize_microglia_button.setFont(self.font)
self.skeletonize_microglia_button.clicked.connect(self.skeletonize_microglia)
layout.addWidget(self.skeletonize_microglia_button)

self.export_control_images_button = QPushButton("📸 Export control image")
self.export_control_images_button.setFont(self.font)
self.export_control_images_button.clicked.connect(self.export_control_images)
layout.addWidget(self.export_control_images_button)

self.export_measures_button = QPushButton("📊 Export measures")
self.export_measures_button = QPushButton("📊 Measure")
self.export_measures_button.setFont(self.font)
self.export_measures_button.clicked.connect(self.export_measures)
layout.addWidget(self.export_measures_button)

self.run_batch_button = QPushButton("☑️ Run batch")
self.run_batch_button.setFont(self.font)
self.run_batch_button.clicked.connect(self.run_batch)
layout.addWidget(self.run_batch_button)

self.microglia_group.setLayout(layout)
self.layout.addWidget(self.microglia_group)

Expand Down Expand Up @@ -242,6 +276,7 @@ def segment_microglia(self):
self.mam.set_proba_threshold(self.probability_threshold_slider.value() / 100)

self.worker = QtSegmentMicroglia(self.pbr, self.mam)
self.total = 0
self.worker.update.connect(self.update_pbr)
self.active_worker = True

Expand All @@ -263,29 +298,119 @@ def classify_microglia(self):
self.set_active_ui(False)
self.thread = QThread()

self.mam.set_min_score(self.minimal_score_input.value() / 100)
self.mam.set_min_score(self.minimal_score_slider.value() / 100)

self.worker = QtSegmentMicroglia(self.pbr, self.mam)
self.worker = QtClassifyMicroglia(self.pbr, self.mam)
self.total = 0
self.worker.update.connect(self.update_pbr)
self.active_worker = True

self.worker.moveToThread(self.thread)
self.worker.finished.connect(self.show_microglia)
self.worker.finished.connect(self.show_classification)
self.thread.started.connect(self.worker.run)

self.thread.start()

def skeletonize_microglia(self):
pass
def export_measures(self):
self.pbr = progress()
self.pbr.set_description("Measuring microglia...")
self.set_active_ui(False)
self.thread = QThread()

self.worker = QtMeasureMicroglia(self.pbr, self.mam)
self.total = 0
self.worker.update.connect(self.update_pbr)
self.active_worker = True

def export_control_images(self):
pass
self.worker.moveToThread(self.thread)
self.worker.finished.connect(self.write_measures)
self.thread.started.connect(self.worker.run)

self.thread.start()

def run_batch(self):
self.total = len(self.get_all_tiff_files(self.sources_folder))
self.pbr = progress(total=self.total)
self.pbr.set_description("Running on folder...")
self.set_active_ui(False)
self.thread = QThread()

def export_measures(self):
pass
settings = {
'calibration': self.mam.calibration,
'cc_min_size': self.minimal_area_input.value(),
'proba_threshold': self.probability_threshold_slider.value() / 100,
'unet_path': os.path.dirname(self.mam.segmentation_model_path),
'yolo_path': os.path.dirname(os.path.dirname(self.mam.classification_model_path)),
'min_score': self.minimal_score_slider.value() / 100
}

self.worker = QtBatchRunners(self.pbr, self.sources_folder, settings)
self.worker.update.connect(self.update_pbr)
self.active_worker = True

self.worker.moveToThread(self.thread)
self.worker.finished.connect(self.end_batch)
self.thread.started.connect(self.worker.run)

self.thread.start()

# -------- Methods: ----------------------------------

def end_batch(self):
self.pbr.close()
self.set_active_ui(True)
show_info("Batch completed.")

def write_measures(self):
self.end_worker()
measures = self.mam.as_csv(self.images_combo.currentText())
skeleton = self.mam.skeleton
if _SKELETON_LAYER_NAME not in self.viewer.layers:
layer = self.viewer.add_image(skeleton, name=_SKELETON_LAYER_NAME, colormap='red', blending='additive')
else:
layer = self.viewer.layers[_SKELETON_LAYER_NAME]
layer.data = skeleton
if self.mam.calibration is not None:
self.set_calibration(*self.mam.calibration)
root_folder = os.path.join(self.sources_folder, "controls")
if not os.path.exists(root_folder):
os.makedirs(root_folder)
measures_path = os.path.join(root_folder, os.path.splitext(self.images_combo.currentText())[0] + "_measures.csv")
control_path = os.path.join(root_folder, os.path.splitext(self.images_combo.currentText())[0] + "_control.tif")
tifffile.imwrite(control_path, np.stack([self.mam.skeleton, self.mam.mask], axis=0))
with open(measures_path, 'w') as f:
f.write("\n".join(measures))
show_info(f"Microglia measured.")

def show_classification(self):
self.end_worker()
bindings = self.mam.bindings
self.reset_table()
# Showing bound classification
boxes, colors = boxes_as_napari_shapes(bindings.values())
layer = None
if _CLASSIFICATION_LAYER_NAME not in self.viewer.layers:
layer = self.viewer.add_shapes(boxes, name=_CLASSIFICATION_LAYER_NAME, edge_color=colors, face_color='#00000000', edge_width=4)
else:
layer = self.viewer.layers[_CLASSIFICATION_LAYER_NAME]
layer.data = boxes
layer.edge_colors = colors
# Showing raw classification
classification = self.mam.classifications
tps = [(c, None, b) for c, b in zip(classification['classes'], classification['boxes'])]
boxes, colors = boxes_as_napari_shapes(tps, True)
if _YOLO_LAYER_NAME not in self.viewer.layers:
layer = self.viewer.add_shapes(boxes, name=_YOLO_LAYER_NAME, edge_color=colors, face_color='#00000000', edge_width=4)
else:
layer = self.viewer.layers[_YOLO_LAYER_NAME]
layer.data = boxes
layer.edge_colors = colors
layer.edge_width = 4
# Update calibration
if self.mam.calibration is not None:
self.set_calibration(*self.mam.calibration)
show_info(f"Microglia classified.")

def show_microglia(self):
self.end_worker()
labeled = self.mam.mask
Expand All @@ -310,8 +435,8 @@ def set_active_ui(self, state):
self.minimal_area_input.setEnabled(state)
self.probability_threshold_slider.setEnabled(state)
self.classify_microglia_button.setEnabled(state)
self.skeletonize_microglia_button.setEnabled(state)
self.export_control_images_button.setEnabled(state)
self.run_batch_button.setEnabled(state)
self.run_batch_button.setEnabled(state)
self.export_measures_button.setEnabled(state)

def end_worker(self):
Expand Down
4 changes: 2 additions & 2 deletions src/microglia_analyzer/dl/unet2d_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,8 @@ def open_pair(input_path, mask_path, training, img_only):
raw_img = tifffile.imread(input_path)
raw_img = np.expand_dims(raw_img, axis=-1)
raw_mask = tifffile.imread(mask_path)
raw_mask = skeletonize(raw_mask)
raw_mask = binary_dilation(raw_mask)
# raw_mask = skeletonize(raw_mask)
# raw_mask = binary_dilation(raw_mask)
raw_mask = raw_mask.astype(np.float32)
raw_mask /= np.max(raw_mask)
raw_mask = np.expand_dims(raw_mask, axis=-1)
Expand Down
Loading

0 comments on commit 85c892c

Please sign in to comment.