From 013d35ded6817fde702b1edd7c603900d6f83b54 Mon Sep 17 00:00:00 2001 From: Gonzalo Mateo Garcia Date: Thu, 16 May 2024 18:06:26 +0200 Subject: [PATCH] Fixed bug --- satalign/main.py | 28 +++++++++++++++++----------- satalign/utils.py | 4 ++-- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/satalign/main.py b/satalign/main.py index 3f15609..5ee3d6c 100644 --- a/satalign/main.py +++ b/satalign/main.py @@ -44,11 +44,12 @@ def __init__( (bands, height, width). channel (Union[int, str], optional): The channel or feature to be used for alignment. Defaults to "gradients". The options are: - - "gradients": The gradients of the image. It uses the Sobel operator - to calculate the gradients. + - "gradients": The gradients of the rgb channels of the image. + It uses the Sobel operator to calculate the gradients. - "mean": The mean of all the bands. - - "luminance": The luminance of the image. It uses the following + - "luminance": The luminance of the rgb channels of the image. It uses the following formula: 0.299 * R + 0.587 * G + 0.114 * B. + - "rgb_mean": The mean of the RGB bands. - int: The index of the band to be used. interpolation (int, optional): Interpolation type used when transforming the stack of images. Defaults to cv2.INTER_LINEAR + cv2.WARP_FILL_OUTLIERS. @@ -164,11 +165,6 @@ def create_layer(self, img: np.ndarray) -> np.ndarray: # If the image has more than 3 bands, select the RGB bands C, H, W = img.shape - if C > 3: - layer = img[self.rgb_bands] - else: - layer = img - # Crop the image with respect to the centroid if self.crop_center is not None: @@ -177,22 +173,32 @@ def create_layer(self, img: np.ndarray) -> np.ndarray: radius_x = (img.shape[-1] - self.crop_center) // 2 radius_y = (img.shape[-2] - self.crop_center) // 2 - layer = layer[:, radius_y:-radius_y, radius_x:-radius_x] + img = img[:, radius_y:-radius_y, radius_x:-radius_x] # From RGB to grayscale (image feature) if isinstance(self.channel, str): + if C > 3: + layer = img[self.rgb_bands] + else: + layer = img + if self.channel == "gradients": global_reference = cv2.Sobel( layer.mean(0).astype(np.float32), cv2.CV_32F, 1, 1 ) elif self.channel == "mean": - global_reference = layer.mean(0) + global_reference = img.mean(0) elif self.channel == "luminance": global_reference = ( layer[0] * 0.299 + layer[1] * 0.587 + layer[2] * 0.114 ) + elif self.channel == "rgb_mean": + global_reference = layer.mean(0) + else: + raise ValueError("The channel should be 'gradients', 'mean', 'luminance' or 'rgb_mean'") + elif isinstance(self.channel, int): - global_reference = layer[self.channel].copy() + global_reference = img[self.channel].copy() else: raise ValueError( "The channel should be a string (a specific method) or an integer (a band index)" diff --git a/satalign/utils.py b/satalign/utils.py index 1ee80fd..c99e7ca 100644 --- a/satalign/utils.py +++ b/satalign/utils.py @@ -131,8 +131,8 @@ def plot_s2_scatter(warp_df: pd.DataFrame) -> Tuple[plt.Figure, plt.Axes]: x2, y2 = warp_df[~warp_df["after"]]["x"], warp_df[~warp_df["after"]]["y"] # Build the scatter plot - ax.scatter(x1, y1, label="After", color="blue", alpha=0.5) - ax.scatter(x2, y2, label="Before", color="red", alpha=0.5) + ax.scatter(x1, y1, label="After Cutoff date", color="blue", alpha=0.5) + ax.scatter(x2, y2, label="Before Cutoff date", color="red", alpha=0.5) ax.legend() return fig, ax