From 9b1a3b782d6a977403abf1854f40cd058cea5542 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Fri, 31 Jan 2025 12:33:17 -0800 Subject: [PATCH 1/4] Correctly process np labels --- albumentations/core/utils.py | 20 ++++++++- tests/test_core.py | 80 ++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/albumentations/core/utils.py b/albumentations/core/utils.py index 45d7a8724..120e115b4 100644 --- a/albumentations/core/utils.py +++ b/albumentations/core/utils.py @@ -144,6 +144,8 @@ def __init__(self, params: Params, additional_targets: dict[str, str] | None = N self.label_encoders: dict[str, dict[str, LabelEncoder]] = defaultdict(dict) self.is_sequence_input: dict[str, bool] = {} self.is_numerical_label: dict[str, dict[str, bool]] = defaultdict(dict) + self.label_dtypes: dict[str, dict[str, np.dtype]] = defaultdict(dict) + self.label_input_types: dict[str, dict[str, type]] = defaultdict(dict) if additional_targets is not None: self.add_targets(additional_targets) @@ -276,6 +278,10 @@ def _validate_label_field_length(self, data: dict[str, Any], data_name: str, lab def _encode_label_field(self, data: dict[str, Any], data_name: str, label_field: str) -> np.ndarray: field_data = data[label_field] + self.label_input_types[data_name][label_field] = type(field_data) + if isinstance(field_data, np.ndarray): + self.label_dtypes[data_name][label_field] = field_data.dtype + # Check if input is numpy array or if all elements are numerical is_numerical = (isinstance(field_data, np.ndarray) and np.issubdtype(field_data.dtype, np.number)) or all( isinstance(label, (int, float)) for label in field_data @@ -284,7 +290,7 @@ def _encode_label_field(self, data: dict[str, Any], data_name: str, label_field: self.is_numerical_label[data_name][label_field] = is_numerical if is_numerical: - # For numerical values, preserve numpy arrays or convert to float32 + # For numerical values, convert to float32 for processing if isinstance(field_data, np.ndarray): return field_data.reshape(-1, 1).astype(np.float32) return np.array(field_data, dtype=np.float32).reshape(-1, 1) @@ -323,12 +329,22 @@ def _remove_label_fields(self, data: dict[str, Any], data_name: str) -> None: for idx, label_field in enumerate(self.params.label_fields): encoded_labels = data_array[:, non_label_columns + idx] decoded_labels = self._decode_label_field(data_name, label_field, encoded_labels) - data[label_field] = decoded_labels.tolist() + + # Convert back to original type (list or numpy array) + input_type = self.label_input_types[data_name][label_field] + if isinstance(input_type, list): + data[label_field] = decoded_labels.tolist() + else: # numpy array + data[label_field] = decoded_labels data[data_name] = data_array[:, :non_label_columns] def _decode_label_field(self, data_name: str, label_field: str, encoded_labels: np.ndarray) -> np.ndarray: if self.is_numerical_label[data_name][label_field]: + # Restore original dtype if it was stored + original_dtype = self.label_dtypes.get(data_name, {}).get(label_field) + if original_dtype is not None: + return encoded_labels.astype(original_dtype) return encoded_labels encoder = self.label_encoders.get(data_name, {}).get(label_field) diff --git a/tests/test_core.py b/tests/test_core.py index 68d483fef..ce430b104 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1608,3 +1608,83 @@ def test_transform_strict_with_valid_params(): transform = A.Blur(strict=True, p=0.7, blur_limit=(3, 5)) assert transform.p == 0.7 assert transform.blur_limit == (3, 5) + + +@pytest.mark.parametrize( + ["labels", "expected_type", "expected_dtype"], + [ + # Numpy arrays should stay numpy arrays + (np.array([1, 2, 3], dtype=np.int32), np.ndarray, np.int32), + (np.array([1, 2, 3], dtype=np.int64), np.ndarray, np.int64), + (np.array([1.0, 2.0, 3.0], dtype=np.float32), np.ndarray, np.float32), + (np.array([1.0, 2.0, 3.0], dtype=np.float64), np.ndarray, np.float64), + # Lists should stay lists + ([1, 2, 3], list, None), + ([1.0, 2.0, 3.0], list, None), + ], +) +def test_label_type_preservation(labels, expected_type, expected_dtype): + """Test that both type (list/ndarray) and dtype are preserved.""" + transform = Compose( + [NoOp(p=1.0)], + bbox_params=BboxParams( + format='pascal_voc', + label_fields=['labels'] + ) + ) + + transformed = transform( + image=np.zeros((100, 100, 3), dtype=np.uint8), + bboxes=[(0, 0, 10, 10), (10, 10, 20, 20), (20, 20, 30, 30)], + labels=labels + ) + + result_labels = transformed['labels'] + assert isinstance(result_labels, expected_type) + if expected_dtype is not None: + assert result_labels.dtype == expected_dtype + if expected_type == list: + assert result_labels == labels + else: + np.testing.assert_array_equal(result_labels, labels) + + +def test_string_labels(): + # Create sample data + bboxes = [(0, 0, 10, 10), (10, 10, 20, 20), (20, 20, 30, 30)] + labels = ['cat', 'dog', 'bird'] + + transform = Compose( + [NoOp(p=1.0)], + bbox_params=BboxParams( + format='pascal_voc', + label_fields=['labels'] + ) + ) + + transformed = transform( + image=np.zeros((100, 100, 3), dtype=np.uint8), + bboxes=bboxes, + labels=labels + ) + + # Check that string labels are preserved exactly + assert transformed['labels'] == labels + + +def test_empty_labels(): + transform = Compose( + [NoOp(p=1.0)], + bbox_params=BboxParams( + format='pascal_voc', + label_fields=['labels'] + ) + ) + + transformed = transform( + image=np.zeros((100, 100, 3), dtype=np.uint8), + bboxes=[], + labels=[] + ) + + assert transformed['labels'] == [] From b200d33f4b3ab947aa777532f1b2473aeea5ec2e Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Fri, 31 Jan 2025 13:04:50 -0800 Subject: [PATCH 2/4] Refactoring --- .../augmentations/dropout/coarse_dropout.py | 9 +- albumentations/core/label_manager.py | 145 ++++++++++++++++++ albumentations/core/utils.py | 114 +------------- 3 files changed, 158 insertions(+), 110 deletions(-) create mode 100644 albumentations/core/label_manager.py diff --git a/albumentations/augmentations/dropout/coarse_dropout.py b/albumentations/augmentations/dropout/coarse_dropout.py index d25fcc10a..ca57cf0a8 100644 --- a/albumentations/augmentations/dropout/coarse_dropout.py +++ b/albumentations/augmentations/dropout/coarse_dropout.py @@ -499,9 +499,14 @@ def get_boxes_from_bboxes(self, bboxes: np.ndarray) -> np.ndarray | None: label_fields = bbox_processor.params.label_fields if label_fields is None: raise ValueError("BboxParams.label_fields must be specified when using string labels") + first_class_label = label_fields[0] - label_encoder = bbox_processor.label_encoders["bboxes"][first_class_label] - target_labels = label_encoder.transform(self.bbox_labels) + # Access encoder through label_manager's metadata + metadata = bbox_processor.label_manager.metadata["bboxes"][first_class_label] + if metadata.encoder is None: + raise ValueError(f"No encoder found for label field {first_class_label}") + + target_labels = metadata.encoder.transform(self.bbox_labels) else: target_labels = np.array(self.bbox_labels) diff --git a/albumentations/core/label_manager.py b/albumentations/core/label_manager.py new file mode 100644 index 000000000..557e1eed3 --- /dev/null +++ b/albumentations/core/label_manager.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Sequence +from dataclasses import dataclass +from numbers import Real +from typing import Any + +import numpy as np + + +def custom_sort(item: Any) -> tuple[int, Real | str]: + if isinstance(item, Real): + return (0, item) # Numerical items come first + return (1, str(item)) # Non-numerical items come second, converted to strings + + +class LabelEncoder: + def __init__(self) -> None: + self.classes_: dict[str | Real, int] = {} + self.inverse_classes_: dict[int, str | Real] = {} + self.num_classes: int = 0 + self.is_numerical: bool = True + + def fit(self, y: Sequence[Any] | np.ndarray) -> LabelEncoder: + if isinstance(y, np.ndarray): + y = y.flatten().tolist() + + self.is_numerical = all(isinstance(label, Real) for label in y) + + if self.is_numerical: + return self + + unique_labels = sorted(set(y), key=custom_sort) + for label in unique_labels: + if label not in self.classes_: + self.classes_[label] = self.num_classes + self.inverse_classes_[self.num_classes] = label + self.num_classes += 1 + return self + + def transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray: + if isinstance(y, np.ndarray): + y = y.flatten().tolist() + + if self.is_numerical: + return np.array(y) + + return np.array([self.classes_[label] for label in y]) + + def fit_transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray: + self.fit(y) + return self.transform(y) + + def inverse_transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray: + if isinstance(y, np.ndarray): + y = y.flatten().tolist() + + if self.is_numerical: + return np.array(y) + + return np.array([self.inverse_classes_[label] for label in y]) + + +@dataclass +class LabelMetadata: + """Stores metadata about a label field.""" + + input_type: type + is_numerical: bool + dtype: np.dtype | None = None + encoder: LabelEncoder | None = None + + +class LabelManager: + def __init__(self) -> None: + self.metadata: dict[str, dict[str, LabelMetadata]] = defaultdict(dict) + + def process_field(self, data_name: str, label_field: str, field_data: Any) -> np.ndarray: + """Process a label field and store its metadata.""" + metadata = self._analyze_input(field_data) + self.metadata[data_name][label_field] = metadata + return self._encode_data(field_data, metadata) + + def restore_field(self, data_name: str, label_field: str, encoded_data: np.ndarray) -> Any: + """Restore a label field to its original format.""" + metadata = self.metadata[data_name][label_field] + decoded_data = self._decode_data(encoded_data, metadata) + return self._restore_type(decoded_data, metadata) + + def _analyze_input(self, field_data: Any) -> LabelMetadata: + """Analyze input data and create metadata.""" + input_type = type(field_data) + dtype = field_data.dtype if isinstance(field_data, np.ndarray) else None + + is_numerical = (isinstance(field_data, np.ndarray) and np.issubdtype(field_data.dtype, np.number)) or all( + isinstance(label, (int, float)) for label in field_data + ) + + metadata = LabelMetadata( + input_type=input_type, + is_numerical=is_numerical, + dtype=dtype, + ) + + if not is_numerical: + metadata.encoder = LabelEncoder() + + return metadata + + def _encode_data(self, field_data: Any, metadata: LabelMetadata) -> np.ndarray: + """Encode field data for processing.""" + if metadata.is_numerical: + # For numerical values, convert to float32 for processing + if isinstance(field_data, np.ndarray): + return field_data.reshape(-1, 1).astype(np.float32) + return np.array(field_data, dtype=np.float32).reshape(-1, 1) + + # For non-numerical values, use LabelEncoder + if metadata.encoder is None: + raise ValueError("Encoder not initialized for non-numerical data") + return metadata.encoder.fit_transform(field_data).reshape(-1, 1) + + def _decode_data(self, encoded_data: np.ndarray, metadata: LabelMetadata) -> np.ndarray: + """Decode processed data.""" + if metadata.is_numerical: + if metadata.dtype is not None: + return encoded_data.astype(metadata.dtype) + return encoded_data.flatten() # Flatten for list conversion + + if metadata.encoder is None: + raise ValueError("Encoder not found for non-numerical data") + return metadata.encoder.inverse_transform(encoded_data.astype(int)) + + def _restore_type(self, decoded_data: np.ndarray, metadata: LabelMetadata) -> Any: + """Restore data to its original type.""" + if isinstance(metadata.input_type, list): + return decoded_data.tolist() + if isinstance(metadata.input_type, np.ndarray) and metadata.dtype is not None: + return decoded_data.astype(metadata.dtype) + return decoded_data + + def handle_empty_data(self) -> list[Any]: + """Handle empty data case.""" + return [] diff --git a/albumentations/core/utils.py b/albumentations/core/utils.py index 120e115b4..c34818471 100644 --- a/albumentations/core/utils.py +++ b/albumentations/core/utils.py @@ -1,13 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections import defaultdict from collections.abc import Sequence from numbers import Real from typing import TYPE_CHECKING, Any, Literal, cast, overload import numpy as np +from albumentations.core.label_manager import LabelManager + from .serialization import Serializable from .type_definitions import PAIR, Number @@ -75,59 +76,6 @@ def format_args(args_dict: dict[str, Any]) -> str: return ", ".join(formatted_args) -def custom_sort(item: Any) -> tuple[int, Real | str]: - if isinstance(item, Real): - return (0, item) # Numerical items come first - return (1, str(item)) # Non-numerical items come second, converted to strings - - -class LabelEncoder: - def __init__(self) -> None: - self.classes_: dict[str | Real, int] = {} - self.inverse_classes_: dict[int, str | Real] = {} - self.num_classes: int = 0 - self.is_numerical: bool = True - - def fit(self, y: Sequence[Any] | np.ndarray) -> LabelEncoder: - if isinstance(y, np.ndarray): - y = y.flatten().tolist() - - self.is_numerical = all(isinstance(label, Real) for label in y) - - if self.is_numerical: - return self - - unique_labels = sorted(set(y), key=custom_sort) - for label in unique_labels: - if label not in self.classes_: - self.classes_[label] = self.num_classes - self.inverse_classes_[self.num_classes] = label - self.num_classes += 1 - return self - - def transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray: - if isinstance(y, np.ndarray): - y = y.flatten().tolist() - - if self.is_numerical: - return np.array(y) - - return np.array([self.classes_[label] for label in y]) - - def fit_transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray: - self.fit(y) - return self.transform(y) - - def inverse_transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray: - if isinstance(y, np.ndarray): - y = y.flatten().tolist() - - if self.is_numerical: - return np.array(y) - - return np.array([self.inverse_classes_[label] for label in y]) - - class Params(Serializable, ABC): def __init__(self, format: Any, label_fields: Sequence[str] | None): # noqa: A002 self.format = format @@ -141,11 +89,8 @@ class DataProcessor(ABC): def __init__(self, params: Params, additional_targets: dict[str, str] | None = None): self.params = params self.data_fields = [self.default_data_name] - self.label_encoders: dict[str, dict[str, LabelEncoder]] = defaultdict(dict) self.is_sequence_input: dict[str, bool] = {} - self.is_numerical_label: dict[str, dict[str, bool]] = defaultdict(dict) - self.label_dtypes: dict[str, dict[str, np.dtype]] = defaultdict(dict) - self.label_input_types: dict[str, dict[str, type]] = defaultdict(dict) + self.label_manager = LabelManager() if additional_targets is not None: self.add_targets(additional_targets) @@ -263,7 +208,7 @@ def _process_label_fields(self, data: dict[str, Any], data_name: str) -> np.ndar if self.params.label_fields is not None: for label_field in self.params.label_fields: self._validate_label_field_length(data, data_name, label_field) - encoded_labels = self._encode_label_field(data, data_name, label_field) + encoded_labels = self.label_manager.process_field(data_name, label_field, data[label_field]) data_array = np.hstack((data_array, encoded_labels)) del data[label_field] return data_array @@ -275,32 +220,6 @@ def _validate_label_field_length(self, data: dict[str, Any], data_name: str, lab f"Got {len(data[data_name])} and {len(data[label_field])} respectively.", ) - def _encode_label_field(self, data: dict[str, Any], data_name: str, label_field: str) -> np.ndarray: - field_data = data[label_field] - - self.label_input_types[data_name][label_field] = type(field_data) - if isinstance(field_data, np.ndarray): - self.label_dtypes[data_name][label_field] = field_data.dtype - - # Check if input is numpy array or if all elements are numerical - is_numerical = (isinstance(field_data, np.ndarray) and np.issubdtype(field_data.dtype, np.number)) or all( - isinstance(label, (int, float)) for label in field_data - ) - - self.is_numerical_label[data_name][label_field] = is_numerical - - if is_numerical: - # For numerical values, convert to float32 for processing - if isinstance(field_data, np.ndarray): - return field_data.reshape(-1, 1).astype(np.float32) - return np.array(field_data, dtype=np.float32).reshape(-1, 1) - - # For non-numerical values, use LabelEncoder - encoder = LabelEncoder() - encoded_labels = encoder.fit_transform(field_data).reshape(-1, 1) - self.label_encoders[data_name][label_field] = encoder - return encoded_labels - def remove_label_fields_from_data(self, data: dict[str, Any]) -> dict[str, Any]: if not self.params.label_fields: return data @@ -316,7 +235,7 @@ def remove_label_fields_from_data(self, data: dict[str, Any]) -> dict[str, Any]: def _handle_empty_data_array(self, data: dict[str, Any]) -> None: if self.params.label_fields is not None: for label_field in self.params.label_fields: - data[label_field] = [] + data[label_field] = self.label_manager.handle_empty_data() def _remove_label_fields(self, data: dict[str, Any], data_name: str) -> None: if self.params.label_fields is None: @@ -328,31 +247,10 @@ def _remove_label_fields(self, data: dict[str, Any], data_name: str) -> None: for idx, label_field in enumerate(self.params.label_fields): encoded_labels = data_array[:, non_label_columns + idx] - decoded_labels = self._decode_label_field(data_name, label_field, encoded_labels) - - # Convert back to original type (list or numpy array) - input_type = self.label_input_types[data_name][label_field] - if isinstance(input_type, list): - data[label_field] = decoded_labels.tolist() - else: # numpy array - data[label_field] = decoded_labels + data[label_field] = self.label_manager.restore_field(data_name, label_field, encoded_labels) data[data_name] = data_array[:, :non_label_columns] - def _decode_label_field(self, data_name: str, label_field: str, encoded_labels: np.ndarray) -> np.ndarray: - if self.is_numerical_label[data_name][label_field]: - # Restore original dtype if it was stored - original_dtype = self.label_dtypes.get(data_name, {}).get(label_field) - if original_dtype is not None: - return encoded_labels.astype(original_dtype) - return encoded_labels - - encoder = self.label_encoders.get(data_name, {}).get(label_field) - if encoder: - return encoder.inverse_transform(encoded_labels.astype(int)) - - raise ValueError(f"Label encoder for {label_field} not found") - def validate_args( low: float | Sequence[int] | Sequence[float] | None, From c95e4295ee865d12a3eba7f3f5db4837656a476f Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Fri, 31 Jan 2025 13:09:26 -0800 Subject: [PATCH 3/4] Refactoring --- albumentations/core/label_manager.py | 20 +++++++++++++++----- tests/test_core_utils.py | 2 +- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/albumentations/core/label_manager.py b/albumentations/core/label_manager.py index 557e1eed3..9b658b730 100644 --- a/albumentations/core/label_manager.py +++ b/albumentations/core/label_manager.py @@ -93,6 +93,7 @@ def _analyze_input(self, field_data: Any) -> LabelMetadata: input_type = type(field_data) dtype = field_data.dtype if isinstance(field_data, np.ndarray) else None + # Check if input is numpy array or if all elements are numerical is_numerical = (isinstance(field_data, np.ndarray) and np.issubdtype(field_data.dtype, np.number)) or all( isinstance(label, (int, float)) for label in field_data ) @@ -130,15 +131,24 @@ def _decode_data(self, encoded_data: np.ndarray, metadata: LabelMetadata) -> np. if metadata.encoder is None: raise ValueError("Encoder not found for non-numerical data") - return metadata.encoder.inverse_transform(encoded_data.astype(int)) + + decoded = metadata.encoder.inverse_transform(encoded_data.astype(int)) + return decoded.reshape(-1) # Ensure 1D array def _restore_type(self, decoded_data: np.ndarray, metadata: LabelMetadata) -> Any: """Restore data to its original type.""" - if isinstance(metadata.input_type, list): + # If original input was a list or sequence, convert back to list + if isinstance(metadata.input_type, type) and issubclass(metadata.input_type, (list, Sequence)): return decoded_data.tolist() - if isinstance(metadata.input_type, np.ndarray) and metadata.dtype is not None: - return decoded_data.astype(metadata.dtype) - return decoded_data + + # If original input was a numpy array, restore original dtype + if isinstance(metadata.input_type, type) and issubclass(metadata.input_type, np.ndarray): + if metadata.dtype is not None: + return decoded_data.astype(metadata.dtype) + return decoded_data + + # For any other type, convert to list by default + return decoded_data.tolist() def handle_empty_data(self) -> list[Any]: """Handle empty data case.""" diff --git a/tests/test_core_utils.py b/tests/test_core_utils.py index d60e8c220..47126e9b9 100644 --- a/tests/test_core_utils.py +++ b/tests/test_core_utils.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from albumentations.core.utils import LabelEncoder +from albumentations.core.label_manager import LabelEncoder @pytest.mark.parametrize( From 4aee5e34f185375e7a190b3aa81de2eacc09948a Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Fri, 31 Jan 2025 13:15:31 -0800 Subject: [PATCH 4/4] Refactoring --- albumentations/augmentations/dropout/coarse_dropout.py | 2 +- albumentations/core/label_manager.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/albumentations/augmentations/dropout/coarse_dropout.py b/albumentations/augmentations/dropout/coarse_dropout.py index ca57cf0a8..fdbb3ff7d 100644 --- a/albumentations/augmentations/dropout/coarse_dropout.py +++ b/albumentations/augmentations/dropout/coarse_dropout.py @@ -512,7 +512,7 @@ def get_boxes_from_bboxes(self, bboxes: np.ndarray) -> np.ndarray | None: # Filter boxes by labels (usually in column 4) mask = np.isin(bboxes[:, 4], target_labels) - filtered_boxes = bboxes[mask, :4] # Keep only x,y,w,h + filtered_boxes = bboxes[mask, :4] return filtered_boxes if len(filtered_boxes) > 0 else None diff --git a/albumentations/core/label_manager.py b/albumentations/core/label_manager.py index 9b658b730..b6a05838b 100644 --- a/albumentations/core/label_manager.py +++ b/albumentations/core/label_manager.py @@ -10,9 +10,7 @@ def custom_sort(item: Any) -> tuple[int, Real | str]: - if isinstance(item, Real): - return (0, item) # Numerical items come first - return (1, str(item)) # Non-numerical items come second, converted to strings + return (0, item) if isinstance(item, Real) else (1, str(item)) class LabelEncoder: