From 85c892c7d4cb2103ca7ddeb0f05df1123cffe9cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20H=2E=20Benedetti?= Date: Fri, 22 Nov 2024 12:33:44 +0100 Subject: [PATCH] Batch mode implemented --- .gitignore | 3 +- pyproject.toml | 16 +- src/microglia_analyzer/_widget.py | 193 +++++++++++++++--- src/microglia_analyzer/dl/unet2d_training.py | 4 +- .../experimental/classify_microglia.py | 30 ++- .../experimental/graph_explo.py | 49 +++++ src/microglia_analyzer/ma_worker.py | 133 ++++++++++-- src/microglia_analyzer/qt_workers.py | 72 ++++++- src/microglia_analyzer/utils.py | 44 ++-- 9 files changed, 450 insertions(+), 94 deletions(-) create mode 100644 src/microglia_analyzer/experimental/graph_explo.py diff --git a/.gitignore b/.gitignore index 282c6ae..21c5247 100644 --- a/.gitignore +++ b/.gitignore @@ -83,5 +83,4 @@ venv/ # written by setuptools_scm **/_version.py -models∕ -models/* +src/microglia_analyzer/models/ diff --git a/pyproject.toml b/pyproject.toml index e7155ea..1c2856f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,10 @@ dependencies = [ "scikit-image", "opencv-python-headless", "pint", - "tifffile" + "tifffile", + "torch", + "tensorflow==2.16.1", + "skan" ] # Required for documentation: tabs, sphinx_sphinx, myst_parser, sphinx_rtd_theme @@ -38,9 +41,9 @@ dependencies = [ [project.optional-dependencies] testing = [ "tox", - "pytest", # https://docs.pytest.org/en/latest/contents.html - "pytest-cov", # https://pytest-cov.readthedocs.io/en/latest/ - "pytest-qt", # https://pytest-qt.readthedocs.io/en/latest/ + "pytest", + "pytest-cov", + "pytest-qt", "napari", "pyqt5", "numpy", @@ -48,7 +51,10 @@ testing = [ "scikit-image", "opencv-python-headless", "pint", - "tifffile" + "tifffile", + "torch", + "tensorflow==2.16.1", + "skan" ] [project.entry-points."napari.manifest"] diff --git a/src/microglia_analyzer/_widget.py b/src/microglia_analyzer/_widget.py index 2a227ed..3d6d467 100644 --- a/src/microglia_analyzer/_widget.py +++ b/src/microglia_analyzer/_widget.py @@ -1,12 +1,12 @@ -from qtpy.QtWidgets import (QWidget, QVBoxLayout, QGroupBox, - QSpinBox, QHBoxLayout, QPushButton, - QFileDialog, QComboBox, QLabel, +from qtpy.QtWidgets import (QWidget, QVBoxLayout, QGroupBox, QTableWidget, + QSpinBox, QHBoxLayout, QPushButton, QHeaderView, + QFileDialog, QComboBox, QLabel, QTableWidgetItem, QSlider, QSpinBox, QFrame, QLineEdit) from qtpy.QtCore import QThread, Qt -from PyQt5.QtGui import QFont, QDoubleValidator -from PyQt5.QtCore import pyqtSignal +from PyQt5.QtGui import QFont, QDoubleValidator, QColor +from PyQt5.QtCore import pyqtSignal, Qt import napari from napari.utils.notifications import show_info @@ -18,11 +18,16 @@ import re from microglia_analyzer import TIFF_REGEX +from microglia_analyzer.utils import boxes_as_napari_shapes, BBOX_COLORS from microglia_analyzer.ma_worker import MicrogliaAnalyzer -from microglia_analyzer.qt_workers import QtSegmentMicroglia, QtClassifyMicroglia +from microglia_analyzer.qt_workers import (QtSegmentMicroglia, QtClassifyMicroglia, + QtMeasureMicroglia, QtBatchRunners) -_IMAGE_LAYER_NAME = "µ-Image" -_SEGMENTATION_LAYER_NAME = "µ-Segmentation" +_IMAGE_LAYER_NAME = "µ-Image" +_SEGMENTATION_LAYER_NAME = "µ-Segmentation" +_CLASSIFICATION_LAYER_NAME = "µ-Classification" +_YOLO_LAYER_NAME = "µ-YOLO" +_SKELETON_LAYER_NAME = "µ-Skeleton" class MicrogliaAnalyzerWidget(QWidget): @@ -141,18 +146,47 @@ def segment_microglia_panel(self): h_layout.addWidget(self.probability_threshold_label) self.probability_threshold_slider = QSlider(Qt.Horizontal) self.probability_threshold_slider.setRange(0, 100) - self.probability_threshold_slider.setValue(10) + self.probability_threshold_slider.setValue(5) self.probability_threshold_slider.setTickInterval(1) self.probability_threshold_slider.setTickPosition(QSlider.TicksBelow) self.probability_threshold_slider.valueChanged.connect(self.proba_threshold_update) h_layout.addWidget(self.probability_threshold_slider) - self.proba_value_label = QLabel("10%") + self.proba_value_label = QLabel("5%") h_layout.addWidget(self.proba_value_label) layout.addLayout(h_layout) self.segment_microglia_group.setLayout(layout) self.layout.addWidget(self.segment_microglia_group) + def reset_table(self): + classes = self.mam.classes + self.table.setRowCount(0) + + items = [] + if classes is not None: + items = [(QColor(BBOX_COLORS[i]), c) for i, c in classes.items()] + self.table.setRowCount(len(classes)) + + for row, (color, word) in enumerate(items): + color_item = QTableWidgetItem() + color_item.setBackground(color) + color_item.setFlags(Qt.ItemIsEnabled) + self.table.setItem(row, 0, color_item) + + word_item = QTableWidgetItem(word) + word_item.setFlags(Qt.ItemIsEnabled) + self.table.setItem(row, 1, word_item) + + def classes_table_ui(self): + self.table = QTableWidget(self) + self.table.setColumnCount(2) + self.table.setHorizontalHeaderLabels(["Color", "Class"]) + self.table.horizontalHeader().setStretchLastSection(True) + self.table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Fixed) + self.table.setColumnWidth(0, 60) + self.table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch) + return self.table + def classify_microglia_panel(self): self.classify_microglia_group = QGroupBox("Classification") layout = QVBoxLayout() @@ -163,15 +197,20 @@ def classify_microglia_panel(self): self.classify_microglia_button.clicked.connect(self.classify_microglia) layout.addWidget(self.classify_microglia_button) + # List of classes + self.table = self.classes_table_ui() + self.reset_table() + layout.addWidget(self.table) + # Minimum score for classification h_layout = QHBoxLayout() self.minimal_score_label = QLabel("Min score:") h_layout.addWidget(self.minimal_score_label) self.minimal_score_slider = QSlider(Qt.Horizontal) self.minimal_score_slider.setRange(0, 100) - self.minimal_score_slider.setValue(50) + self.minimal_score_slider.setValue(15) h_layout.addWidget(self.minimal_score_slider) - self.min_score_label = QLabel("50%") + self.min_score_label = QLabel("15%") h_layout.addWidget(self.min_score_label) self.minimal_score_slider.valueChanged.connect(self.min_score_update) layout.addLayout(h_layout) @@ -183,21 +222,16 @@ def measures_panel(self): self.microglia_group = QGroupBox("Measures") layout = QVBoxLayout() - self.skeletonize_microglia_button = QPushButton("🦴 Skeletonize") - self.skeletonize_microglia_button.setFont(self.font) - self.skeletonize_microglia_button.clicked.connect(self.skeletonize_microglia) - layout.addWidget(self.skeletonize_microglia_button) - - self.export_control_images_button = QPushButton("📸 Export control image") - self.export_control_images_button.setFont(self.font) - self.export_control_images_button.clicked.connect(self.export_control_images) - layout.addWidget(self.export_control_images_button) - - self.export_measures_button = QPushButton("📊 Export measures") + self.export_measures_button = QPushButton("📊 Measure") self.export_measures_button.setFont(self.font) self.export_measures_button.clicked.connect(self.export_measures) layout.addWidget(self.export_measures_button) + self.run_batch_button = QPushButton("☑️ Run batch") + self.run_batch_button.setFont(self.font) + self.run_batch_button.clicked.connect(self.run_batch) + layout.addWidget(self.run_batch_button) + self.microglia_group.setLayout(layout) self.layout.addWidget(self.microglia_group) @@ -242,6 +276,7 @@ def segment_microglia(self): self.mam.set_proba_threshold(self.probability_threshold_slider.value() / 100) self.worker = QtSegmentMicroglia(self.pbr, self.mam) + self.total = 0 self.worker.update.connect(self.update_pbr) self.active_worker = True @@ -263,29 +298,119 @@ def classify_microglia(self): self.set_active_ui(False) self.thread = QThread() - self.mam.set_min_score(self.minimal_score_input.value() / 100) + self.mam.set_min_score(self.minimal_score_slider.value() / 100) - self.worker = QtSegmentMicroglia(self.pbr, self.mam) + self.worker = QtClassifyMicroglia(self.pbr, self.mam) + self.total = 0 self.worker.update.connect(self.update_pbr) self.active_worker = True self.worker.moveToThread(self.thread) - self.worker.finished.connect(self.show_microglia) + self.worker.finished.connect(self.show_classification) self.thread.started.connect(self.worker.run) self.thread.start() - def skeletonize_microglia(self): - pass + def export_measures(self): + self.pbr = progress() + self.pbr.set_description("Measuring microglia...") + self.set_active_ui(False) + self.thread = QThread() + + self.worker = QtMeasureMicroglia(self.pbr, self.mam) + self.total = 0 + self.worker.update.connect(self.update_pbr) + self.active_worker = True - def export_control_images(self): - pass + self.worker.moveToThread(self.thread) + self.worker.finished.connect(self.write_measures) + self.thread.started.connect(self.worker.run) + + self.thread.start() + + def run_batch(self): + self.total = len(self.get_all_tiff_files(self.sources_folder)) + self.pbr = progress(total=self.total) + self.pbr.set_description("Running on folder...") + self.set_active_ui(False) + self.thread = QThread() - def export_measures(self): - pass + settings = { + 'calibration': self.mam.calibration, + 'cc_min_size': self.minimal_area_input.value(), + 'proba_threshold': self.probability_threshold_slider.value() / 100, + 'unet_path': os.path.dirname(self.mam.segmentation_model_path), + 'yolo_path': os.path.dirname(os.path.dirname(self.mam.classification_model_path)), + 'min_score': self.minimal_score_slider.value() / 100 + } + + self.worker = QtBatchRunners(self.pbr, self.sources_folder, settings) + self.worker.update.connect(self.update_pbr) + self.active_worker = True + + self.worker.moveToThread(self.thread) + self.worker.finished.connect(self.end_batch) + self.thread.started.connect(self.worker.run) + + self.thread.start() # -------- Methods: ---------------------------------- + def end_batch(self): + self.pbr.close() + self.set_active_ui(True) + show_info("Batch completed.") + + def write_measures(self): + self.end_worker() + measures = self.mam.as_csv(self.images_combo.currentText()) + skeleton = self.mam.skeleton + if _SKELETON_LAYER_NAME not in self.viewer.layers: + layer = self.viewer.add_image(skeleton, name=_SKELETON_LAYER_NAME, colormap='red', blending='additive') + else: + layer = self.viewer.layers[_SKELETON_LAYER_NAME] + layer.data = skeleton + if self.mam.calibration is not None: + self.set_calibration(*self.mam.calibration) + root_folder = os.path.join(self.sources_folder, "controls") + if not os.path.exists(root_folder): + os.makedirs(root_folder) + measures_path = os.path.join(root_folder, os.path.splitext(self.images_combo.currentText())[0] + "_measures.csv") + control_path = os.path.join(root_folder, os.path.splitext(self.images_combo.currentText())[0] + "_control.tif") + tifffile.imwrite(control_path, np.stack([self.mam.skeleton, self.mam.mask], axis=0)) + with open(measures_path, 'w') as f: + f.write("\n".join(measures)) + show_info(f"Microglia measured.") + + def show_classification(self): + self.end_worker() + bindings = self.mam.bindings + self.reset_table() + # Showing bound classification + boxes, colors = boxes_as_napari_shapes(bindings.values()) + layer = None + if _CLASSIFICATION_LAYER_NAME not in self.viewer.layers: + layer = self.viewer.add_shapes(boxes, name=_CLASSIFICATION_LAYER_NAME, edge_color=colors, face_color='#00000000', edge_width=4) + else: + layer = self.viewer.layers[_CLASSIFICATION_LAYER_NAME] + layer.data = boxes + layer.edge_colors = colors + # Showing raw classification + classification = self.mam.classifications + tps = [(c, None, b) for c, b in zip(classification['classes'], classification['boxes'])] + boxes, colors = boxes_as_napari_shapes(tps, True) + if _YOLO_LAYER_NAME not in self.viewer.layers: + layer = self.viewer.add_shapes(boxes, name=_YOLO_LAYER_NAME, edge_color=colors, face_color='#00000000', edge_width=4) + else: + layer = self.viewer.layers[_YOLO_LAYER_NAME] + layer.data = boxes + layer.edge_colors = colors + layer.edge_width = 4 + # Update calibration + if self.mam.calibration is not None: + self.set_calibration(*self.mam.calibration) + show_info(f"Microglia classified.") + def show_microglia(self): self.end_worker() labeled = self.mam.mask @@ -310,8 +435,8 @@ def set_active_ui(self, state): self.minimal_area_input.setEnabled(state) self.probability_threshold_slider.setEnabled(state) self.classify_microglia_button.setEnabled(state) - self.skeletonize_microglia_button.setEnabled(state) - self.export_control_images_button.setEnabled(state) + self.run_batch_button.setEnabled(state) + self.run_batch_button.setEnabled(state) self.export_measures_button.setEnabled(state) def end_worker(self): diff --git a/src/microglia_analyzer/dl/unet2d_training.py b/src/microglia_analyzer/dl/unet2d_training.py index 355e56c..016dad7 100644 --- a/src/microglia_analyzer/dl/unet2d_training.py +++ b/src/microglia_analyzer/dl/unet2d_training.py @@ -628,8 +628,8 @@ def open_pair(input_path, mask_path, training, img_only): raw_img = tifffile.imread(input_path) raw_img = np.expand_dims(raw_img, axis=-1) raw_mask = tifffile.imread(mask_path) - raw_mask = skeletonize(raw_mask) - raw_mask = binary_dilation(raw_mask) + # raw_mask = skeletonize(raw_mask) + # raw_mask = binary_dilation(raw_mask) raw_mask = raw_mask.astype(np.float32) raw_mask /= np.max(raw_mask) raw_mask = np.expand_dims(raw_mask, axis=-1) diff --git a/src/microglia_analyzer/experimental/classify_microglia.py b/src/microglia_analyzer/experimental/classify_microglia.py index c619b48..478da61 100644 --- a/src/microglia_analyzer/experimental/classify_microglia.py +++ b/src/microglia_analyzer/experimental/classify_microglia.py @@ -29,7 +29,7 @@ def calculate_iou(box1, box2): return 0.0 if union_area == 0 else (inter_area / union_area) class MicrogliaClassifier(object): - def __init__(self, model_path, image_path, iou_tr=0.8, score_tr=0.5, reload_yolo=False): + def __init__(self, model_path, image_path, iou_tr=0.25, score_tr=0.3, reload_yolo=False): if not os.path.isfile(model_path): raise FileNotFoundError(f"Model file {model_path} not found") if not model_path.endswith(".pt"): @@ -85,16 +85,24 @@ def remove_useless_boxes(self, boxes): return clean_boxes def inference(self): - results = self.model(self.image) - for img_results in results.xyxy: + tiles_manager = ImageTiler2D(640, 220, self.image.shape) + tiles = tiles_manager.image_to_tiles(self.image, False) # , True, 0, 255, np.uint8 + results = self.model(tiles) + self.bboxes = { + 'boxes' : [], + 'scores' : [], + 'classes': [], + } + for i, img_results in enumerate(results.xyxy): + print(i) + y, x = tiles_manager.layout[i].ul_corner boxes = img_results[:, :4].tolist() + boxes = [[box[0] + x, box[1] + y, box[2] + x, box[3] + y] for box in boxes] scores = img_results[:, 4].tolist() classes = img_results[:, 5].tolist() - self.bboxes = { - 'boxes' : boxes, - 'scores' : scores, - 'classes': classes, - } + self.bboxes['boxes'] += boxes + self.bboxes['scores'] += scores + self.bboxes['classes'] += classes def get_cleaned_bboxes(self): return self.remove_useless_boxes(self.bboxes) @@ -119,7 +127,7 @@ def draw_bounding_boxes(image, predictions, classes, exclude_class=1, thickness= box_colors=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255), (255, 255, 255), (0, 0, 0), (128, 128, 128), (128, 0, 0), (0, 128, 0), (0, 0, 128), (128, 128, 0), (128, 0, 128), (0, 128, 128), (128, 128, 128)] - image_with_boxes = image.copy() + image_with_boxes = normalize(image, 0, 255, np.uint8) image_with_boxes = cv2.cvtColor(image_with_boxes, cv2.COLOR_GRAY2BGR) for box, cls, score in zip(predictions['boxes'], predictions['classes'], predictions['scores']): @@ -136,9 +144,11 @@ def draw_bounding_boxes(image, predictions, classes, exclude_class=1, thickness= if __name__ == "__main__": mc = MicrogliaClassifier( "/home/benedetti/Documents/projects/2060-microglia/µyolo/µyolo-V051/weights/best.pt", - "/home/benedetti/Documents/projects/2060-microglia/data/raw-data/tiff-data/adulte 3.tif" + "/home/benedetti/Documents/projects/2060-microglia/data/raw-data/tiff-data/adulte 4.tif" ) mc.inference() + print(len(mc.bboxes['boxes'])) cleaned_bboxes = mc.get_cleaned_bboxes() + print(len(cleaned_bboxes['boxes'])) visual = draw_bounding_boxes(mc.image, mc.get_cleaned_bboxes(), mc.classes) cv2.imwrite("/tmp/visual.png", visual) \ No newline at end of file diff --git a/src/microglia_analyzer/experimental/graph_explo.py b/src/microglia_analyzer/experimental/graph_explo.py new file mode 100644 index 0000000..48a477c --- /dev/null +++ b/src/microglia_analyzer/experimental/graph_explo.py @@ -0,0 +1,49 @@ +import numpy as np +from skimage.morphology import skeletonize +from skimage.measure import label +import networkx as nx +import os +import tifffile +from skan import Skeleton, summarize + +def analyze_skeleton(mask, pixel_size=1.0): + skeleton = skeletonize(mask) + skel = Skeleton(skeleton) + branch_data = summarize(skel, separator='_') + + num_branches = len(branch_data) + num_leaves = np.sum(branch_data['branch_type'] == 1) + num_junctions = np.sum(branch_data['branch_type'] == 2) + avg_branch_length = np.mean(branch_data['branch_distance']) * pixel_size + total_length = branch_data['branch_distance'].sum() * pixel_size + max_branch_length = branch_data['branch_distance'].max() * pixel_size + + results = { + "number_of_branches" : num_branches, + "number_of_leaves" : num_leaves, + "number_of_junctions" : num_junctions, + "average_branch_length": round(avg_branch_length, 2), + "total_length" : round(total_length, 2), + "max_branch_length" : round(max_branch_length, 2) + } + + return results, skeleton + +# Exemple d'utilisation +if __name__ == "__main__": + source_dir = "/tmp/unet_working/predictions/epoch_261/" + img_name = "prediction_00004.tif" + img_path = os.path.join(source_dir, img_name) + + raw_mask = (tifffile.imread(img_path) > 0.1).astype(np.uint8) + labeled = label(raw_mask, connectivity=1) + mask = labeled == 70 + + tifffile.imwrite("/tmp/labeled.tif", labeled.astype(np.uint16)) + results, skeleton = analyze_skeleton(mask, 0.325) + tifffile.imwrite("/tmp/skeleton.tif", skeleton.astype(np.uint8) * 255) + + print("Résultats de l'analyse du squelette :") + max_key_len = max([len(key) for key in results.keys()]) + for key, value in results.items(): + print(f" | {key.replace('_', ' ').capitalize()}{' ' * (max_key_len - len(key) + 1)}: {value}") diff --git a/src/microglia_analyzer/ma_worker.py b/src/microglia_analyzer/ma_worker.py index 2b18d15..80caa20 100644 --- a/src/microglia_analyzer/ma_worker.py +++ b/src/microglia_analyzer/ma_worker.py @@ -7,6 +7,8 @@ import numpy as np from skimage.measure import label, regionprops from skimage import morphology +from skimage.morphology import skeletonize +from skan import Skeleton, summarize from microglia_analyzer.tiles.tiler import normalize from microglia_analyzer.utils import calculate_iou, normalize_batch @@ -40,9 +42,13 @@ def __init__(self, logging_f=None): # Object responsible for cutting images into tiles. self.tiles_manager = None # Size of the tiles (in pixels). - self.tile_size = None - # Overlap between the tiles (in pixels). - self.overlap = None + self.unet_tile_size = None + # Size of the tiles (in pixels). + self.yolo_tile_size = None + # Overlap for YOLO + self.yolo_overlap = None + # unet_overlap between the tiles (in pixels). + self.unet_overlap = None # Probability threshold for the segmentation (%). self.segmentation_threshold = 0.5 # Importance of the skeleton in the loss function. @@ -50,7 +56,7 @@ def __init__(self, logging_f=None): # Importance of the BCE in the BCE-dice loss function. self.unet_bce_coef = 0.7 # Score threshold for the classification (%). - self.score_threshold = 0.5 + self.score_threshold = 0.35 # Probability map of the segmentation. self.probability_map = None # Connected component minimum size threshold. @@ -60,19 +66,23 @@ def __init__(self, logging_f=None): # Set of bounding-boxes guessed by the classification model. self.bboxes = None # Maximum IoU threshold (%) for the classification. Beyond that, BBs are merged. - self.iou_threshold = 0.85 + self.iou_threshold = 0.25 # Bounding-boxes after they were cleaned. self.classifications = None # Dictionary of the bindings between the segmentation and the classification. self.bindings = None # Final mask of the segmentation. self.mask = None + # Graph metrics extracted from each label + self.graph_metrics = None + # Skeleton of the segmentation. + self.skeleton = None def log(self, message): if self.logging: self.logging(message) - def set_input_image(self, image, tile_size=512, overlap=128): + def set_input_image(self, image, unet_tile_size=512, unet_overlap=128, yolo_tile_size=640, yolo_overlap=128): """ Setter of the input image. Checks that the image is 2D before using it. @@ -81,8 +91,10 @@ def set_input_image(self, image, tile_size=512, overlap=128): if len(image.shape) != 2: raise ValueError("The input image must be 2D.") self.image = image - self.tile_size = tile_size - self.overlap = overlap + self.unet_tile_size = unet_tile_size + self.unet_overlap = unet_overlap + self.yolo_tile_size = yolo_tile_size + self.yolo_overlap = yolo_overlap def set_calibration(self, pixel_size, unit): """ @@ -161,11 +173,13 @@ def set_classification_model(self, path, use="best", reload=False): ) device = 'cuda' if torch.cuda.is_available() else 'cpu' self.classification_model.to(device) + self.classes = self.classification_model.names def segmentation_inference(self): shape = self.image.shape - tiles_manager = ImageTiler2D(self.tile_size, self.overlap, shape) - tiles = np.array(tiles_manager.image_to_tiles(self.image)) + tiles_manager = ImageTiler2D(self.unet_tile_size, self.unet_overlap, shape) + input_unet = normalize(self.image, 0, 1, np.float32) + tiles = np.array(tiles_manager.image_to_tiles(input_unet, False)) predictions = np.squeeze(self.segmentation_model.predict(tiles, batch_size=8)) normalize_batch(predictions) self.probability_map = tiles_manager.tiles_to_image(predictions) @@ -219,19 +233,22 @@ def segmentation_postprocessing(self): self.mask = label(self.mask, connectivity=2) def classification_inference(self): - yolo_input = normalize(self.image, 0, 255, np.uint8) - results = self.classification_model(yolo_input) - for img_results in results.xyxy: + yolo_input = self.image.copy() # normalize(self.image, 0, 255, np.uint8) + tiles_manager = ImageTiler2D(self.yolo_tile_size, self.yolo_overlap, self.image.shape) + tiles = tiles_manager.image_to_tiles(yolo_input, True, 0, 255, np.uint8) + results = self.classification_model(tiles) + self.bboxes = {'boxes': [], 'scores': [], 'classes': []} + for i, img_results in enumerate(results.xyxy): boxes = img_results[:, :4].tolist() + y, x = tiles_manager.layout[i].ul_corner + boxes = [[box[0] + x, box[1] + y, box[2] + x, box[3] + y] for box in boxes] scores = img_results[:, 4].tolist() classes = img_results[:, 5].tolist() - self.bboxes = { - 'boxes' : boxes, - 'scores' : scores, - 'classes': classes, - } + self.bboxes['boxes'] += boxes + self.bboxes['scores'] += scores + self.bboxes['classes'] += classes - def classification_postprocess(self): + def classification_postprocessing(self): """ Fusions boxes with an IoU greater than `iou_threshold`. The box with the highest score is kept, whatever the two classes were. @@ -276,12 +293,80 @@ def bind_classifications(self): bindings = {int(l): (None, 0.0, None) for l in np.unique(labeled) if l != 0} # label: (class, IoU) for region in regions: seg_bbox = list(map(int, region.bbox)) + bindings[region.label] = (0, 0.0, seg_bbox) for box, cls in zip(self.classifications['boxes'], self.classifications['classes']): - detect_bbox = list(map(int, box)) + x1, y1, x2, y2 = list(map(int, box)) + detect_bbox = [y1, x1, y2, x2] iou = calculate_iou(seg_bbox, detect_bbox) - if iou > bindings[region.label][1]: + if iou > 0.8 and iou > bindings[region.label][1]: bindings[region.label] = (cls, iou, seg_bbox) self.bindings = bindings + + def analyze_skeleton(self, mask): + skeleton = skeletonize(mask) + skel = Skeleton(skeleton) + branch_data = summarize(skel, separator='_') + factor = self.calibration[0] if self.calibration else 1.0 + + num_branches = len(branch_data) + num_leaves = np.sum(branch_data['branch_type'] == 1) + num_junctions = np.sum(branch_data['branch_type'] == 2) + avg_branch_length = np.mean(branch_data['branch_distance']) * factor + total_length = branch_data['branch_distance'].sum() * factor + max_branch_length = branch_data['branch_distance'].max() * factor + + results = { + "number_of_branches" : num_branches, + "number_of_leaves" : num_leaves, + "number_of_junctions" : num_junctions, + "average_branch_length": round(avg_branch_length, 2), + "total_length" : round(total_length, 2), + "max_branch_length" : round(max_branch_length, 2) + } + + return results, skeleton + + def analyze_as_graph(self): + labels = np.unique(self.mask) + results = {} + skeletons = np.zeros_like(self.mask) + for label in labels: + if label == 0: + continue + mask = (self.mask == label).astype(np.uint8) + results[label], skeleton = self.analyze_skeleton(mask) + skeletons = np.maximum(skeletons, skeleton) + self.graph_metrics = results + self.skeleton = skeletons + + def as_csv(self, identifier): + """ + Merge two dictionaries into a CSV file. + - dict1: A dictionary containing nested dictionaries with graph measures. + - dict2: A dictionary containing tuples where only the first two elements (IoU and Class) are relevant. + - output_file: The name of the output CSV file. + """ + common_labels = set(self.graph_metrics.keys()) & set(self.bindings.keys()) + if len(common_labels) == 0: + return None + + first_label = next(iter(common_labels)) + graph_measure_keys = list(self.graph_metrics[first_label].keys()) + headers = ["Identifier"] + graph_measure_keys + ["IoU", "Class"] + buffer = [", ".join(headers)] + + for i, label in enumerate(common_labels): + values = [""] + if i == 0: + values[0] = identifier + graph_measures = self.graph_metrics[label] + iou, class_value = self.bindings[label][:2] + class_value = self.classes[int(class_value)] if class_value is not None else "" + values += [graph_measures[key] for key in graph_measure_keys] + [iou, class_value] + line = ", ".join([str(v) for v in values]) + buffer.append(line) + + return buffer if __name__ == "__main__": @@ -295,5 +380,9 @@ def bind_classifications(self): ma.segmentation_inference() ma.segmentation_postprocessing() ma.classification_inference() - ma.classification_postprocess() + ma.classification_postprocessing() ma.bind_classifications() + ma.analyze_as_graph() + csv = ma.as_csv("adulte 3") + with open("/tmp/metrics.csv", "w") as f: + f.write("\n".join(csv)) diff --git a/src/microglia_analyzer/qt_workers.py b/src/microglia_analyzer/qt_workers.py index 3936ced..6d5b712 100644 --- a/src/microglia_analyzer/qt_workers.py +++ b/src/microglia_analyzer/qt_workers.py @@ -2,7 +2,10 @@ from PyQt5.QtCore import pyqtSignal import requests import os +import numpy as np from microglia_analyzer.utils import download_from_web +from microglia_analyzer.ma_worker import MicrogliaAnalyzer +import tifffile _MODELS = "https://raw.githubusercontent.com/MontpellierRessourcesImagerie/microglia-analyzer/refs/heads/main/src/microglia_analyzer/models.json" @@ -30,7 +33,7 @@ def _check_updates(self): local_version = 0 with open(v_path, 'r') as f: local_version = int(f.read().strip()) - if local_version < self.versions['µnet']['version']: + if local_version < int(self.versions['µnet']['version']): download_from_web(self.versions['µnet']['url'], self.model_path) print("Model updated.") return @@ -76,7 +79,7 @@ def _check_updates(self): local_version = 0 with open(v_path, 'r') as f: local_version = int(f.read().strip()) - if local_version < self.versions['µyolo']['version']: + if local_version < int(self.versions['µyolo']['version']): download_from_web(self.versions['µyolo']['url'], self.model_path) print("Model updated.") return @@ -95,4 +98,69 @@ def run(self): self.mga.set_classification_model(self.model_path) self.mga.classification_inference() self.mga.classification_postprocessing() + self.mga.bind_classifications() + self.finished.emit() + +class QtMeasureMicroglia(QObject): + + finished = pyqtSignal() + update = pyqtSignal(str, int, int) + + def __init__(self, pbr, mga): + super().__init__() + self.pbr = pbr + self.mga = mga + + def run(self): + self.mga.analyze_as_graph() + self.finished.emit() + +class QtBatchRunners(QObject): + + finished = pyqtSignal() + update = pyqtSignal(str, int, int) + + def __init__(self, pbr, source_dir, settings): + super().__init__() + self.pbr = pbr + self.source_dir = source_dir + self.settings = settings + self.images_pool = [f for f in os.listdir(source_dir) if f.endswith(".tif")] + self.csv_lines = [] + + def workflow(self, index): + img_path = os.path.join(self.source_dir, self.images_pool[index]) + img_data = tifffile.imread(img_path) + s = self.settings + ma = MicrogliaAnalyzer(lambda x: print(x)) + ma.set_input_image(img_data) + ma.set_calibration(*s['calibration']) + ma.set_segmentation_model(s['unet_path']) + ma.set_classification_model(s['yolo_path']) + ma.set_cc_min_size(s['cc_min_size']) + ma.set_proba_threshold(s['proba_threshold']) + ma.segmentation_inference() + ma.segmentation_postprocessing() + ma.set_min_score(s['min_score']) + ma.classification_inference() + ma.classification_postprocessing() + ma.bind_classifications() + ma.analyze_as_graph() + csv = ma.as_csv(self.images_pool[index]) + if index == 0: + self.csv_lines += csv + else: + self.csv_lines += csv[1:] + control_path = os.path.join(self.source_dir, "controls", self.images_pool[index]) + tifffile.imwrite(control_path, np.stack([ma.skeleton, ma.mask], axis=0)) + + def write_csv(self): + with open(os.path.join(self.source_dir, "controls", "results.csv"), 'w') as f: + f.write("\n".join(self.csv_lines)) + + def run(self): + for i in range(len(self.images_pool)): + self.workflow(i) + self.write_csv() + self.update.emit(self.images_pool[i], i+1, len(self.images_pool)) self.finished.emit() \ No newline at end of file diff --git a/src/microglia_analyzer/utils.py b/src/microglia_analyzer/utils.py index 8003951..0f371e2 100644 --- a/src/microglia_analyzer/utils.py +++ b/src/microglia_analyzer/utils.py @@ -3,27 +3,20 @@ import zipfile import tempfile import os -import json +import numpy as np import shutil from microglia_analyzer.tiles.tiler import normalize BBOX_COLORS = [ - (255, 0, 0), - ( 0, 255, 0), - ( 0, 0, 255), - (255, 255, 0), - (255, 0, 255), - ( 0, 255, 255), - (255, 255, 255), - ( 0, 0, 0), - (128, 128, 128), - (128, 0, 0), - ( 0, 128, 0), - ( 0, 0, 128), - (128, 128, 0), - (128, 0, 128), - ( 0, 128, 128), - (128, 128, 128) + '#FF000000', # Red + '#0000FF', # Blue + '#FFFF00', # Yellow + '#008000', # Dark green + '#FFA500', # Orange + '#00FFFF', # Cyan + '#800080', # Purple + '#FF1493', # Vivid pink + '#ADFF2F', # Lime ] def calculate_iou(box1, box2): @@ -78,6 +71,23 @@ def draw_bounding_boxes(image, predictions, classes, thickness=2): return image_with_boxes +def boxes_as_napari_shapes(collection, swap=False): + items = [] + colors = [] + for (cls, _, seg_bbox) in collection: + y1, x1, y2, x2 = map(int, seg_bbox) + if swap: + x1, y1, x2, y2 = map(int, seg_bbox) + rect = np.array([ + [y1, x1], # Haut-gauche + [y1, x2], # Haut-droite + [y2, x2], # Bas-droite + [y2, x1], # Bas-gauche + ]) + colors.append(BBOX_COLORS[int(cls)]) + items.append(rect) + return items, colors + def download_from_web(url, extract_to, timeout=100): if os.path.isdir(extract_to): shutil.rmtree(extract_to)