diff --git a/src/microglia_analyzer/__init__.py b/src/microglia_analyzer/__init__.py index deb452b..a26cc84 100644 --- a/src/microglia_analyzer/__init__.py +++ b/src/microglia_analyzer/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.2" +__version__ = "0.1.0" import re diff --git a/src/microglia_analyzer/_tests/test_tiles.py b/src/microglia_analyzer/_tests/test_tiles.py index b89bc46..4f6bb27 100644 --- a/src/microglia_analyzer/_tests/test_tiles.py +++ b/src/microglia_analyzer/_tests/test_tiles.py @@ -20,17 +20,20 @@ """ +# (patch_size, overlap, (height, width)) _PATCHES_SETTINGS = [ - (64, 0 , (64 , 64)), # no overlap - (64, 0 , (64 , 128)), # no overlap - (64, 0 , (128, 64)), # no overlap - (64, 0 , (128, 128)), # no overlap - (64, 32, (64 , 64)), # overlap - (64, 32, (64 , 128)), # overlap - (64, 32, (128, 64)), # overlap - (64, 32, (128, 128)) # overlap + ( 64, 0, ( 64, 64)), + ( 64, 0, ( 64, 128)), + ( 64, 0, ( 128, 64)), + ( 64, 0, ( 128, 128)), + ( 64, 32, ( 64, 64)), + ( 64, 32, ( 64, 128)), + ( 64, 32, ( 128, 64)), + ( 64, 32, ( 128, 128)), + (512, 128, (2048, 2048)) ] +# (Number of patches Y-axis, Number of patches X-axis) _GRID_GROUND_TRUTH = [ (1, 1), (1, 2), @@ -39,13 +42,17 @@ (1, 1), (1, 3), (3, 1), - (3, 3) + (3, 3), + (5, 5) ] +# Blending between patches (for tiles_to_image) _BLENDINGS = ['flat', 'gradient'] +# Number of channels for the image _N_CHANNELS = [1, 3, 5] +# Normalization targets: (target data type, lower bound, upper bound) _NORMALIZE_BOUNDS = [ (np.uint8 , 0 , 255), (np.uint16 , 0 , 65535), @@ -54,6 +61,7 @@ (np.float32, 0.2, 0.8) ] +# Random images to be normalized _IMAGE_NORMALIZE = [ np.random.randint(50, 200, (128, 128)).astype(np.uint8), np.random.randint(50, 200, (128, 128, 3)).astype(np.uint8), @@ -88,7 +96,8 @@ def test_normalize(dtype, lower_bound, upper_bound, image): """ Tests that for each pixel, the sum of the blending coefficients is equal to 1. - To do so, we just check that cutting and remerging the image doesn't change it. + To do so, we just check that cutting and remerging an image doesn't change it. + Input images are randomly generated. """ normalized_image = normalize(image, lower_bound, upper_bound, dtype) is_zero, is_unique_val, vals_match = False, False, False @@ -124,7 +133,7 @@ def test_wrong_settings(patch_size, overlap, shape, errType): ) def test_grid_size(patch_size, overlap, shape, target): """ - Tests that the number of tiles on each axis is correct. + Tests that the number of tiles on each axis is what we expect. """ pe = ImageTiler2D(patch_size, overlap, shape) assert pe.get_grid_size() == target @@ -141,7 +150,7 @@ def test_n_tiles(patch_size, overlap, shape, target): """ pe = ImageTiler2D(patch_size, overlap, shape) assert len(pe.get_layout()) == target[0] * target[1] - assert len(pe.get_blending_tiles()) == target[0] * target[1] + assert len(pe.blending_coefs) == target[0] * target[1] @pytest.mark.parametrize("patch_size, overlap, shape", _PATCHES_SETTINGS) diff --git a/src/microglia_analyzer/_widget_yolo_annotations.py b/src/microglia_analyzer/_widget_yolo_annotations.py deleted file mode 100644 index ef423bc..0000000 --- a/src/microglia_analyzer/_widget_yolo_annotations.py +++ /dev/null @@ -1,667 +0,0 @@ -from qtpy.QtWidgets import (QWidget, QVBoxLayout, QLineEdit, - QHBoxLayout, QPushButton, QLabel, - QFileDialog, QComboBox, QGroupBox) - -from scipy.ndimage import binary_fill_holes -from napari.utils.notifications import show_info - -import tifffile -import skimage - -import numpy as np -import os - -# Prefix of a layer name to be considered as a YOLO class. -_CLASS_PREFIX = "class." -# Name of the layer containing the current image. -_IMAGE_LAYER = "Image" -# Layer containing the manual annotations (masks) of microglia. -_MASKS_LAYER = "Masks" -# Colors assigned to each YOLO class. -_COLORS = [ - "#FF4D4D", - "#4DFF4D", - "#4D4DFF", - "#FFD700", - "#FF66FF", - "#66FFFF", - "#FF9900", - "#9933FF", - "#3399FF", - "#99FF33", - "#FF3399", - "#33FF99", - "#B20000", - "#006600", - "#800080", - "#808000" -] -# Function used to read images. -imread = tifffile.imread -# Indices to have the width and height from an image shape -_WIDTH_HEIGHT = (0, 2) -# Arguments to pass to the imread function. -ARGS = {} - -# A YOLO bounding-box == a tuple of 5 elements: -# - (int) The class to which this box belongs. -# - (float) x component of the box's center (between 0.0 and 1.0) in percentage of image width. -# - (float) y component of the box's center (between 0.0 and 1.0) in percentage of image height. -# - (float) width of the box (between 0.0 and 1.0) in percentage of image width. -# - (float) height of the box (between 0.0 and 1.0) in percentage of image height. - -class AnnotateBoundingBoxesWidget(QWidget): - - def __init__(self, napari_viewer): - super().__init__() - - # Active Napari viewer. - self.viewer = napari_viewer - # Folder in which are located the 'images' and 'annotations' folders. - self.root_directory = None - # Name of the folder in which images are located - self.sources_directory = None - # Name of the folder in which annotations are located - self.annotations_directory = None - # Name of the folder in which masks are located - self.masks_directory = None - # List of images ('.tif') in the 'images' folder. - self.images_list = [] - - self.layout = QVBoxLayout() - self.init_ui() - self.setLayout(self.layout) - - # -------- UI: ---------------------------------- - - def add_media_management_group_ui(self): - box = QGroupBox("Media management") - layout = QVBoxLayout() - box.setLayout(layout) - - # Label + button to select the source directory: - self.select_root_directory_button = QPushButton("📂 Root directory") - self.select_root_directory_button.clicked.connect(self.select_root_directory) - layout.addWidget(self.select_root_directory_button) - - # Label + text box for the inputs sub-folder's name: - inputs_name_label = QLabel("Inputs sub-folder:") - self.inputs_name = QComboBox() - self.inputs_name.currentIndexChanged.connect(self.set_sources_directory) - self.inputs_name.setEnabled(False) - self.inputs_name.addItem("---") - h_layout = QHBoxLayout() - h_layout.addWidget(inputs_name_label) - h_layout.addWidget(self.inputs_name) - layout.addLayout(h_layout) - - # Label + text box for the annotations sub-folder's name: - annotations_name_label = QLabel("Annotations sub-folder:") - self.annotations_name = QLabel() - self.annotations_name.setText("---") - self.annotations_name.setMaximumHeight(20) - font = self.annotations_name.font() - font.setBold(True) - self.annotations_name.setFont(font) - h_layout = QHBoxLayout() - h_layout.addWidget(annotations_name_label) - h_layout.addWidget(self.annotations_name) - layout.addLayout(h_layout) - - self.layout.addWidget(box) - - def add_classes_management_group_ui(self): - box = QGroupBox("Classes management") - layout = QVBoxLayout() - box.setLayout(layout) - - # Adds a new class with the current name in the text box: - h_laytout = QHBoxLayout() - self.new_name = QLineEdit() - h_laytout.addWidget(self.new_name) - self.add_yolo_class_button = QPushButton("🎯 New class") - self.add_yolo_class_button.clicked.connect(self.add_yolo_class) - h_laytout.addWidget(self.add_yolo_class_button) - layout.addLayout(h_laytout) - - # Label showing the number of boxes in each class - self.count_display_label = QLabel("") - layout.addWidget(self.count_display_label) - - self.layout.addWidget(box) - - def add_annotations_management_group_ui(self): - box = QGroupBox("Annotations management") - layout = QVBoxLayout() - box.setLayout(layout) - - # Label + combobox containing inputs list: - self.image_selector = QComboBox() - self.image_selector.currentIndexChanged.connect(self.open_image) - self.image_selector.addItem("---") - layout.addWidget(self.image_selector) - - # Button to save the annotations: - self.save_button = QPushButton("💾 Save annotations") - self.save_button.clicked.connect(self.save_state) - layout.addWidget(self.save_button) - - self.layout.addWidget(box) - - def add_masks_group_ui(self): - box = QGroupBox("Masks") - layout = QVBoxLayout() - box.setLayout(layout) - - # Button to add a new mask layer. - self.add_mask_button = QPushButton("🎭 Add mask") - self.add_mask_button.clicked.connect(self.add_mask_layer) - layout.addWidget(self.add_mask_button) - - # Button to fill the current label. - self.fill_button = QPushButton("🎨 Fill label") - self.fill_button.clicked.connect(self.fill_current_label) - layout.addWidget(self.fill_button) - - # Button to save the masks. - self.save_masks_button = QPushButton("💾 Save masks") - self.save_masks_button.clicked.connect(self.save_masks) - layout.addWidget(self.save_masks_button) - - self.layout.addWidget(box) - - def init_ui(self): - self.add_media_management_group_ui() - self.add_classes_management_group_ui() - self.add_annotations_management_group_ui() - self.add_masks_group_ui() - - # ----------------- CALLBACKS ------------------------------------------- - - def select_root_directory(self): - """ - Select the folder containing the "images" and "labels" sub-folders. - """ - directory = QFileDialog.getExistingDirectory(self, "Select root directory") - if not os.path.isdir(directory): - show_info("Invalid directory.") - return - self.set_root_directory(directory) - - def set_sources_directory(self): - """ - Whenever the user selects a new source directory, the content of the 'sources' folder is probed. - This function also checks if the 'annotations' folder exists, and creates it if not. - """ - source_folder = self.inputs_name.currentText() - annotations_folder = source_folder + "-labels" - masks_folder = source_folder + "-masks" - if (source_folder is None) or (source_folder == "---") or (source_folder == ""): - return - inputs_path = os.path.join(self.root_directory, source_folder) - annotations_path = os.path.join(self.root_directory, annotations_folder) - masks_path = os.path.join(self.root_directory, masks_folder) - if not os.path.isdir(inputs_path): - return False - if not os.path.isdir(annotations_path): - os.makedirs(annotations_path) - if not os.path.isdir(masks_path): - os.makedirs(masks_path) - self.sources_directory = source_folder - self.annotations_directory = annotations_folder - self.masks_directory = masks_path - self.annotations_name.setText(annotations_folder) - self.open_sources_directory() - - def add_yolo_class(self): - """ - Adds a new layer representing a YOLO class. - """ - class_name = self.get_new_class_name() - if _IMAGE_LAYER not in self.viewer.layers: - show_info("No image loaded.") - return - n_classes = len(self.get_classes()) - color = _COLORS[n_classes % len(_COLORS)] - l = self.viewer.layers.selection.active - if class_name is None: - return - # Clears the selection of the current layer. - if (l is not None) and (l.name.startswith(_CLASS_PREFIX)): - l.selected_data = set() - self.viewer.add_shapes( - name=class_name, - edge_color=color, - face_color="transparent", - opacity=0.8, - edge_width=3 - ) - - def open_image(self): - """ - Uses the value contained in the `self.image_selector` to find and open an image. - Reloads the annotations if some were already made for this image. - """ - current_image = self.image_selector.currentText() - self.clear_mask() - if (self.root_directory is None) or (current_image is None) or (current_image == "---") or (current_image == ""): - return - image_path = os.path.join(self.root_directory, self.sources_directory, current_image) - name, _ = os.path.splitext(current_image) - current_as_txt = name + ".txt" - current_as_tif = name + ".tif" - labels_path = os.path.join(self.root_directory, self.annotations_directory, current_as_txt) - masks_path = os.path.join(self.root_directory, self.masks_directory, current_as_tif) - if not os.path.isfile(image_path): - print(f"The image: '{current_image}' doesn't exist.") - return - print(image_path, ARGS, imread, current_image) - image_path = os.fsencode(image_path).decode('utf-8') - data = imread(image_path, **ARGS) - if _IMAGE_LAYER in self.viewer.layers: - self.viewer.layers[_IMAGE_LAYER].data = data - self.viewer.layers[_IMAGE_LAYER].contrast_limits = (np.min(data), np.max(data)) - self.viewer.layers[_IMAGE_LAYER].reset_contrast_limits_range() - else: - self.viewer.add_image(data, name=_IMAGE_LAYER) - self.deselect_all() - self.restore_classes_layers() - self.clear_classes_layers() - if os.path.isfile(labels_path): # If some annotations already exist for this image. - self.load_annotations(labels_path) - self.count_boxes() - if os.path.isfile(masks_path): - self.restore_mask_layer(masks_path) - - def restore_mask_layer(self, masks_path): - data = tifffile.imread(masks_path) - if _MASKS_LAYER in self.viewer.layers: - self.viewer.layers[_MASKS_LAYER].data = data - else: - self.viewer.add_labels(data, name=_MASKS_LAYER) - - def save_state(self): - """ - Saves the current annotations (bounding-boxes) in a '.txt' file. - The class index corresponds to the rank of the layers in the Napari stack. - """ - count = 0 - lines = [] - for l in self.viewer.layers: - if not l.name.startswith(_CLASS_PREFIX): - continue - lines += self.layer2yolo(l.name, count) - count += 1 - self.write_annotations(lines) - self.count_boxes() - - def add_mask_layer(self): - if not self.is_image_opened(): - return - shape = self.viewer.layers[_IMAGE_LAYER].data.shape[_WIDTH_HEIGHT[0]:_WIDTH_HEIGHT[1]] - self.viewer.add_labels(np.zeros(shape, dtype=np.uint8), name=_MASKS_LAYER) - - def save_masks(self): - name, _ = os.path.splitext(self.image_selector.currentText()) - current_as_tif = name + ".tif" - mask_path = os.path.join(self.masks_directory, current_as_tif) - if _MASKS_LAYER in self.viewer.layers: - tifffile.imwrite(mask_path, self.viewer.layers[_MASKS_LAYER].data) - show_info("Masks saved.") - - # ----------------- METHODS ------------------------------------------- - - def is_image_opened(self): - """ - Checks if an image is currently opened in the Napari viewer. - """ - return (_IMAGE_LAYER in self.viewer.layers) and (self.viewer.layers[_IMAGE_LAYER].data is not None) and len(self.viewer.layers[_IMAGE_LAYER].data) > 0 - - def clear_mask(self): - """ - Clears the mask layer. - """ - if _MASKS_LAYER in self.viewer.layers: - self.viewer.layers[_MASKS_LAYER].data = np.zeros_like(self.viewer.layers[_MASKS_LAYER].data) - - def fill_current_label(self): - if not _MASKS_LAYER in self.viewer.layers: - return - current_label = self.viewer.layers[_MASKS_LAYER].selected_label - mask = self.viewer.layers[_MASKS_LAYER].data == current_label - filled = binary_fill_holes(mask) > 0 - filled = filled.astype(np.uint8) * current_label - self.viewer.layers[_MASKS_LAYER].data = np.maximum(self.viewer.layers[_MASKS_LAYER].data, filled) - - def upper_corner(self, box): - """ - Locates the upper-right corner of a bounding-box having the Napari format. - Works only for rectangles. - The order of coordinates changes depending on the drawing direction of the rectangle!!! - Once transposed (.T), a Napari bounding box is composed of 2 arrays, one for Y and one for X. - - Args: - - box (np.array): A Napari bounding-box, as it can be found in the '.data' of a shape layer. - - Returns: - (np.array): A (Y, X) 2D representing the maximal coordinates on both axis. - """ - y, x = [np.max(axis) for axis in box.T] - height, width = self.viewer.layers[_IMAGE_LAYER].data.shape[_WIDTH_HEIGHT[0]:_WIDTH_HEIGHT[1]] - return [min(y, height-1), min(x, width-1)] - - def lower_corner(self, box): - """ - Locates the lower-left corner of a bounding-box having the Napari format. - Works only for rectangles. - The order of coordinates changes depending on the drawing direction of the rectangle!!! - Once transposed (.T), a Napari bounding box is composed of 2 arrays, one for Y and one for X. - - Args: - - box (np.array): A Napari bounding-box, as it can be found in the '.data' of a shape layer. - - Returns: - (np.array): A (Y, X) 2D representing the minimal coordinates on both axis. - """ - return [max(0, np.min(axis)) for axis in box.T] - - def yolo2bbox(self, bboxes): - """ - Takes a YOLO bounding-box and converts it to the Napari format. - The output can be used in a shape layer. - """ - xmin, ymin = bboxes[0]-bboxes[2]/2, bboxes[1]-bboxes[3]/2 - xmax, ymax = bboxes[0]+bboxes[2]/2, bboxes[1]+bboxes[3]/2 - return xmin, ymin, xmax, ymax - - def bbox2yolo(self, bbox): - """ - Converts a Napari bounding-box (from a shape layer) into a YOLO bounding-box. - Doesn't handle the class (int) by itself, only the coordinates. - - Args: - - bbox (np.array): Array containing coordinates of a bounding-box as in a shape layer. - - Returns: - (tuple): 4 floats representing the x-centroid, the y-centroid, the width and the height in YOLO format. - It means that all these coordinates are percentages of the image's dimensions. - """ - ymax, xmax = self.upper_corner(bbox) - ymin, xmin = self.lower_corner(bbox) - height, width = self.viewer.layers[_IMAGE_LAYER].data.shape[_WIDTH_HEIGHT[0]:_WIDTH_HEIGHT[1]] - x = (xmin + xmax) / 2 / width - y = (ymin + ymax) / 2 / height - w = (xmax - xmin) / width - h = (ymax - ymin) / height - return round(x, 3), round(y, 3), round(w, 3), round(h, 3) - - def layer2yolo(self, layer_name, index): - data = self.viewer.layers[layer_name].data - tuples = [] - for rectangle in data: - bbox = (index,) + self.bbox2yolo(rectangle) - tuples.append(bbox) - return tuples - - def write_annotations(self, tuples): - """ - Responsible for writing the annotations in a '.txt' file, and updating the '-classes.txt' file. - - Args: - - tuples (list): A list of tuples, each tuple containing the class index and the YOLO bounding-box. - """ - labels_folder = os.path.join(self.root_directory, self.annotations_directory) - name, _ = os.path.splitext(self.image_selector.currentText()) - current_as_txt = name + ".txt" - labels_path = os.path.join(labels_folder, current_as_txt) - with open(labels_path, "w") as f: - for row in tuples: - f.write(" ".join(map(str, row)) + "\n") - with open(os.path.join(self.root_directory, os.path.basename(self.sources_directory)+"-classes.txt"), "w") as f: - for c in self.get_classes(): - f.write(c + "\n") - show_info("Annotations saved.") - - def get_new_class_name(self): - """ - Probes the input text box and returns the name of the new class. - """ - name_candidate = self.new_name.text().lower().replace(" ", "-") - if name_candidate == "": - show_info("Empty name.") - return None - if name_candidate.startswith(_CLASS_PREFIX): - full_name = name_candidate - else: - full_name = _CLASS_PREFIX + name_candidate - self.new_name.setText("") - return full_name - - def set_root_directory(self, directory): - """ - Sets the root directory (folder in which the 'sources' and 'annotations' folders are) of the application. - Probes the content of the 'sources' folder to find all sub-folders, to propose them in the GUI's dropdown. - No sub-folder is selected by default. - - Args: - - directory (str): The absolute path to the root directory. - """ - folders = sorted([f for f in os.listdir(directory) if (not f.endswith('-labels')) and os.path.isdir(os.path.join(directory, f))]) - folders = ["---"] + folders - self.inputs_name.clear() - self.inputs_name.addItems(folders) - self.inputs_name.setEnabled(len(folders) > 1) - self.root_directory = directory - - def update_reader_fx(self): - global imread - global _WIDTH_HEIGHT - global ARGS - ARGS = {} - ext = self.images_list[0] - if (ext == "---") or (ext == ""): - return - ext = ext.split('.')[-1].lower() - if (ext == "tif") or (ext == "tiff"): - imread = tifffile.imread - _WIDTH_HEIGHT = (0, 2) - im = imread(os.path.join(self.root_directory, self.sources_directory, self.images_list[0])) - if len(im.shape) == 3: - _WIDTH_HEIGHT = (1, 3) - else: - imread = skimage.io.imread - _WIDTH_HEIGHT = (0, 2) - ARGS = {'as_gray': True} - - def open_sources_directory(self): - """ - Triggered when the user chooses a new source directory in the GUI's dropdown. - The content of the provided folder is probed to find all TIFF files. - If the folder is empty, a message is displayed, and the images list is set to ['---']. - The first item of the list is selected by default. - It gets opened automatically due to the signal 'currentIndexChanged'. - """ - inputs_path = os.path.join(self.root_directory, self.sources_directory) - self.images_list = sorted([f for f in os.listdir(inputs_path)]) - if len(self.images_list) == 0: # Didn't find any file in the folder. - show_info("Didn't find any image in the provided folder.") - self.images_list = ['---'] - else: - self.update_reader_fx() - self.image_selector.clear() - self.image_selector.addItems(self.images_list) - return True - - def get_classes(self): - """ - Probes the layers stack of Napari to find the classes layers. - These layers are found through the prefix '_CLASS_PREFIX'. - The order matters. - - Returns: - (list): A list of strings containing the name of the classes. - """ - classes = [] - for l in self.viewer.layers: - if l.name.startswith(_CLASS_PREFIX): - classes.append(l.name[len(_CLASS_PREFIX):]) - return classes - - def clear_classes_layers(self): - """ - Reset the data of each shape layer representing a YOLO class. - Used before loading the annotations of a new image. - """ - names = [l.name for l in self.viewer.layers] - for n in names: - if n.startswith(_CLASS_PREFIX): - self.viewer.layers[n].data = [] - - def restore_classes_layers(self): - """ - Parses the '-classes.txt' file to restore the classes layers. - The file contains the name of the classes, one per line and nothing else. - Creates the associated shape layers with the right colors. - """ - classes_path = os.path.join(self.root_directory, os.path.basename(self.sources_directory)+"-classes.txt") - if not os.path.isfile(classes_path): - show_info("No classes file found.") - return - classes = [] - with open(classes_path, "r") as f: - classes = [item for item in f.read().split('\n') if len(item.strip()) > 0] - for i, c in enumerate(classes): - basis = c.strip() - if len(basis) == 0: - continue - name = _CLASS_PREFIX + basis - if name in self.viewer.layers: - continue - color = _COLORS[i % len(_COLORS)] - self.viewer.add_shapes( - name=name, - edge_color=color, - face_color="transparent", - opacity=0.8, - edge_width=3 - ) - show_info(f"Classes restored: {classes}") - - def add_labels(self, data): - """ - Uses the dictionary of data to reset and refill shapes layers containing YOLO bounding-boxes. - The class index refers to the index in which shape layers appear in the layers stack of Napari. - """ - # Boxes are created according to the current image's size. - h, w = self.viewer.layers[_IMAGE_LAYER].data.shape[_WIDTH_HEIGHT[0]:_WIDTH_HEIGHT[1]] - class_layers = [l.name for l in self.viewer.layers if l.name.startswith(_CLASS_PREFIX)] - if max(data.keys()) >= len(class_layers): - show_info("Some classes are missing: Abort loading.") - return - for c, bbox_list in data.items(): - rectangles = [] - for bbox in bbox_list: - x1, y1, x2, y2 = self.yolo2bbox(bbox) - xmin = int(x1*w) - ymin = int(y1*h) - xmax = int(x2*w) - ymax = int(y2*h) - points = np.array([[ymin, xmin], [ymin, xmax], [ymax, xmax], [ymax, xmin]]) - rectangles.append(np.array(points)) - layer = self.viewer.layers[class_layers[c]] - layer.data = rectangles - layer.face_color='transparent' - layer.edge_color=_COLORS[c % len(_COLORS)] - layer.edge_width=3 - - def load_annotations(self, labels_path): - """ - Loads a file (.txt) containing annotations over the currently opened image. - Parses the file and expects for each line: - - (int) The box's class - - (float) The box's x component - - (float) The box's y component - - (float) The box's width - - (float) The box's height - All these values are separated with spaces. - The internal structure (a dictionary) makes a list of boxes per class index: - data[class_index] = [(x1, y1, w1, h1), (x2, y2, w2, h2)] - - Args: - - labels_path (str): The absolute path to the ".txt" file. - """ - lines = [] - with open(labels_path, "r") as f: - lines = f.read().split('\n') - if len(lines) == 0: - return - data = dict() - for line in lines: - if line == "": - continue - c, x, y, w, h = line.split(" ") - c, x, y, w, h = int(c), float(x), float(y), float(w), float(h) - data.setdefault(c, []).append((x, y, w, h)) - self.add_labels(data) - - def deselect_all(self): - """ - Deselects everything in a shape layer. - It is required when you flush the content of a shape layer when you open a new image. - If you flush with an active selection, you will get an "index out of range" right away. - Only targets shape layers representing YOLO classes. - """ - for l in self.viewer.layers: - if not l.name.startswith(_CLASS_PREFIX): - continue - l.mode = 'pan_zoom' - l.selected_data = set() - - def count_boxes(self): - """ - Counts the number of boxes in each class to make sure annotations are balanced. - No fix is provded, just a display of the current state. - The whole annotations folder is probed to count the boxes. - The current image is not taken into account. - """ - annotations_path = os.path.join(self.root_directory, self.annotations_directory) - counts = dict() - for f in os.listdir(annotations_path): - if not f.endswith(".txt"): - continue - with open(os.path.join(annotations_path, f), "r") as file: - lines = file.read().split('\n') - for line in lines: - if line == "": - continue - c, _, _, _, _ = line.split(" ") - c = int(c) - counts[c] = counts.get(c, 0) + 1 - self.update_count_display(counts) - - def update_count_display(self, counts): - """ - Updates the content of the table (in the GUI) displaying how many boxes are in each class. - Beyond 8 classes , the display could have an issue with the height of the QGroupBox. - - Args: - - counts (str): is a dictionary where the key is the class index and the value is the number of boxes. - """ - classes = self.get_classes() - total_count = sum(counts.values()) - text = '' - for class_idx, class_name in enumerate(classes): - count = counts.get(class_idx, 0) - text += f''' - - - - - - ''' - text += "
- {class_name} - - {count} ({round((count/total_count*100) if total_count > 0 else 0, 1)}%) -
" - self.count_display_label.setText(text) diff --git a/src/microglia_analyzer/dl/losses.py b/src/microglia_analyzer/dl/losses.py index 0b32d61..b4bc5ac 100644 --- a/src/microglia_analyzer/dl/losses.py +++ b/src/microglia_analyzer/dl/losses.py @@ -1,25 +1,5 @@ import tensorflow as tf -def jaccard_loss(y_true, y_pred): - y_true = tf.cast(y_true, tf.float32) - y_pred = tf.cast(y_pred, tf.float32) - intersection = tf.reduce_sum(y_true * y_pred) - union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection - return 1 - (intersection + 1) / (union + 1) - -def dice_loss(y_true, y_pred): - y_true = tf.cast(y_true, tf.float32) - y_pred = tf.cast(y_pred, tf.float32) - intersection = tf.reduce_sum(y_true * y_pred) - return 1 - (2. * intersection + 1) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + 1) - -def bce_dice_loss(bce_coef=0.5): - def bcl(y_true, y_pred): - bce = tf.keras.losses.binary_crossentropy(y_true, y_pred) - dice = dice_loss(y_true, y_pred) - return bce_coef * bce + (1.0 - bce_coef) * dice - return bcl - def focal_loss(gamma=2.0, alpha=5.75): def focal_loss_fixed(y_true, y_pred): y_true = tf.cast(y_true, tf.float32) @@ -41,20 +21,36 @@ def loss(y_true, y_pred): return 1 - (true_pos + 1) / (true_pos + alpha * false_neg + beta * false_pos + 1) return loss +def skeleton_loss(y_true, y_pred): + inter = tf.reduce_sum(y_true * y_pred) / tf.reduce_sum(y_true) + mse_score = tf.reduce_mean(tf.square(y_true - y_pred)) + mean_constraint = tf.abs(tf.reduce_mean(y_pred) - tf.reduce_mean(y_true)) + return 1.0 - inter + mse_score + 0.1 * mean_constraint + + +# - - - - - Loss depending on the objects skeleton - - - - - # + def skeleton_recall(y_true, y_pred): intersection = tf.reduce_sum(y_true * y_pred) recall = intersection / (tf.reduce_sum(y_true) + 1e-8) return 1 - recall +def dice_loss(y_true, y_pred): + y_true = tf.cast(y_true, tf.float32) + y_pred = tf.cast(y_pred, tf.float32) + intersection = tf.reduce_sum(y_true * y_pred) + return 1 - (2. * intersection + 1) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + 1) + +def bce_dice_loss(bce_coef=0.5): + def bcl(y_true, y_pred): + bce = tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice = dice_loss(y_true, y_pred) + return bce_coef * bce + (1.0 - bce_coef) * dice + return bcl + def dice_skeleton_loss(skeleton_coef=0.5, bce_coef=0.5): bdl = bce_dice_loss(bce_coef) - def dsl(y_true, y_pred): + def _dice_skeleton_loss(y_true, y_pred): y_pred = tf.square(y_pred) return (1.0 - skeleton_coef) * bdl(y_true, y_pred) + skeleton_coef * skeleton_recall(y_true, y_pred) - return dsl - -def skeleton_loss(y_true, y_pred): - inter = tf.reduce_sum(y_true * y_pred) / tf.reduce_sum(y_true) - mse_score = tf.reduce_mean(tf.square(y_true - y_pred)) - mean_constraint = tf.abs(tf.reduce_mean(y_pred) - tf.reduce_mean(y_true)) - return 1.0 - inter + mse_score + 0.1 * mean_constraint \ No newline at end of file + return _dice_skeleton_loss \ No newline at end of file diff --git a/src/microglia_analyzer/dl/unet2d_training.py b/src/microglia_analyzer/dl/unet2d_training.py index 016dad7..d6b715d 100644 --- a/src/microglia_analyzer/dl/unet2d_training.py +++ b/src/microglia_analyzer/dl/unet2d_training.py @@ -12,28 +12,30 @@ import pandas as pd from tabulate import tabulate -from microglia_analyzer.dl.losses import (dice_loss, bce_dice_loss, - skeleton_recall, dice_skeleton_loss) +# from microglia_analyzer.dl.losses import (dice_loss, bce_dice_loss, +# skeleton_recall, dice_skeleton_loss) +from losses import (dice_loss, bce_dice_loss, skeleton_recall, dice_skeleton_loss) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' -# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" import tensorflow as tf from tensorflow.keras.models import Model from tensorflow.keras.callbacks import (ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, Callback) -from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, - UpSampling2D, concatenate, Activation, Conv2DTranspose) +from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, + Dropout, BatchNormalization, + UpSampling2D, concatenate, Add, + Conv2DTranspose, Activation, Multiply) from tensorflow.keras.optimizers import Adam from tensorflow.keras.utils import plot_model -from tensorflow.keras.losses import BinaryCrossentropy # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # SETTINGS # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -#@markdown # ⭐ 1. SETTINGS +# ⭐ 1. SETTINGS """ @@ -45,6 +47,8 @@ - `working_directory`: Folder in which the training, validation and testing folders will be created. - `model_name_prefix`: Prefix of the model name. Will be part of the folder name in `models_path`. - `reset_local_data` : If True, the locally copied training, validation and testing folders will be re-imported. +- `remove_wrong_data`: If True, the data that is not useful will be deleted from the data folder. +- `data_usage` : Path to a JSON file containing how each input file should be used (for training, validation or testing). - `validation_percentage`: Percentage of the data that will be used for validation. This data will be moved to the validation folder. - `batch_size` : Number of images per batch. @@ -54,30 +58,39 @@ - `dropout_rate` : Dropout rate. - `optimizer` : Optimizer used for the training. - `learning_rate` : Learning rate at which the optimizer is initialized +- `skeleton_coef` : Coefficient of the skeleton loss. +- `bce_coef` : Coefficient of the binary cross-entropy loss. +- `early_stop_patience` : Number of epochs without improvement before stopping the training. +- `dilation_kernel` : Kernel used for the dilation of the skeleton. +- `loss` : Loss function used for the training. - `use_data_augmentation`: If True, data augmentation will be used. - `use_mirroring` : If True, random mirroring will be used. - `use_gaussian_noise` : If True, random gaussian noise will be used. -- `use_random_rotations` : If True, random rotation of 90, 180 or 270 degrees will be used. +- `noise_scale` : Scale of the gaussian noise (range of values). +- `use_random_rotations` : If True, random rotations will be used. +- `angle_range` : Range of the random rotations. The angle will be in [angle_range[0], angle_range[1]]. - `use_gamma_correction` : If True, random gamma correction will be used. - `gamma_range` : Range of the gamma correction. The gamma will be in [1 - gamma_range, 1 + gamma_range] (1.0 == neutral). +- `use_holes` : If True, holes will be created in the input images to teach the network to gap them. +- `export_aug_sample` : If True, an augmented sample will be exported to the working directory as a preview. """ -#@markdown ## 📍 a. Data paths +## 📍 a. Data paths -data_folder = "/home/benedetti/Documents/projects/2060-microglia/data/training-data/clean" +data_folder = "/home/benedetti/Downloads/training-audrey" qc_folder = None inputs_name = "inputs" masks_name = "masks" -models_path = "/home/benedetti/Documents/projects/2060-microglia/µnet" +models_path = "/home/benedetti/Downloads/training-audrey/models" working_directory = "/tmp/unet_working/" -model_name_prefix = "µnet" +model_name_prefix = "unet" reset_local_data = True -remove_wrong_data = False +remove_wrong_data = True data_usage = None -#@markdown ## 📍 b. Network architecture +## 📍 b. Network architecture validation_percentage = 0.15 batch_size = 8 @@ -88,16 +101,16 @@ optimizer = 'Adam' learning_rate = 0.001 skeleton_coef = 0.2 -bce_coef = 0.7 +bce_coef = 0.25 early_stop_patience = 50 dilation_kernel = diamond(1) loss = dice_skeleton_loss(skeleton_coef, bce_coef) -#@markdown ## 📍 c. Data augmentation +## 📍 c. Data augmentation use_data_augmentation = True use_mirroring = True -use_gaussian_noise = True +use_gaussian_noise = False noise_scale = 0.001 use_random_rotations = True angle_range = (-90, 90) @@ -110,7 +123,7 @@ # SANITY CHECK # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -#@markdown # ⭐ 2. SANITY CHECK +# ⭐ 2. SANITY CHECK """ The goal of this section is to make sure that the data located in the `data_folder` is consistent. @@ -123,7 +136,7 @@ | Masks must be binary masks (on 8-bits with only 0 and another value). """ -#@markdown ## 📍 a. Data check +## 📍 a. Data check # Regex matching a TIFF file, whatever the case and the number of 'f'. _TIFF_REGEX = r".+\.tiff?" @@ -276,7 +289,7 @@ def merge_dicts(d1, d2): d1[key] = value -#@markdown ## 📍 b. Sanity check launcher +## 📍 b. Sanity check launcher _SANITY_CHECK = [ ("extension", is_extension_correct), @@ -321,7 +334,7 @@ def sanity_check(root_folder): return (all(assessment), results) -#@markdown ## 📍 c. Remove dirty data +## 📍 c. Remove dirty data def remove_dirty_data(root_folder, folders, results): """ @@ -344,9 +357,9 @@ def remove_dirty_data(root_folder, folders, results): # DATA MIGRATION # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -#@markdown # ⭐ 3. DATA MIGRATION +# ⭐ 3. DATA MIGRATION -#@markdown ## 📍 a. Utils +## 📍 a. Utils _LOCAL_FOLDERS = ["training", "validation", "testing"] @@ -440,9 +453,9 @@ def migrate_data(targets, source): # DATA AUGMENTATION # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -#@markdown # ⭐ 4. DATA AUGMENTATION +# ⭐ 4. DATA AUGMENTATION -#@markdown ## 📍 a. Data augmentation functions +## 📍 a. Data augmentation functions def deteriorate_image(image, mask, num_points=25): """ @@ -581,7 +594,7 @@ def apply_data_augmentation(image, mask): image, mask = normalize(image, mask) return image, mask -#@markdown ## 📍 b. Datasets visualization +## 📍 b. Datasets visualization def visualize_augmentations(model_path, num_examples=6): """ @@ -620,9 +633,9 @@ def visualize_augmentations(model_path, num_examples=6): # DATASET GENERATOR # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -#@markdown # ⭐ 5. DATASET GENERATOR +# ⭐ 5. DATASET GENERATOR -#@markdown ## 📍 a. Datasets generator +## 📍 a. Datasets generator def open_pair(input_path, mask_path, training, img_only): raw_img = tifffile.imread(input_path) @@ -719,9 +732,9 @@ def export_data_usage(model_path): # MODEL GENERATOR # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -#@markdown # ⭐ 6. MODEL GENERATOR +# ⭐ 6. MODEL GENERATOR -#@markdown ## 📍 a. Utils +## 📍 a. Utils def get_version(): """ @@ -740,7 +753,40 @@ def get_version(): else: return int(content[-1].split('-')[-1].replace('V', '')) + 1 -#@markdown ## 📍 b. UNet2D architecture +## 📍 b. Structure of an attention block + +def attention_block(x, g, intermediate_channels): + """ + Attention Block pour UNet. + + Args: + x: TensorFlow tensor des caractéristiques de l'encodeur (skip connection). + g: TensorFlow tensor des caractéristiques du décodeur. + intermediate_channels: Nombre de canaux intermédiaires. + + Returns: + Tensor avec attention appliquée sur `x`. + """ + # Transformation de la caractéristique du décodeur + g1 = Conv2D(intermediate_channels, kernel_size=1, strides=1, padding="same")(g) + g1 = BatchNormalization()(g1) + + # Transformation de la caractéristique de l'encodeur + x1 = Conv2D(intermediate_channels, kernel_size=1, strides=1, padding="same")(x) + x1 = BatchNormalization()(x1) + + # Calcul de l'attention (g1 + x1 -> ReLU -> Sigmoid) + psi = Add()([g1, x1]) + psi = Activation('relu')(psi) + psi = Conv2D(1, kernel_size=1, strides=1, padding="same")(psi) + psi = BatchNormalization()(psi) + psi = Activation('sigmoid')(psi) + + # Application de l'attention sur x + out = Multiply()([x, psi]) + return out + +## 📍 c. UNet2D architecture def create_unet2d_model(input_shape): """ @@ -774,6 +820,7 @@ def create_unet2d_model(input_shape): num_filters = num_filters_start * 2**i x = UpSampling2D(2)(x) x = Conv2DTranspose(num_filters, (3, 3), strides=(1, 1), padding='same')(x) + x = attention_block(skip_connections[i], x, intermediate_channels=8) x = concatenate([x, skip_connections[i]]) x = Conv2D(num_filters, 3, activation='relu', padding='same', kernel_initializer='he_normal')(x) # x = BatchNormalization()(x) @@ -785,7 +832,7 @@ def create_unet2d_model(input_shape): return model -#@markdown ## 📍 c. Model instanciator +## 📍 d. Model instanciator def instanciate_model(): input_shape = get_shape() @@ -806,9 +853,9 @@ def instanciate_model(): # TRAINING THE MODEL # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -#@markdown # ⭐ 7. TRAINING THE MODEL +# ⭐ 7. TRAINING THE MODEL -#@markdown ## 📍 a. Creating callback for validation +## 📍 a. Creating callback for validation class SavePredictionsCallback(Callback): def __init__(self, model_path, num_examples=5): @@ -855,7 +902,7 @@ def on_train_end(self, logs=None): shutil.move(last_epoch_path, last_epoch_dest) -#@markdown ## 📍 b. Training launcher +## 📍 b. Training launcher import math @@ -934,7 +981,7 @@ def export_settings(model_path): # EVALUATE THE MODEL # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -#@markdown # ⭐ 8. EVALUATE THE MODEL +# ⭐ 8. EVALUATE THE MODEL def plot_training_history(history, model_path): """ diff --git a/src/microglia_analyzer/experimental/segment_microglia.py b/src/microglia_analyzer/experimental/segment_microglia.py index a4b8e21..6e43097 100644 --- a/src/microglia_analyzer/experimental/segment_microglia.py +++ b/src/microglia_analyzer/experimental/segment_microglia.py @@ -53,8 +53,8 @@ def inference(self): return probabilities if __name__ == "__main__": - output_path = "/tmp/inference/" - model_path = "/home/benedetti/Documents/projects/2060-microglia/µnet/µnet-V208/best.keras" + output_path = "/home/benedetti/Downloads/training-audrey/output/" + model_path = "/home/benedetti/Downloads/training-audrey/models/unet-V007/best.keras" folder_path = "/home/benedetti/Documents/projects/2060-microglia/data/raw-data/tiff-data" content = [f for f in os.listdir(folder_path) if f.endswith(".tif")] for i, image_name in enumerate(content): diff --git a/src/microglia_analyzer/napari.yaml b/src/microglia_analyzer/napari.yaml index 3c2f0b6..ead50ba 100644 --- a/src/microglia_analyzer/napari.yaml +++ b/src/microglia_analyzer/napari.yaml @@ -9,17 +9,17 @@ contributions: - id: microglia-analyzer.microglia_analyzer python_name: microglia_analyzer._widget:MicrogliaAnalyzerWidget title: Microglia Analyzer - - id: microglia-analyzer.yolo_annotator - python_name: microglia_analyzer._widget_yolo_annotations:AnnotateBoundingBoxesWidget - title: YOLO Annotator + - id: microglia-analyzer.annotations_helper + python_name: microglia_analyzer._widget_annotations_helper:AnnotationsWidget + title: Annotations Helper - id: microglia-analyzer.tiles_creator python_name: microglia_analyzer._widget_tiles:TilesCreatorWidget title: Tiles Creator widgets: - command: microglia-analyzer.tiles_creator display_name: Tiles Creator - - command: microglia-analyzer.yolo_annotator - display_name: YOLO Annotator + - command: microglia-analyzer.annotations_helper + display_name: Annotations Helper - command: microglia-analyzer.microglia_analyzer display_name: Microglia Analyzer \ No newline at end of file