Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stain augment #2337

Merged
merged 7 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 310 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3136,3 +3136,313 @@ def get_mud_params(
"mud": mud.astype(np.float32),
"non_mud": non_mud.astype(np.float32),
}


# Standard reference H&E stain matrices
ternaus marked this conversation as resolved.
Show resolved Hide resolved
STAIN_MATRICES = {
"ruifrok": np.array(
[ # Ruifrok & Johnston standard reference
[0.644211, 0.716556, 0.266844], # Hematoxylin
[0.092789, 0.954111, 0.283111], # Eosin
],
),
"macenko": np.array(
[ # Macenko's reference
[0.5626, 0.7201, 0.4062],
[0.2159, 0.8012, 0.5581],
],
),
"standard": np.array(
[ # Standard bright-field microscopy
[0.65, 0.70, 0.29],
[0.07, 0.99, 0.11],
],
),
"high_contrast": np.array(
[ # Enhanced contrast
[0.55, 0.88, 0.11],
[0.12, 0.86, 0.49],
],
),
"h_heavy": np.array(
[ # Hematoxylin dominant
[0.75, 0.61, 0.32],
[0.04, 0.93, 0.36],
],
),
"e_heavy": np.array(
[ # Eosin dominant
[0.60, 0.75, 0.28],
[0.17, 0.95, 0.25],
],
),
"dark": np.array(
[ # Darker staining
[0.78, 0.55, 0.28],
[0.09, 0.97, 0.21],
],
),
"light": np.array(
[ # Lighter staining
[0.57, 0.71, 0.38],
[0.15, 0.89, 0.42],
],
),
}


def get_normalizer(method: Literal["vahadane", "macenko"]) -> StainNormalizer:
"""Get stain normalizer based on method."""
return VahadaneNormalizer() if method == "vahadane" else MacenkoNormalizer()


class StainNormalizer:
"""Base class for stain normalizers."""

def __init__(self) -> None:
self.stain_matrix_target = None

def fit(self, img: np.ndarray) -> None:
"""Extract stain matrix from image."""
raise NotImplementedError


class SimpleNMF:
def __init__(self, n_iter: int = 100):
self.n_iter = n_iter
# Initialize with standard H&E colors from Ruifrok
self.initial_colors = np.array(
[
[0.644211, 0.716556, 0.266844], # Hematoxylin
[0.092789, 0.954111, 0.283111], # Eosin
],
dtype=np.float32,
)

def fit_transform(self, optical_density: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
# Start with known H&E colors
stain_colors = self.initial_colors.copy()

# Initialize concentrations based on projection onto initial colors
# This gives us a physically meaningful starting point
stain_colors_normalized = stain_colors / np.sqrt(np.sum(stain_colors**2, axis=1, keepdims=True))
stain_concentrations = np.maximum(optical_density @ stain_colors_normalized.T, 0)

# Iterative updates with careful normalization
eps = 1e-6
for _ in range(self.n_iter):
# Update concentrations
numerator = optical_density @ stain_colors.T
denominator = stain_concentrations @ (stain_colors @ stain_colors.T)
stain_concentrations *= numerator / (denominator + eps)

# Ensure non-negativity
stain_concentrations = np.maximum(stain_concentrations, 0)

# Update colors
numerator = stain_concentrations.T @ optical_density
denominator = (stain_concentrations.T @ stain_concentrations) @ stain_colors
stain_colors *= numerator / (denominator + eps)

# Ensure non-negativity and normalize
stain_colors = np.maximum(stain_colors, 0)
stain_colors /= np.sqrt(np.sum(stain_colors**2, axis=1, keepdims=True))

return stain_concentrations, stain_colors


def order_stains_combined(stain_colors: np.ndarray) -> tuple[int, int]:
"""Order stains using a combination of methods.

This combines both angular information and spectral characteristics
for more robust identification.
"""
# Normalize stain vectors
stain_colors = stain_colors / np.sqrt(np.sum(stain_colors**2, axis=1, keepdims=True))

# Calculate angles (Macenko)
angles = np.mod(np.arctan2(stain_colors[:, 1], stain_colors[:, 0]), np.pi)

# Calculate spectral ratios (Ruifrok)
blue_ratio = stain_colors[:, 2] / (np.sum(stain_colors, axis=1) + 1e-6)
red_ratio = stain_colors[:, 0] / (np.sum(stain_colors, axis=1) + 1e-6)

# Combine scores
# High angle and high blue ratio indicates Hematoxylin
# Low angle and high red ratio indicates Eosin
scores = angles * blue_ratio - red_ratio

hematoxylin_idx = np.argmax(scores)
eosin_idx = 1 - hematoxylin_idx

return hematoxylin_idx, eosin_idx


class VahadaneNormalizer(StainNormalizer):
def fit(self, img: np.ndarray) -> None:
max_value = MAX_VALUES_BY_DTYPE[img.dtype]

pixel_matrix = img.reshape((-1, 3)).astype(np.float32)

pixel_matrix = np.maximum(pixel_matrix / max_value, 1e-6)

optical_density = -np.log(pixel_matrix)

nmf = SimpleNMF(n_iter=100)
_, stain_colors = nmf.fit_transform(optical_density)

# Use combined method for robust stain ordering
hematoxylin_idx, eosin_idx = order_stains_combined(stain_colors)

self.stain_matrix_target = np.array(
[
stain_colors[hematoxylin_idx],
stain_colors[eosin_idx],
],
)


class MacenkoNormalizer(StainNormalizer):
"""Macenko stain normalizer with optimized computations."""

def __init__(self, angular_percentile: float = 99):
super().__init__()
self.angular_percentile = angular_percentile

def fit(self, img: np.ndarray, angular_percentile: float = 99) -> None:
"""Extract H&E stain matrix using optimized Macenko's method."""
max_value = MAX_VALUES_BY_DTYPE[img.dtype]
# Step 1: Convert RGB to optical density (OD) space
pixel_matrix = img.reshape(-1, 3).astype(np.float32)
pixel_matrix = np.maximum(pixel_matrix / max_value, 1e-6)
optical_density = -np.log(pixel_matrix)

# Step 2: Remove background pixels
od_threshold = 0.05
threshold_mask = (optical_density > od_threshold).any(axis=1)
tissue_density = optical_density[threshold_mask]

if len(tissue_density) < 1:
raise ValueError(f"No tissue pixels found (threshold={od_threshold})")

# Step 3: Compute covariance matrix
tissue_density = np.ascontiguousarray(tissue_density, dtype=np.float32)
od_covariance = cv2.calcCovarMatrix(
tissue_density,
None,
cv2.COVAR_NORMAL | cv2.COVAR_ROWS | cv2.COVAR_SCALE,
)[0]

# Step 4: Get principal components
eigenvalues, eigenvectors = cv2.eigen(od_covariance)[1:]
ternaus marked this conversation as resolved.
Show resolved Hide resolved
idx = np.argsort(eigenvalues.ravel())[-2:]
principal_eigenvectors = np.ascontiguousarray(eigenvectors[:, idx], dtype=np.float32)

# Step 5: Project onto eigenvector plane
plane_coordinates = tissue_density @ principal_eigenvectors

# Step 6: Find angles of extreme points
polar_angles = np.arctan2(
plane_coordinates[:, 1],
plane_coordinates[:, 0],
)

# Get robust angle estimates
hematoxylin_angle = np.percentile(polar_angles, 100 - angular_percentile)
eosin_angle = np.percentile(polar_angles, angular_percentile)

# Step 7: Convert angles back to RGB space
hem_cos, hem_sin = np.cos(hematoxylin_angle), np.sin(hematoxylin_angle)
eos_cos, eos_sin = np.cos(eosin_angle), np.sin(eosin_angle)

angle_to_vector = np.array(
[[hem_cos, hem_sin], [eos_cos, eos_sin]],
dtype=np.float32,
)
stain_vectors = cv2.gemm(
angle_to_vector,
principal_eigenvectors.T,
1,
None,
0,
)

# Step 8: Ensure non-negativity by taking absolute values
# This is valid because stain vectors represent absorption coefficients
stain_vectors = np.abs(stain_vectors)

# Step 9: Normalize vectors to unit length
stain_vectors = stain_vectors / np.sqrt(np.sum(stain_vectors**2, axis=1, keepdims=True))

# Step 10: Order vectors as [hematoxylin, eosin]
# Hematoxylin typically has larger red component
self.stain_matrix_target = stain_vectors if stain_vectors[0, 0] > stain_vectors[1, 0] else stain_vectors[::-1]


def get_tissue_mask(img: np.ndarray, threshold: float = 0.85) -> np.ndarray:
"""Get binary mask of tissue regions based on luminosity.

Args:
img: RGB image in float32 format, range [0, 1]
threshold: Luminosity threshold. Pixels with luminosity below this value
are considered tissue. Range: 0 to 1. Default: 0.85

Returns:
Binary mask where True indicates tissue regions
"""
# Convert to grayscale using RGB weights: R*0.299 + G*0.587 + B*0.114
luminosity = img[..., 0] * 0.299 + img[..., 1] * 0.587 + img[..., 2] * 0.114

# Tissue is darker, so we want pixels below threshold
mask = luminosity < threshold

return mask.reshape(-1)


@clipped
@float32_io
def apply_he_stain_augmentation(
img: np.ndarray,
stain_matrix: np.ndarray,
scale_factors: np.ndarray,
shift_values: np.ndarray,
augment_background: bool,
) -> np.ndarray:
# Step 1: Convert RGB to optical density space
rgb_pixels = img.reshape(-1, 3)
rgb_pixels_clipped = np.maximum(rgb_pixels, 1e-6) # Prevent log(0)
optical_density = -np.log(rgb_pixels_clipped)

# Step 2: Calculate stain concentrations using regularized pseudo-inverse
stain_matrix = np.ascontiguousarray(stain_matrix, dtype=np.float32)

# Add small regularization term for numerical stability
regularization = 1e-6
stain_correlation = stain_matrix @ stain_matrix.T + regularization * np.eye(2)
density_projection = stain_matrix @ optical_density.T

try:
# Solve for stain concentrations
stain_concentrations = np.linalg.solve(stain_correlation, density_projection).T
except np.linalg.LinAlgError:
# Fallback to pseudo-inverse if direct solve fails
stain_concentrations = np.linalg.lstsq(
stain_matrix.T,
optical_density,
rcond=regularization,
)[0].T

# Step 3: Apply concentration adjustments
if not augment_background:
# Only modify tissue regions
tissue_mask = get_tissue_mask(img).reshape(-1)
stain_concentrations[tissue_mask] = stain_concentrations[tissue_mask] * scale_factors + shift_values
else:
# Modify all pixels
stain_concentrations = stain_concentrations * scale_factors + shift_values

# Step 4: Reconstruct RGB image
optical_density_result = stain_concentrations @ stain_matrix
rgb_result = np.exp(-optical_density_result)

return rgb_result.reshape(img.shape)
Loading
Loading