Skip to content

Commit

Permalink
Fixed bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Gonzalo Mateo Garcia committed May 16, 2024
1 parent 052dc69 commit 013d35d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
28 changes: 17 additions & 11 deletions satalign/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ def __init__(
(bands, height, width).
channel (Union[int, str], optional): The channel or feature to be used for
alignment. Defaults to "gradients". The options are:
- "gradients": The gradients of the image. It uses the Sobel operator
to calculate the gradients.
- "gradients": The gradients of the rgb channels of the image.
It uses the Sobel operator to calculate the gradients.
- "mean": The mean of all the bands.
- "luminance": The luminance of the image. It uses the following
- "luminance": The luminance of the rgb channels of the image. It uses the following
formula: 0.299 * R + 0.587 * G + 0.114 * B.
- "rgb_mean": The mean of the RGB bands.
- int: The index of the band to be used.
interpolation (int, optional): Interpolation type used when transforming
the stack of images. Defaults to cv2.INTER_LINEAR + cv2.WARP_FILL_OUTLIERS.
Expand Down Expand Up @@ -164,11 +165,6 @@ def create_layer(self, img: np.ndarray) -> np.ndarray:
# If the image has more than 3 bands, select the RGB bands
C, H, W = img.shape

if C > 3:
layer = img[self.rgb_bands]
else:
layer = img

# Crop the image with respect to the centroid
if self.crop_center is not None:

Expand All @@ -177,22 +173,32 @@ def create_layer(self, img: np.ndarray) -> np.ndarray:

radius_x = (img.shape[-1] - self.crop_center) // 2
radius_y = (img.shape[-2] - self.crop_center) // 2
layer = layer[:, radius_y:-radius_y, radius_x:-radius_x]
img = img[:, radius_y:-radius_y, radius_x:-radius_x]

# From RGB to grayscale (image feature)
if isinstance(self.channel, str):
if C > 3:
layer = img[self.rgb_bands]
else:
layer = img

if self.channel == "gradients":
global_reference = cv2.Sobel(
layer.mean(0).astype(np.float32), cv2.CV_32F, 1, 1
)
elif self.channel == "mean":
global_reference = layer.mean(0)
global_reference = img.mean(0)
elif self.channel == "luminance":
global_reference = (
layer[0] * 0.299 + layer[1] * 0.587 + layer[2] * 0.114
)
elif self.channel == "rgb_mean":
global_reference = layer.mean(0)
else:
raise ValueError("The channel should be 'gradients', 'mean', 'luminance' or 'rgb_mean'")

elif isinstance(self.channel, int):
global_reference = layer[self.channel].copy()
global_reference = img[self.channel].copy()
else:
raise ValueError(
"The channel should be a string (a specific method) or an integer (a band index)"
Expand Down
4 changes: 2 additions & 2 deletions satalign/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def plot_s2_scatter(warp_df: pd.DataFrame) -> Tuple[plt.Figure, plt.Axes]:
x2, y2 = warp_df[~warp_df["after"]]["x"], warp_df[~warp_df["after"]]["y"]

# Build the scatter plot
ax.scatter(x1, y1, label="After", color="blue", alpha=0.5)
ax.scatter(x2, y2, label="Before", color="red", alpha=0.5)
ax.scatter(x1, y1, label="After Cutoff date", color="blue", alpha=0.5)
ax.scatter(x2, y2, label="Before Cutoff date", color="red", alpha=0.5)
ax.legend()

return fig, ax
Expand Down

0 comments on commit 013d35d

Please sign in to comment.