From 5a92b64b9126b1447e165efd5da1a09208fad2bf Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Wed, 22 Jan 2025 12:59:35 -0800 Subject: [PATCH] Fix --- .pre-commit-config.yaml | 2 +- .../augmentations/geometric/functional.py | 102 +++++--------- .../augmentations/geometric/transforms.py | 5 - albumentations/core/bbox_utils.py | 75 +++++++++- tests/test_bbox.py | 129 ++++++++++++++++++ 5 files changed, 236 insertions(+), 77 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 11b34c684..f053ec75f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,7 +68,7 @@ repos: - id: python-use-type-annotations - id: text-unicode-replacement-char - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 + rev: v2.4.0 hooks: - id: codespell additional_dependencies: ["tomli"] diff --git a/albumentations/augmentations/geometric/functional.py b/albumentations/augmentations/geometric/functional.py index 8fcad9528..7762345c2 100644 --- a/albumentations/augmentations/geometric/functional.py +++ b/albumentations/augmentations/geometric/functional.py @@ -20,7 +20,9 @@ from albumentations.augmentations.utils import angle_2pi_range, handle_empty_array from albumentations.core.bbox_utils import ( bboxes_from_masks, + bboxes_to_mask, denormalize_bboxes, + mask_to_bboxes, masks_from_bboxes, normalize_bboxes, ) @@ -1444,30 +1446,25 @@ def remap_keypoints( ) -> np.ndarray: height, width = image_shape[:2] - # Create inverse mappings - x_inv = np.arange(width).reshape(1, -1).repeat(height, axis=0) - y_inv = np.arange(height).reshape(-1, 1).repeat(width, axis=1) + # Create mask where each keypoint has unique index + kp_mask = np.zeros((height, width), dtype=np.int16) + for idx, kp in enumerate(keypoints, start=1): + x, y = round(kp[0]), round(kp[1]) + if 0 <= x < width and 0 <= y < height: + cv2.circle(kp_mask, (x, y), 1, idx, -1) - # Extract x and y coordinates - x, y = keypoints[:, 0], keypoints[:, 1] - - # Clip coordinates to image boundaries - x = np.clip(x, 0, width - 1, out=x) - y = np.clip(y, 0, height - 1, out=y) - - # Convert to integer indices - x_idx, y_idx = x.astype(int), y.astype(int) + # Remap the mask + transformed_kp_mask = cv2.remap(kp_mask, map_x, map_y, cv2.INTER_NEAREST) - # Apply the inverse mapping - new_x = x_inv[y_idx, x_idx] + (x - map_x[y_idx, x_idx]) - new_y = y_inv[y_idx, x_idx] + (y - map_y[y_idx, x_idx]) + # Extract transformed keypoints + new_points = [] + for idx, kp in enumerate(keypoints, start=1): + y_coords, x_coords = np.where(transformed_kp_mask == idx) + if len(y_coords) > 0: + # Take first occurrence of the point + new_points.append(np.concatenate([[x_coords[0], y_coords[0]], kp[2:]])) - # Clip the new coordinates to ensure they're within the image bounds - new_x = np.clip(new_x, 0, width - 1, out=new_x) - new_y = np.clip(new_y, 0, height - 1, out=new_y) - - # Create the transformed keypoints array - return np.column_stack([new_x, new_y, keypoints[:, 2:]]) + return np.array(new_points) if new_points else np.zeros((0, keypoints.shape[1])) @handle_empty_array("bboxes") @@ -1477,53 +1474,18 @@ def remap_bboxes( map_y: np.ndarray, image_shape: tuple[int, int], ) -> np.ndarray: - # Number of points to sample per dimension - grid_size = 5 - - num_boxes = len(bboxes) - all_points = [] - - for box in bboxes: - x_min, y_min, x_max, y_max = box[:4] - - # Create grid of points inside and on edges of box - x_points = np.linspace(x_min, x_max, grid_size) - y_points = np.linspace(y_min, y_max, grid_size) - xx, yy = np.meshgrid(x_points, y_points) - - points = np.column_stack([xx.ravel(), yy.ravel()]) - all_points.append(points) - - # Transform all points - all_points = np.vstack(all_points) - transformed_points = remap_keypoints( - np.column_stack( - [all_points, np.zeros(len(all_points)), np.zeros(len(all_points))], - ), - map_x, - map_y, - image_shape, - )[:, :2] + """Remap bounding boxes using displacement maps.""" + # Convert bboxes to mask + bbox_masks = bboxes_to_mask(bboxes, image_shape) - # Reshape back to per-box points - points_per_box = grid_size * grid_size - transformed_points = transformed_points.reshape(num_boxes, points_per_box, 2) + # Ensure maps are float32 + map_x = map_x.astype(np.float32) + map_y = map_y.astype(np.float32) - # Get min/max coordinates for each box - new_bboxes = np.column_stack( - [ - np.min(transformed_points[:, :, 0], axis=1), # x_min - np.min(transformed_points[:, :, 1], axis=1), # y_min - np.max(transformed_points[:, :, 0], axis=1), # x_max - np.max(transformed_points[:, :, 1], axis=1), # y_max - ], - ) + transformed_masks = remap(bbox_masks, map_x, map_y, cv2.INTER_NEAREST, cv2.BORDER_CONSTANT, value=0) - return ( - np.column_stack([new_bboxes, bboxes[:, 4:]]) - if bboxes.shape[1] > NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS - else new_bboxes - ) + # Convert masks back to bboxes + return mask_to_bboxes(transformed_masks, bboxes) def generate_displacement_fields( @@ -3270,7 +3232,6 @@ def tps_transform( def get_camera_matrix_distortion_maps( image_shape: tuple[int, int], k: float, - center_xy: tuple[float, float], ) -> tuple[np.ndarray, np.ndarray]: """Generate distortion maps using camera matrix model. @@ -3284,8 +3245,11 @@ def get_camera_matrix_distortion_maps( - map_y: Vertical displacement map """ height, width = image_shape[:2] + + center_x, center_y = width / 2, height / 2 + camera_matrix = np.array( - [[width, 0, center_xy[0]], [0, height, center_xy[1]], [0, 0, 1]], + [[width, 0, center_x], [0, height, center_y], [0, 0, 1]], dtype=np.float32, ) distortion = np.array([k, k, 0, 0, 0], dtype=np.float32) @@ -3302,7 +3266,6 @@ def get_camera_matrix_distortion_maps( def get_fisheye_distortion_maps( image_shape: tuple[int, int], k: float, - center_xy: tuple[float, float], ) -> tuple[np.ndarray, np.ndarray]: """Generate distortion maps using fisheye model. @@ -3317,8 +3280,7 @@ def get_fisheye_distortion_maps( """ height, width = image_shape[:2] - center_x, center_y = center_xy - + center_x, center_y = width / 2, height / 2 # Create coordinate grid y, x = np.mgrid[:height, :width].astype(np.float32) diff --git a/albumentations/augmentations/geometric/transforms.py b/albumentations/augmentations/geometric/transforms.py index 34f55ce8a..bd1200a1c 100644 --- a/albumentations/augmentations/geometric/transforms.py +++ b/albumentations/augmentations/geometric/transforms.py @@ -1479,21 +1479,16 @@ def get_params_dependent_on_data( # Get distortion coefficient k = self.py_random.uniform(*self.distort_limit) - # Calculate center shift - center_xy = fgeometric.center(image_shape) - # Get distortion maps based on mode if self.mode == "camera": map_x, map_y = fgeometric.get_camera_matrix_distortion_maps( image_shape, k, - center_xy, ) else: # fisheye map_x, map_y = fgeometric.get_fisheye_distortion_maps( image_shape, k, - center_xy, ) return {"map_x": map_x, "map_y": map_y} diff --git a/albumentations/core/bbox_utils.py b/albumentations/core/bbox_utils.py index 7b58e27f9..c47538c00 100644 --- a/albumentations/core/bbox_utils.py +++ b/albumentations/core/bbox_utils.py @@ -6,7 +6,7 @@ import numpy as np from albumentations.augmentations.utils import handle_empty_array -from albumentations.core.type_definitions import MONO_CHANNEL_DIMENSIONS +from albumentations.core.type_definitions import MONO_CHANNEL_DIMENSIONS, NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS from .utils import DataProcessor, Params, ShapeType @@ -593,3 +593,76 @@ def masks_from_bboxes(bboxes: np.ndarray, shape: ShapeType | tuple[int, int]) -> masks[i] = (x_min <= x) & (x < x_max) & (y_min <= y) & (y < y_max) return masks + + +def bboxes_to_mask( + bboxes: np.ndarray, + image_shape: tuple[int, int], +) -> np.ndarray: + """Convert bounding boxes to multi-channel binary mask. + + Args: + bboxes: Array of bboxes in format [x_min, y_min, x_max, y_max, ...] + image_shape: (height, width) of the target mask + + Returns: + Binary mask of shape (height, width, num_boxes) + """ + height, width = image_shape[:2] + num_boxes = len(bboxes) + + # Create multi-channel mask where each channel represents one bbox + bbox_masks = np.zeros((height, width, num_boxes), dtype=np.uint8) + + # Fill each bbox in its channel + for idx, box in enumerate(bboxes): + x_min, y_min, x_max, y_max = map(round, box[:4]) + x_min = max(0, min(width - 1, x_min)) + x_max = max(0, min(width - 1, x_max)) + y_min = max(0, min(height - 1, y_min)) + y_max = max(0, min(height - 1, y_max)) + bbox_masks[y_min : y_max + 1, x_min : x_max + 1, idx] = 1 + + return bbox_masks + + +def mask_to_bboxes( + masks: np.ndarray, + original_bboxes: np.ndarray, +) -> np.ndarray: + """Convert multi-channel binary mask back to bounding boxes. + + Args: + masks: Binary mask of shape (height, width, num_boxes) + original_bboxes: Original bboxes array to preserve extra columns + + Returns: + Array of bboxes in format [x_min, y_min, x_max, y_max, ...] + """ + num_boxes = masks.shape[-1] + new_bboxes = [] + + num_boxes = masks.shape[-1] + + if num_boxes == 0: + # Return empty array with correct shape + return np.zeros((0, original_bboxes.shape[1]), dtype=original_bboxes.dtype) + + for idx in range(num_boxes): + mask = masks[..., idx] + if np.any(mask): + y_coords, x_coords = np.where(mask) + x_min, x_max = x_coords.min(), x_coords.max() + y_min, y_max = y_coords.min(), y_coords.max() + new_bboxes.append([x_min, y_min, x_max, y_max]) + else: + # If bbox disappeared, use original coords + new_bboxes.append(original_bboxes[idx, :4]) + + new_bboxes = np.array(new_bboxes) + + return ( + np.column_stack([new_bboxes, original_bboxes[:, 4:]]) + if original_bboxes.shape[1] > NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS + else new_bboxes + ) diff --git a/tests/test_bbox.py b/tests/test_bbox.py index 99bd9b666..615d67ac4 100644 --- a/tests/test_bbox.py +++ b/tests/test_bbox.py @@ -22,6 +22,8 @@ masks_from_bboxes, normalize_bboxes, union_of_bboxes, + bboxes_to_mask, + mask_to_bboxes, ) from albumentations.core.composition import BboxParams, Compose, ReplayCompose from albumentations.core.transforms_interface import BasicTransform, NoOp @@ -1727,3 +1729,130 @@ def test_bboxes_grid_shuffle_multiple_components(): assert np.all(result >= 0) # All coordinates should be valid assert np.all(result[:, [0, 2]] <= image_shape[1]) # x coordinates within image width assert np.all(result[:, [1, 3]] <= image_shape[0]) # y coordinates within image height + + + +@pytest.mark.parametrize("test_case", [ + { + "name": "single_bbox", + "bboxes": np.array([[10, 20, 30, 40]]), + "image_shape": (100, 100), + "expected_mask_shape": (100, 100, 1), + "expected_nonzero": [(20, 10, 0), (20, 30, 0), (40, 10, 0), (40, 30, 0)] # y, x, channel + }, + { + "name": "multiple_bboxes", + "bboxes": np.array([ + [10, 20, 30, 40], + [50, 60, 70, 80] + ]), + "image_shape": (100, 100), + "expected_mask_shape": (100, 100, 2), + "expected_nonzero": [(20, 10, 0), (60, 50, 1)] # y, x, channel + }, + { + "name": "bbox_with_extra_fields", + "bboxes": np.array([[10, 20, 30, 40, 1, 0.8]]), + "image_shape": (100, 100), + "expected_mask_shape": (100, 100, 1), + "expected_nonzero": [(20, 10, 0)] + }, + { + "name": "bbox_at_edges", + "bboxes": np.array([[0, 0, 99, 99]]), + "image_shape": (100, 100), + "expected_mask_shape": (100, 100, 1), + "expected_nonzero": [(0, 0, 0), (99, 99, 0)] + } +]) +def test_bboxes_to_mask(test_case): + bboxes = test_case["bboxes"] + image_shape = test_case["image_shape"] + expected_shape = test_case["expected_mask_shape"] + expected_nonzero = test_case["expected_nonzero"] + + masks = bboxes_to_mask(bboxes, image_shape) + + # Check shape + assert masks.shape == expected_shape + + # Check dtype + assert masks.dtype == np.uint8 + + # Check if masks are binary + assert np.all(np.isin(masks, [0, 1])) + + # Check specific points + for y, x, c in expected_nonzero: + assert masks[y, x, c] == 1, f"Expected 1 at position ({y}, {x}, {c})" + +@pytest.mark.parametrize("test_case", [ + { + "name": "single_mask", + "masks": np.array([ # (3, 3, 1) mask + [[0], [0], [0]], + [[0], [1], [0]], + [[0], [0], [0]] + ], dtype=np.uint8), + "original_bboxes": np.array([[0, 0, 2, 2]]), + "expected_bboxes": np.array([[1, 1, 1, 1]]) + }, + { + "name": "multiple_masks", + "masks": np.array([ # (3, 3, 2) mask + [[0, 0], [0, 1], [0, 0]], + [[0, 1], [1, 1], [0, 1]], + [[0, 0], [0, 1], [0, 0]] + ], dtype=np.uint8), + "original_bboxes": np.array([ + [0, 0, 2, 2], + [0, 0, 2, 2] + ]), + "expected_bboxes": np.array([ + [1, 1, 1, 1], + [0, 0, 2, 2] + ]) + }, + { + "name": "mask_with_extra_fields", + "masks": np.array([ # (3, 3, 1) mask + [[0], [0], [0]], + [[0], [1], [0]], + [[0], [0], [0]] + ], dtype=np.uint8), + "original_bboxes": np.array([[0, 0, 2, 2, 1, 0.8]]), + "expected_bboxes": np.array([[1, 1, 1, 1, 1, 0.8]]) + }, + { + "name": "empty_mask", + "masks": np.zeros((3, 3, 1), dtype=np.uint8), + "original_bboxes": np.array([[10, 20, 30, 40]]), + "expected_bboxes": np.array([[10, 20, 30, 40]]) # Should preserve original bbox + } +]) +def test_mask_to_bboxes(test_case): + masks = test_case["masks"] + original_bboxes = test_case["original_bboxes"] + expected_bboxes = test_case["expected_bboxes"] + + result = mask_to_bboxes(masks, original_bboxes) + + # Check shape and values + assert result.shape == expected_bboxes.shape + np.testing.assert_array_equal(result, expected_bboxes) + + # Check extra fields preservation + if original_bboxes.shape[1] > 4: + assert np.all(result[:, 4:] == original_bboxes[:, 4:]) + +def test_empty_bboxes(): + empty_bboxes = np.zeros((0, 4)) + image_shape = (100, 100) + + # Test bboxes_to_mask with empty input + masks = bboxes_to_mask(empty_bboxes, image_shape) + assert masks.shape == (100, 100, 0) + + # Test mask_to_bboxes with empty input + result = mask_to_bboxes(masks, empty_bboxes) + assert result.shape == (0, 4)