Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix np array int labels #2325

Merged
merged 4 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions albumentations/augmentations/dropout/coarse_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,14 @@ def get_boxes_from_bboxes(self, bboxes: np.ndarray) -> np.ndarray | None:
label_fields = bbox_processor.params.label_fields
if label_fields is None:
raise ValueError("BboxParams.label_fields must be specified when using string labels")

first_class_label = label_fields[0]
label_encoder = bbox_processor.label_encoders["bboxes"][first_class_label]
target_labels = label_encoder.transform(self.bbox_labels)
# Access encoder through label_manager's metadata
metadata = bbox_processor.label_manager.metadata["bboxes"][first_class_label]
if metadata.encoder is None:
raise ValueError(f"No encoder found for label field {first_class_label}")

target_labels = metadata.encoder.transform(self.bbox_labels)
ternaus marked this conversation as resolved.
Show resolved Hide resolved
else:
target_labels = np.array(self.bbox_labels)

Expand Down
155 changes: 155 additions & 0 deletions albumentations/core/label_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from numbers import Real
from typing import Any

import numpy as np


def custom_sort(item: Any) -> tuple[int, Real | str]:
if isinstance(item, Real):
return (0, item) # Numerical items come first
return (1, str(item)) # Non-numerical items come second, converted to strings
ternaus marked this conversation as resolved.
Show resolved Hide resolved


class LabelEncoder:
def __init__(self) -> None:
self.classes_: dict[str | Real, int] = {}
self.inverse_classes_: dict[int, str | Real] = {}
self.num_classes: int = 0
self.is_numerical: bool = True

def fit(self, y: Sequence[Any] | np.ndarray) -> LabelEncoder:
if isinstance(y, np.ndarray):
y = y.flatten().tolist()

self.is_numerical = all(isinstance(label, Real) for label in y)

if self.is_numerical:
return self

unique_labels = sorted(set(y), key=custom_sort)
for label in unique_labels:
if label not in self.classes_:
self.classes_[label] = self.num_classes
self.inverse_classes_[self.num_classes] = label
self.num_classes += 1
return self

def transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray:
if isinstance(y, np.ndarray):
y = y.flatten().tolist()

if self.is_numerical:
return np.array(y)

return np.array([self.classes_[label] for label in y])

def fit_transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray:
self.fit(y)
return self.transform(y)

def inverse_transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray:
if isinstance(y, np.ndarray):
y = y.flatten().tolist()

if self.is_numerical:
return np.array(y)

return np.array([self.inverse_classes_[label] for label in y])


@dataclass
class LabelMetadata:
"""Stores metadata about a label field."""

input_type: type
is_numerical: bool
dtype: np.dtype | None = None
encoder: LabelEncoder | None = None


class LabelManager:
def __init__(self) -> None:
self.metadata: dict[str, dict[str, LabelMetadata]] = defaultdict(dict)

def process_field(self, data_name: str, label_field: str, field_data: Any) -> np.ndarray:
"""Process a label field and store its metadata."""
metadata = self._analyze_input(field_data)
self.metadata[data_name][label_field] = metadata
return self._encode_data(field_data, metadata)

def restore_field(self, data_name: str, label_field: str, encoded_data: np.ndarray) -> Any:
ternaus marked this conversation as resolved.
Show resolved Hide resolved
"""Restore a label field to its original format."""
metadata = self.metadata[data_name][label_field]
decoded_data = self._decode_data(encoded_data, metadata)
return self._restore_type(decoded_data, metadata)

def _analyze_input(self, field_data: Any) -> LabelMetadata:
"""Analyze input data and create metadata."""
input_type = type(field_data)
dtype = field_data.dtype if isinstance(field_data, np.ndarray) else None

# Check if input is numpy array or if all elements are numerical
is_numerical = (isinstance(field_data, np.ndarray) and np.issubdtype(field_data.dtype, np.number)) or all(
isinstance(label, (int, float)) for label in field_data
)

metadata = LabelMetadata(
input_type=input_type,
is_numerical=is_numerical,
dtype=dtype,
)

if not is_numerical:
metadata.encoder = LabelEncoder()

return metadata

def _encode_data(self, field_data: Any, metadata: LabelMetadata) -> np.ndarray:
"""Encode field data for processing."""
if metadata.is_numerical:
# For numerical values, convert to float32 for processing
if isinstance(field_data, np.ndarray):
return field_data.reshape(-1, 1).astype(np.float32)
return np.array(field_data, dtype=np.float32).reshape(-1, 1)

# For non-numerical values, use LabelEncoder
if metadata.encoder is None:
raise ValueError("Encoder not initialized for non-numerical data")
return metadata.encoder.fit_transform(field_data).reshape(-1, 1)

def _decode_data(self, encoded_data: np.ndarray, metadata: LabelMetadata) -> np.ndarray:
"""Decode processed data."""
if metadata.is_numerical:
if metadata.dtype is not None:
return encoded_data.astype(metadata.dtype)
return encoded_data.flatten() # Flatten for list conversion

if metadata.encoder is None:
raise ValueError("Encoder not found for non-numerical data")

decoded = metadata.encoder.inverse_transform(encoded_data.astype(int))
return decoded.reshape(-1) # Ensure 1D array

def _restore_type(self, decoded_data: np.ndarray, metadata: LabelMetadata) -> Any:
"""Restore data to its original type."""
# If original input was a list or sequence, convert back to list
if isinstance(metadata.input_type, type) and issubclass(metadata.input_type, (list, Sequence)):
return decoded_data.tolist()

# If original input was a numpy array, restore original dtype
if isinstance(metadata.input_type, type) and issubclass(metadata.input_type, np.ndarray):
if metadata.dtype is not None:
return decoded_data.astype(metadata.dtype)
return decoded_data

# For any other type, convert to list by default
return decoded_data.tolist()

def handle_empty_data(self) -> list[Any]:
"""Handle empty data case."""
return []
98 changes: 6 additions & 92 deletions albumentations/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Sequence
from numbers import Real
from typing import TYPE_CHECKING, Any, Literal, cast, overload

import numpy as np

from albumentations.core.label_manager import LabelManager

from .serialization import Serializable
from .type_definitions import PAIR, Number

Expand Down Expand Up @@ -75,59 +76,6 @@ def format_args(args_dict: dict[str, Any]) -> str:
return ", ".join(formatted_args)


def custom_sort(item: Any) -> tuple[int, Real | str]:
if isinstance(item, Real):
return (0, item) # Numerical items come first
return (1, str(item)) # Non-numerical items come second, converted to strings


class LabelEncoder:
def __init__(self) -> None:
self.classes_: dict[str | Real, int] = {}
self.inverse_classes_: dict[int, str | Real] = {}
self.num_classes: int = 0
self.is_numerical: bool = True

def fit(self, y: Sequence[Any] | np.ndarray) -> LabelEncoder:
if isinstance(y, np.ndarray):
y = y.flatten().tolist()

self.is_numerical = all(isinstance(label, Real) for label in y)

if self.is_numerical:
return self

unique_labels = sorted(set(y), key=custom_sort)
for label in unique_labels:
if label not in self.classes_:
self.classes_[label] = self.num_classes
self.inverse_classes_[self.num_classes] = label
self.num_classes += 1
return self

def transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray:
if isinstance(y, np.ndarray):
y = y.flatten().tolist()

if self.is_numerical:
return np.array(y)

return np.array([self.classes_[label] for label in y])

def fit_transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray:
self.fit(y)
return self.transform(y)

def inverse_transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray:
if isinstance(y, np.ndarray):
y = y.flatten().tolist()

if self.is_numerical:
return np.array(y)

return np.array([self.inverse_classes_[label] for label in y])


class Params(Serializable, ABC):
def __init__(self, format: Any, label_fields: Sequence[str] | None): # noqa: A002
self.format = format
Expand All @@ -141,9 +89,8 @@ class DataProcessor(ABC):
def __init__(self, params: Params, additional_targets: dict[str, str] | None = None):
self.params = params
self.data_fields = [self.default_data_name]
self.label_encoders: dict[str, dict[str, LabelEncoder]] = defaultdict(dict)
self.is_sequence_input: dict[str, bool] = {}
self.is_numerical_label: dict[str, dict[str, bool]] = defaultdict(dict)
self.label_manager = LabelManager()

if additional_targets is not None:
self.add_targets(additional_targets)
Expand Down Expand Up @@ -261,7 +208,7 @@ def _process_label_fields(self, data: dict[str, Any], data_name: str) -> np.ndar
if self.params.label_fields is not None:
for label_field in self.params.label_fields:
self._validate_label_field_length(data, data_name, label_field)
encoded_labels = self._encode_label_field(data, data_name, label_field)
encoded_labels = self.label_manager.process_field(data_name, label_field, data[label_field])
data_array = np.hstack((data_array, encoded_labels))
del data[label_field]
return data_array
Expand All @@ -273,28 +220,6 @@ def _validate_label_field_length(self, data: dict[str, Any], data_name: str, lab
f"Got {len(data[data_name])} and {len(data[label_field])} respectively.",
)

def _encode_label_field(self, data: dict[str, Any], data_name: str, label_field: str) -> np.ndarray:
field_data = data[label_field]

# Check if input is numpy array or if all elements are numerical
is_numerical = (isinstance(field_data, np.ndarray) and np.issubdtype(field_data.dtype, np.number)) or all(
isinstance(label, (int, float)) for label in field_data
)

self.is_numerical_label[data_name][label_field] = is_numerical

if is_numerical:
# For numerical values, preserve numpy arrays or convert to float32
if isinstance(field_data, np.ndarray):
return field_data.reshape(-1, 1).astype(np.float32)
return np.array(field_data, dtype=np.float32).reshape(-1, 1)

# For non-numerical values, use LabelEncoder
encoder = LabelEncoder()
encoded_labels = encoder.fit_transform(field_data).reshape(-1, 1)
self.label_encoders[data_name][label_field] = encoder
return encoded_labels

def remove_label_fields_from_data(self, data: dict[str, Any]) -> dict[str, Any]:
if not self.params.label_fields:
return data
Expand All @@ -310,7 +235,7 @@ def remove_label_fields_from_data(self, data: dict[str, Any]) -> dict[str, Any]:
def _handle_empty_data_array(self, data: dict[str, Any]) -> None:
if self.params.label_fields is not None:
for label_field in self.params.label_fields:
data[label_field] = []
data[label_field] = self.label_manager.handle_empty_data()

def _remove_label_fields(self, data: dict[str, Any], data_name: str) -> None:
if self.params.label_fields is None:
Expand All @@ -322,21 +247,10 @@ def _remove_label_fields(self, data: dict[str, Any], data_name: str) -> None:

for idx, label_field in enumerate(self.params.label_fields):
encoded_labels = data_array[:, non_label_columns + idx]
decoded_labels = self._decode_label_field(data_name, label_field, encoded_labels)
data[label_field] = decoded_labels.tolist()
data[label_field] = self.label_manager.restore_field(data_name, label_field, encoded_labels)

data[data_name] = data_array[:, :non_label_columns]

def _decode_label_field(self, data_name: str, label_field: str, encoded_labels: np.ndarray) -> np.ndarray:
if self.is_numerical_label[data_name][label_field]:
return encoded_labels

encoder = self.label_encoders.get(data_name, {}).get(label_field)
if encoder:
return encoder.inverse_transform(encoded_labels.astype(int))

raise ValueError(f"Label encoder for {label_field} not found")


def validate_args(
low: float | Sequence[int] | Sequence[float] | None,
Expand Down
Loading