Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized Area thresholding #983

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions topostats/grains.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,20 +360,33 @@ def area_thresholding(self, image: npt.NDArray, area_thresholds: tuple) -> npt.N
if lower_size_limit is None:
lower_size_limit = 0
# Get array of grain numbers (discounting zero)
uniq = np.delete(np.unique(image), 0)
grain_count = 0
# uniq = np.delete(np.unique(image), 0)
# grain_count = 0
LOGGER.debug(
f"[{self.filename}] : Area thresholding grains | Thresholds: L: {(lower_size_limit / self.pixel_to_nm_scaling**2):.2f},"
f"U: {(upper_size_limit / self.pixel_to_nm_scaling**2):.2f} px^2, L: {lower_size_limit:.2f}, U: {upper_size_limit:.2f} nm^2."
)
for grain_no in uniq: # Calculate grian area in nm^2
grain_area = np.sum(image_cp == grain_no) * (self.pixel_to_nm_scaling**2)
# Compare area in nm^2 to area thresholds
if grain_area > upper_size_limit or grain_area < lower_size_limit:
image_cp[image_cp == grain_no] = 0
else:
grain_count += 1
image_cp[image_cp == grain_no] = grain_count

grain_counts = np.bincount(image_cp.ravel())
grain_counts = grain_counts[1:]
# Calculate areas in nm^2
grain_areas = grain_counts * (self.pixel_to_nm_scaling**2)

# Create a mask for valid grains
valid_grains = (grain_areas > lower_size_limit) & (grain_areas < upper_size_limit)

# Create a new mapping for valid grain numbers
new_indices = np.arange(1, valid_grains.sum() + 1) # New indices for valid grains
valid_grain_numbers = np.where(valid_grains)[0] + 1 # Original grain numbers that are valid

# Step 1: Create a boolean mask for valid grains
valid_mask = np.isin(image_cp, valid_grain_numbers)
# Step 2: Set invalid values to 0
image_cp[~valid_mask] = 0 # Invert the mask to find invalid values
# Map old grain numbers to new ones
for new_idx, old_idx in enumerate(valid_grain_numbers):
image_cp[image_cp == old_idx] = new_indices[new_idx]

return image_cp

def colour_regions(self, image: npt.NDArray, **kwargs) -> npt.NDArray:
Expand Down
Loading