Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Feb 10, 2025
1 parent 657aad1 commit 490df6e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
35 changes: 18 additions & 17 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3191,6 +3191,18 @@ def get_mud_params(
}


def rgb_to_optical_density(img: np.ndarray, eps: float = 1e-6) -> np.ndarray:
max_value = MAX_VALUES_BY_DTYPE[img.dtype]
pixel_matrix = img.reshape(-1, 3).astype(np.float32)
pixel_matrix = np.maximum(pixel_matrix / max_value, eps)
return -np.log(pixel_matrix)


def normalize_vectors(vectors: np.ndarray) -> np.ndarray:
norms = np.sqrt(np.sum(vectors**2, axis=1, keepdims=True))
return vectors / norms


def get_normalizer(method: Literal["vahadane", "macenko"]) -> StainNormalizer:
"""Get stain normalizer based on method."""
return VahadaneNormalizer() if method == "vahadane" else MacenkoNormalizer()
Expand Down Expand Up @@ -3225,7 +3237,7 @@ def fit_transform(self, optical_density: np.ndarray) -> tuple[np.ndarray, np.nda

# Initialize concentrations based on projection onto initial colors
# This gives us a physically meaningful starting point
stain_colors_normalized = stain_colors / np.sqrt(np.sum(stain_colors**2, axis=1, keepdims=True))
stain_colors_normalized = normalize_vectors(stain_colors)
stain_concentrations = np.maximum(optical_density @ stain_colors_normalized.T, 0)

# Iterative updates with careful normalization
Expand All @@ -3246,7 +3258,7 @@ def fit_transform(self, optical_density: np.ndarray) -> tuple[np.ndarray, np.nda

# Ensure non-negativity and normalize
stain_colors = np.maximum(stain_colors, 0)
stain_colors /= np.sqrt(np.sum(stain_colors**2, axis=1, keepdims=True))
stain_colors = normalize_vectors(stain_colors)

return stain_concentrations, stain_colors

Expand All @@ -3258,7 +3270,7 @@ def order_stains_combined(stain_colors: np.ndarray) -> tuple[int, int]:
for more robust identification.
"""
# Normalize stain vectors
stain_colors = stain_colors / np.sqrt(np.sum(stain_colors**2, axis=1, keepdims=True))
stain_colors = normalize_vectors(stain_colors)

# Calculate angles (Macenko)
angles = np.mod(np.arctan2(stain_colors[:, 1], stain_colors[:, 0]), np.pi)
Expand All @@ -3280,13 +3292,7 @@ def order_stains_combined(stain_colors: np.ndarray) -> tuple[int, int]:

class VahadaneNormalizer(StainNormalizer):
def fit(self, img: np.ndarray) -> None:
max_value = MAX_VALUES_BY_DTYPE[img.dtype]

pixel_matrix = img.reshape((-1, 3)).astype(np.float32)

pixel_matrix = np.maximum(pixel_matrix / max_value, 1e-6)

optical_density = -np.log(pixel_matrix)
optical_density = rgb_to_optical_density(img)

nmf = SimpleNMF(n_iter=100)
_, stain_colors = nmf.fit_transform(optical_density)
Expand All @@ -3311,11 +3317,8 @@ def __init__(self, angular_percentile: float = 99):

def fit(self, img: np.ndarray, angular_percentile: float = 99) -> None:
"""Extract H&E stain matrix using optimized Macenko's method."""
max_value = MAX_VALUES_BY_DTYPE[img.dtype]
# Step 1: Convert RGB to optical density (OD) space
pixel_matrix = img.reshape(-1, 3).astype(np.float32)
pixel_matrix = np.maximum(pixel_matrix / max_value, 1e-6)
optical_density = -np.log(pixel_matrix)
optical_density = rgb_to_optical_density(img)

# Step 2: Remove background pixels
od_threshold = 0.05
Expand Down Expand Up @@ -3409,9 +3412,7 @@ def apply_he_stain_augmentation(
augment_background: bool,
) -> np.ndarray:
# Step 1: Convert RGB to optical density space
rgb_pixels = img.reshape(-1, 3)
rgb_pixels_clipped = np.maximum(rgb_pixels, 1e-6) # Prevent log(0)
optical_density = -np.log(rgb_pixels_clipped)
optical_density = rgb_to_optical_density(img)

# Step 2: Calculate stain concentrations using regularized pseudo-inverse
stain_matrix = np.ascontiguousarray(stain_matrix, dtype=np.float32)
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pytest_cov>=5.0.0
pytest_mock>=3.14.0
requests>=2.31.0
scikit-image
scikit-learn
tomli>=2.0.1
torch>=2.3.1
torchvision>=0.18.1
Expand Down

0 comments on commit 490df6e

Please sign in to comment.