Skip to content

Commit

Permalink
Fix in SaltAndPepper
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Feb 2, 2025
1 parent e85ec53 commit 3bd82e3
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 13 deletions.
6 changes: 5 additions & 1 deletion albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2409,7 +2409,11 @@ def apply_salt_and_pepper(
pepper_mask: np.ndarray,
) -> np.ndarray:
"""Apply salt and pepper noise to image using pre-computed masks."""
# Avoid copy if possible by using np.where
# Add channel dimension to masks if image is 3D
if img.ndim == 3:
salt_mask = salt_mask[..., None]
pepper_mask = pepper_mask[..., None]

max_value = MAX_VALUES_BY_DTYPE[img.dtype]
return np.where(salt_mask, max_value, np.where(pepper_mask, 0, img))

Expand Down
23 changes: 11 additions & 12 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5772,28 +5772,27 @@ def get_params_dependent_on_data(
data: dict[str, Any],
) -> dict[str, Any]:
image = data["image"] if "image" in data else data["images"][0]
height, width = image.shape[-2:] # Get spatial dimensions only
height, width = image.shape[:2]

# Sample total amount and salt ratio
total_amount = self.py_random.uniform(*self.amount)
salt_ratio = self.py_random.uniform(*self.salt_vs_pepper)

# Calculate number of pixels to affect (only for H x W, not channels)
num_pixels = int(height * width * total_amount)
num_salt = int(num_pixels * salt_ratio)

# Generate flat indices for salt and pepper (for H x W only)
total_pixels = height * width
indices = self.random_generator.choice(total_pixels, size=num_pixels, replace=False)
# Generate all positions at once
all_positions = np.arange(height * width)
noise_positions = self.random_generator.choice(all_positions, size=num_pixels, replace=False)

# Create 2D masks using advanced indexing
salt_mask = np.zeros(total_pixels, dtype=bool)
pepper_mask = np.zeros(total_pixels, dtype=bool)
# Create masks
salt_mask = np.zeros(height * width, dtype=bool)
pepper_mask = np.zeros(height * width, dtype=bool)

salt_mask[indices[:num_salt]] = True
pepper_mask[indices[num_salt:]] = True
# Set salt and pepper positions
salt_mask[noise_positions[:num_salt]] = True
pepper_mask[noise_positions[num_salt:]] = True

# Reshape masks to 2D and broadcast to all channels
# Reshape to 2D
salt_mask = salt_mask.reshape(height, width)
pepper_mask = pepper_mask.reshape(height, width)

Expand Down
83 changes: 83 additions & 0 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,3 +1225,86 @@ def test_pixel_dropout_per_channel():
assert len(unique_values) == 2 # Should only have original value and drop value
assert drop_val in unique_values
assert 255 in unique_values



def test_salt_and_pepper_noise():
# Test image setup - create all gray image instead of black with gray square
image = np.full((100, 100, 3), 128, dtype=np.uint8) # All gray image

# Fixed parameters for deterministic testing
amount = (0.05, 0.05) # Exactly 5% of pixels
salt_vs_pepper = (0.6, 0.6) # Exactly 60% salt, 40% pepper

transform = A.SaltAndPepper(
amount=amount,
salt_vs_pepper=salt_vs_pepper,
p=1.0
)
transform.set_random_seed(137)

# Apply transform
transformed = transform(image=image)["image"]

# Count all changes
salt_pixels = (transformed == 255).all(axis=2)
pepper_pixels = (transformed == 0).all(axis=2)

total_changes = salt_pixels.sum() + pepper_pixels.sum()

expected_pixels = int(image.shape[0] * image.shape[1] * amount[0])
assert total_changes == expected_pixels, \
f"Expected {expected_pixels} noisy pixels, got {total_changes}"

# Verify salt vs pepper ratio
expected_salt = int(expected_pixels * salt_vs_pepper[0])
assert salt_pixels.sum() == expected_salt, \
f"Expected {expected_salt} salt pixels, got {salt_pixels.sum()}"


def test_salt_and_pepper_float_image():
"""Test salt and pepper noise on float32 images"""
image = np.zeros((100, 100, 3), dtype=np.float32)
image[25:75, 25:75] = 0.5 # Gray square

transform = A.SaltAndPepper(
amount=(0.05, 0.05),
salt_vs_pepper=(0.6, 0.6),
p=1.0
)
transform.set_random_seed(137)

transformed = transform(image=image)["image"]

# Check that salt pixels are 1.0 and pepper pixels are 0.0
changed_mask = (transformed != image).any(axis=2)
assert np.allclose(transformed[transformed > 0.9], 1.0), \
"Salt pixels should be exactly 1.0 for float images"
assert np.allclose(transformed[transformed < 0.1], 0.0), \
"Pepper pixels should be exactly 0.0 for float images"

def test_salt_and_pepper_grayscale():
"""Test salt and pepper noise on single-channel images"""
image = np.zeros((100, 100), dtype=np.uint8)
image[25:75, 25:75] = 128

transform = A.SaltAndPepper(
amount=(0.05, 0.05),
salt_vs_pepper=(0.6, 0.6),
p=1.0
)
transform.set_random_seed(137)

transformed = transform(image=image)["image"]

# Verify shape is preserved
assert transformed.shape == image.shape, \
"Transform should preserve single-channel image shape"

# Check noise values
changed_mask = transformed != image
salt_pixels = (transformed == 255) & changed_mask
pepper_pixels = (transformed == 0) & changed_mask

assert (salt_pixels | pepper_pixels | ~changed_mask).all(), \
"Changed pixels should only be salt (255) or pepper (0)"

0 comments on commit 3bd82e3

Please sign in to comment.