From 5b9a4c66bbae616a5169eba49f44194ebf248082 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Wed, 29 Jan 2025 18:34:22 -0800 Subject: [PATCH 1/5] Speed up in RandomRain --- albumentations/augmentations/functional.py | 61 +++++----- tests/functional/test_functional.py | 131 +++++++++++++++++++++ 2 files changed, 160 insertions(+), 32 deletions(-) diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index a78591122..6452e0b14 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -748,43 +748,40 @@ def add_rain( brightness_coefficient: float, rain_drops: list[tuple[int, int]], ) -> np.ndarray: - """Adds rain drops to the image. + """Adds rain drops to the image with optimized performance.""" + img = img.copy() - Args: - img (np.ndarray): Input image. - slant (int): The angle of the rain drops. - drop_length (int): The length of each rain drop. - drop_width (int): The width of each rain drop. - drop_color (tuple[int, int, int]): The color of the rain drops in RGB format. - blur_value (int): The size of the kernel used to blur the image. Rainy views are blurry. - brightness_coefficient (float): Coefficient to adjust the brightness of the image. Rainy days are usually shady. - rain_drops (list[tuple[int, int]]): A list of tuples where each tuple represents the (x, y) - coordinates of the starting point of a rain drop. + # Vectorize rain drop coordinates + rain_drops = np.array(rain_drops) + if len(rain_drops) == 0: + return img - Returns: - np.ndarray: Image with rain effect added. + # Calculate all end points at once + end_points = rain_drops + np.array([slant, drop_length]) - Reference: - https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library - """ - img = img.copy() - for rain_drop_x0, rain_drop_y0 in rain_drops: - rain_drop_x1 = rain_drop_x0 + slant - rain_drop_y1 = rain_drop_y0 + drop_length - - cv2.line( - img, - (rain_drop_x0, rain_drop_y0), - (rain_drop_x1, rain_drop_y1), - drop_color, - drop_width, - ) + # Draw all lines at once using polylines + lines = np.stack([rain_drops, end_points], axis=1) + cv2.polylines( + img, + lines.astype(np.int32), + False, + drop_color, + drop_width, + lineType=cv2.LINE_AA, # Use anti-aliasing for better quality + ) + + # Optimize blur operation + if blur_value > 1: + # Use a faster blur method + img = cv2.boxFilter(img, -1, (blur_value, blur_value), normalize=True) - img = cv2.blur(img, (blur_value, blur_value)) # rainy view are blurry - image_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32) - image_hsv[:, :, 2] *= brightness_coefficient + # Optimize brightness adjustment using HSV + if brightness_coefficient != 1.0: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + img[:, :, 2] = cv2.multiply(img[:, :, 2], brightness_coefficient) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) - return cv2.cvtColor(image_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + return img def get_fog_particle_radiuses( diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 6ac744b98..e9546f7c0 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -2326,3 +2326,134 @@ def test_gaussian_illumination_sigma(sigma, expected_pattern): if expected_pattern == "narrow": assert diff > wide_diff # Narrow should have steeper falloff than wide + + + +@pytest.mark.parametrize( + ["img", "slant", "drop_length", "drop_width", "drop_color", "blur_value", "brightness_coefficient", "rain_drops", "expected_shape"], + [ + # Test basic functionality with small image + ( + np.zeros((10, 10, 3), dtype=np.uint8), + 5, + 3, + 1, + (200, 200, 200), + 3, + 0.7, + [(2, 2)], + (10, 10, 3), + ), + # Test with no rain drops + ( + np.zeros((20, 20, 3), dtype=np.uint8), + 5, + 3, + 1, + (200, 200, 200), + 3, + 0.7, + [], + (20, 20, 3), + ), + # Test with multiple rain drops + ( + np.zeros((30, 30, 3), dtype=np.uint8), + -5, + 5, + 2, + (255, 255, 255), + 5, + 0.8, + [(5, 5), (10, 10), (15, 15)], + (30, 30, 3), + ), + ] +) +def test_add_rain_shape_and_type( + img, slant, drop_length, drop_width, drop_color, blur_value, brightness_coefficient, rain_drops, expected_shape +): + result = fmain.add_rain( + img, slant, drop_length, drop_width, drop_color, blur_value, brightness_coefficient, rain_drops + ) + assert result.shape == expected_shape + assert result.dtype == np.uint8 + + +@pytest.mark.parametrize("brightness_coefficient", [0.5, 0.7, 1.0]) +def test_add_rain_brightness(brightness_coefficient): + """Test that brightness coefficient correctly affects image brightness""" + img = np.full((20, 20, 3), 100, dtype=np.uint8) + rain_drops = [(5, 5)] + + result = fmain.add_rain( + img=img, + slant=5, + drop_length=3, + drop_width=1, + drop_color=(200, 200, 200), + blur_value=3, + brightness_coefficient=brightness_coefficient, + rain_drops=rain_drops, + ) + + # Convert to HSV to check brightness + original_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + result_hsv = cv2.cvtColor(result, cv2.COLOR_RGB2HSV) + + if brightness_coefficient < 1.0: + # For darkening coefficients, brightness should decrease + assert np.mean(result_hsv[:, :, 2]) < np.mean(original_hsv[:, :, 2]) + np.testing.assert_allclose( + np.mean(result_hsv[:, :, 2]) / np.mean(original_hsv[:, :, 2]), + brightness_coefficient, + rtol=0.1 # Allow 10% tolerance due to rounding and blur effects + ) + else: + # For brightness_coefficient = 1.0, brightness might slightly increase + # due to bright rain drops and blur, but shouldn't change dramatically + np.testing.assert_allclose( + np.mean(result_hsv[:, :, 2]) / np.mean(original_hsv[:, :, 2]), + 1.0, + rtol=0.1 # Allow 10% tolerance + ) + + +def test_add_rain_drops_visibility(): + """Test that rain drops are actually visible in the output""" + img = np.zeros((20, 20, 3), dtype=np.uint8) + rain_drops = [(5, 5)] + drop_color = (255, 255, 255) + + result = fmain.add_rain( + img=img, + slant=0, + drop_length=5, + drop_width=1, + drop_color=drop_color, + blur_value=1, # Minimal blur to check drop visibility + brightness_coefficient=1.0, # No brightness change + rain_drops=rain_drops, + ) + + # Check if any pixels have the rain drop color + assert np.any(result > 0) + + +def test_add_rain_preserves_input(): + """Test that the function doesn't modify the input image""" + img = np.zeros((10, 10, 3), dtype=np.uint8) + img_copy = img.copy() + + fmain.add_rain( + img=img, + slant=5, + drop_length=3, + drop_width=1, + drop_color=(200, 200, 200), + blur_value=3, + brightness_coefficient=0.7, + rain_drops=[(5, 5)], + ) + + np.testing.assert_array_equal(img, img_copy) From d6f8c057b628e65ad96cdc7a6efb88829cb9cbd3 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Wed, 29 Jan 2025 18:51:27 -0800 Subject: [PATCH 2/5] Speed up in RandomRain --- albumentations/augmentations/functional.py | 46 ++++++++++++---------- albumentations/augmentations/transforms.py | 17 +++++--- tests/functional/test_functional.py | 12 +++--- 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index 6452e0b14..f24a4ce37 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -746,36 +746,42 @@ def add_rain( drop_color: tuple[int, int, int], blur_value: int, brightness_coefficient: float, - rain_drops: list[tuple[int, int]], + rain_drops: np.ndarray, ) -> np.ndarray: - """Adds rain drops to the image with optimized performance.""" + """Optimized version of add_rain using vectorized operations.""" img = img.copy() - # Vectorize rain drop coordinates - rain_drops = np.array(rain_drops) - if len(rain_drops) == 0: + if not rain_drops.size: return img - # Calculate all end points at once - end_points = rain_drops + np.array([slant, drop_length]) + # Generate all points for all rain drops at once + steps = max(drop_length, abs(slant)) + t = np.linspace(0, 1, steps) - # Draw all lines at once using polylines - lines = np.stack([rain_drops, end_points], axis=1) - cv2.polylines( - img, - lines.astype(np.int32), - False, - drop_color, - drop_width, - lineType=cv2.LINE_AA, # Use anti-aliasing for better quality - ) + # Calculate all points along all rain drops + x_steps = rain_drops[:, 0, None] + (slant * t[None, :]).astype(int) + y_steps = rain_drops[:, 1, None] + (drop_length * t[None, :]).astype(int) + + # Clip coordinates to image boundaries + height, width = img.shape[:2] + x_steps = np.clip(x_steps, 0, width - 1) + y_steps = np.clip(y_steps, 0, height - 1) + + # Create mask for valid coordinates + valid_mask = (x_steps >= 0) & (x_steps < width) & (y_steps >= 0) & (y_steps < height) + + # Draw all rain drops at once + if drop_width == 1: + img[y_steps[valid_mask], x_steps[valid_mask]] = drop_color + else: + for w in range(-drop_width // 2, drop_width // 2 + 1): + x = np.clip(x_steps + w, 0, width - 1) + img[y_steps[valid_mask], x[valid_mask]] = drop_color - # Optimize blur operation + # Apply blur and brightness adjustments if blur_value > 1: - # Use a faster blur method img = cv2.boxFilter(img, -1, (blur_value, blur_value), normalize=True) - # Optimize brightness adjustment using HSV if brightness_coefficient != 1.0: img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) img[:, :, 2] = cv2.multiply(img[:, :, 2], brightness_coefficient) diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index 91bbad2d3..3ecd35bd3 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -835,13 +835,18 @@ def get_params_dependent_on_data( drop_length = self.drop_length num_drops = area // 600 - rain_drops = [] - - for _ in range(num_drops): # If You want heavy rain, try increasing this - x = self.py_random.randint(slant, width) if slant < 0 else self.py_random.randint(0, max(width - slant, 0)) - y = self.py_random.randint(0, max(height - drop_length, 0)) + # Vectorized rain drop generation + num_drops = int(num_drops) + if num_drops > 0: + if slant < 0: + x = self.random_generator.integers(slant, width, size=num_drops) + else: + x = self.random_generator.integers(0, max(width - slant, 1), size=num_drops) - rain_drops.append((x, y)) + y = self.random_generator.integers(0, max(height - drop_length, 1), size=num_drops) + rain_drops = np.column_stack((x, y)) + else: + rain_drops = np.empty((0, 2), dtype=np.int32) return {"drop_length": drop_length, "slant": slant, "rain_drops": rain_drops} diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index e9546f7c0..797a534fb 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -2341,7 +2341,7 @@ def test_gaussian_illumination_sigma(sigma, expected_pattern): (200, 200, 200), 3, 0.7, - [(2, 2)], + np.array([(2, 2)]), (10, 10, 3), ), # Test with no rain drops @@ -2353,7 +2353,7 @@ def test_gaussian_illumination_sigma(sigma, expected_pattern): (200, 200, 200), 3, 0.7, - [], + np.array([]).reshape(0, 2), (20, 20, 3), ), # Test with multiple rain drops @@ -2365,7 +2365,7 @@ def test_gaussian_illumination_sigma(sigma, expected_pattern): (255, 255, 255), 5, 0.8, - [(5, 5), (10, 10), (15, 15)], + np.array([(5, 5), (10, 10), (15, 15)]), (30, 30, 3), ), ] @@ -2384,7 +2384,7 @@ def test_add_rain_shape_and_type( def test_add_rain_brightness(brightness_coefficient): """Test that brightness coefficient correctly affects image brightness""" img = np.full((20, 20, 3), 100, dtype=np.uint8) - rain_drops = [(5, 5)] + rain_drops = np.array([(5, 5)]) result = fmain.add_rain( img=img, @@ -2422,7 +2422,7 @@ def test_add_rain_brightness(brightness_coefficient): def test_add_rain_drops_visibility(): """Test that rain drops are actually visible in the output""" img = np.zeros((20, 20, 3), dtype=np.uint8) - rain_drops = [(5, 5)] + rain_drops = np.array([(5, 5)]) drop_color = (255, 255, 255) result = fmain.add_rain( @@ -2453,7 +2453,7 @@ def test_add_rain_preserves_input(): drop_color=(200, 200, 200), blur_value=3, brightness_coefficient=0.7, - rain_drops=[(5, 5)], + rain_drops=np.array([(5, 5)]), ) np.testing.assert_array_equal(img, img_copy) From 6b71405e798ccd6949cb17f3935f72025a5250a9 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Wed, 29 Jan 2025 18:55:58 -0800 Subject: [PATCH 3/5] Speed up in RandomRain --- albumentations/augmentations/functional.py | 59 +++++++++++++--------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index f24a4ce37..a03088688 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -748,44 +748,53 @@ def add_rain( brightness_coefficient: float, rain_drops: np.ndarray, ) -> np.ndarray: - """Optimized version of add_rain using vectorized operations.""" + """Optimized version of add_rain using vectorized operations and OpenCV.""" + if not rain_drops.size: + return img.copy() + img = img.copy() - if not rain_drops.size: - return img + height, width = img.shape[:2] + + # Create rain layer instead of modifying image directly + rain_layer = np.zeros_like(img) - # Generate all points for all rain drops at once + # Generate all points for all rain drops at once (fewer steps, use integers) steps = max(drop_length, abs(slant)) - t = np.linspace(0, 1, steps) + t = np.linspace(0, 1, steps, dtype=np.float32) # Use float32 for faster computation - # Calculate all points along all rain drops - x_steps = rain_drops[:, 0, None] + (slant * t[None, :]).astype(int) - y_steps = rain_drops[:, 1, None] + (drop_length * t[None, :]).astype(int) + # Pre-compute all coordinates at once + x_steps = (rain_drops[:, 0, None] + (slant * t[None, :])).astype(np.int32) + y_steps = (rain_drops[:, 1, None] + (drop_length * t[None, :])).astype(np.int32) - # Clip coordinates to image boundaries - height, width = img.shape[:2] - x_steps = np.clip(x_steps, 0, width - 1) - y_steps = np.clip(y_steps, 0, height - 1) + # Vectorize the width handling + if drop_width > 1: + w_offsets = np.arange(-drop_width // 2, drop_width // 2 + 1) + x_steps = x_steps[..., None] + w_offsets + y_steps = np.repeat(y_steps[..., None], len(w_offsets), axis=2) + x_steps = x_steps.reshape(-1) + y_steps = y_steps.reshape(-1) - # Create mask for valid coordinates + # Single mask operation valid_mask = (x_steps >= 0) & (x_steps < width) & (y_steps >= 0) & (y_steps < height) + x_steps = x_steps[valid_mask] + y_steps = y_steps[valid_mask] - # Draw all rain drops at once - if drop_width == 1: - img[y_steps[valid_mask], x_steps[valid_mask]] = drop_color - else: - for w in range(-drop_width // 2, drop_width // 2 + 1): - x = np.clip(x_steps + w, 0, width - 1) - img[y_steps[valid_mask], x[valid_mask]] = drop_color + # Draw all points at once + rain_layer[y_steps, x_steps] = drop_color - # Apply blur and brightness adjustments + # Use OpenCV operations for blending and adjustments if blur_value > 1: - img = cv2.boxFilter(img, -1, (blur_value, blur_value), normalize=True) + rain_layer = cv2.blur(rain_layer, (blur_value, blur_value)) + + # Blend rain with original image + cv2.add(img, rain_layer, dst=img) + # Adjust brightness if needed if brightness_coefficient != 1.0: - img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) - img[:, :, 2] = cv2.multiply(img[:, :, 2], brightness_coefficient) - img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + # Use LUT for faster brightness adjustment + brightness_lut = np.clip(np.arange(0, 256) * brightness_coefficient, 0, 255).astype(np.uint8) + return cv2.LUT(img, brightness_lut) return img From 39553de96eebbe2f6759a4664b4eb5a371c70da6 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Wed, 29 Jan 2025 19:18:53 -0800 Subject: [PATCH 4/5] Speed up in RandomRain --- albumentations/augmentations/functional.py | 51 +++++++--------------- albumentations/augmentations/transforms.py | 42 +++++++++--------- 2 files changed, 36 insertions(+), 57 deletions(-) diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index a03088688..61f561788 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -748,53 +748,32 @@ def add_rain( brightness_coefficient: float, rain_drops: np.ndarray, ) -> np.ndarray: - """Optimized version of add_rain using vectorized operations and OpenCV.""" + """Optimized version using OpenCV line drawing.""" if not rain_drops.size: return img.copy() img = img.copy() - height, width = img.shape[:2] - - # Create rain layer instead of modifying image directly - rain_layer = np.zeros_like(img) - - # Generate all points for all rain drops at once (fewer steps, use integers) - steps = max(drop_length, abs(slant)) - t = np.linspace(0, 1, steps, dtype=np.float32) # Use float32 for faster computation - - # Pre-compute all coordinates at once - x_steps = (rain_drops[:, 0, None] + (slant * t[None, :])).astype(np.int32) - y_steps = (rain_drops[:, 1, None] + (drop_length * t[None, :])).astype(np.int32) - - # Vectorize the width handling - if drop_width > 1: - w_offsets = np.arange(-drop_width // 2, drop_width // 2 + 1) - x_steps = x_steps[..., None] + w_offsets - y_steps = np.repeat(y_steps[..., None], len(w_offsets), axis=2) - x_steps = x_steps.reshape(-1) - y_steps = y_steps.reshape(-1) - - # Single mask operation - valid_mask = (x_steps >= 0) & (x_steps < width) & (y_steps >= 0) & (y_steps < height) - x_steps = x_steps[valid_mask] - y_steps = y_steps[valid_mask] - - # Draw all points at once - rain_layer[y_steps, x_steps] = drop_color + # Pre-allocate rain layer + rain_layer = np.zeros_like(img, dtype=np.uint8) + + # Single polylines call with direct end points calculation + cv2.polylines( + rain_layer, + np.stack([rain_drops, [*rain_drops, [slant, drop_length]]], axis=1), + False, + drop_color, + drop_width, + lineType=cv2.LINE_4, + ) - # Use OpenCV operations for blending and adjustments if blur_value > 1: - rain_layer = cv2.blur(rain_layer, (blur_value, blur_value)) + cv2.blur(rain_layer, (blur_value, blur_value), dst=rain_layer) - # Blend rain with original image cv2.add(img, rain_layer, dst=img) - # Adjust brightness if needed if brightness_coefficient != 1.0: - # Use LUT for faster brightness adjustment - brightness_lut = np.clip(np.arange(0, 256) * brightness_coefficient, 0, 255).astype(np.uint8) - return cv2.LUT(img, brightness_lut) + cv2.multiply(img, brightness_coefficient, dst=img, dtype=cv2.CV_8U) return img diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index 3ecd35bd3..fbf57550e 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -796,7 +796,7 @@ def apply( img: np.ndarray, slant: int, drop_length: int, - rain_drops: list[tuple[int, int]], + rain_drops: np.ndarray, **params: Any, ) -> np.ndarray: non_rgb_error(img) @@ -817,34 +817,34 @@ def get_params_dependent_on_data( params: dict[str, Any], data: dict[str, Any], ) -> dict[str, Any]: - slant = int(self.py_random.uniform(*self.slant_range)) - height, width = params["shape"][:2] - area = height * width + # Simpler calculations, directly following Kornia if self.rain_type == "drizzle": - num_drops = area // 770 - drop_length = 10 + num_drops = height // 4 elif self.rain_type == "heavy": - num_drops = width * height // 600 - drop_length = 30 + num_drops = height elif self.rain_type == "torrential": - num_drops = area // 500 - drop_length = 60 + num_drops = height * 2 else: - drop_length = self.drop_length - num_drops = area // 600 + num_drops = height // 3 - # Vectorized rain drop generation - num_drops = int(num_drops) - if num_drops > 0: - if slant < 0: - x = self.random_generator.integers(slant, width, size=num_drops) - else: - x = self.random_generator.integers(0, max(width - slant, 1), size=num_drops) + # Fixed proportion for drop length (like Kornia) + drop_length = max(1, height // 8) - y = self.random_generator.integers(0, max(height - drop_length, 1), size=num_drops) - rain_drops = np.column_stack((x, y)) + # Simplified slant calculation + slant = self.random_generator.integers(-width // 50, width // 50) + + # Single random call for all coordinates + if num_drops > 0: + # Generate all coordinates in one call + coords = self.random_generator.integers( + low=[0, 0], + high=[width, height - drop_length], + size=(num_drops, 2), + dtype=np.int32, + ) + rain_drops = coords else: rain_drops = np.empty((0, 2), dtype=np.int32) From 14b86cdc486a31ce473a1a76c8c03a2f5f835391 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Thu, 30 Jan 2025 12:01:41 -0800 Subject: [PATCH 5/5] Speed up in RandomRain --- .pre-commit-config.yaml | 2 +- albumentations/augmentations/functional.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9335cdab2..d223c9ad1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: files: setup.py - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.9.3 + rev: v0.9.4 hooks: # Run the linter. - id: ruff diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index 61f561788..a2c7b45ab 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -757,10 +757,15 @@ def add_rain( # Pre-allocate rain layer rain_layer = np.zeros_like(img, dtype=np.uint8) - # Single polylines call with direct end points calculation + # Calculate end points correctly + end_points = rain_drops + np.array([[slant, drop_length]]) # This creates correct shape + + # Stack arrays properly - both must be same shape arrays + lines = np.stack((rain_drops, end_points), axis=1) # Use tuple and proper axis + cv2.polylines( rain_layer, - np.stack([rain_drops, [*rain_drops, [slant, drop_length]]], axis=1), + lines.astype(np.int32), False, drop_color, drop_width,