Skip to content

Commit

Permalink
Updated GaussNoise and deleted traget_as_params (#2268)
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus authored Jan 8, 2025
1 parent 669b60f commit e4d6cce
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 62 deletions.
48 changes: 1 addition & 47 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collections.abc import Sequence
from types import LambdaType
from typing import Annotated, Any, Callable, Union, cast
from warnings import warn

import albucore
import cv2
Expand Down Expand Up @@ -2219,9 +2218,6 @@ class GaussNoise(ImageOnlyTransform):
- The noise parameters (std_range and mean_range) are normalized to [0, 1] range:
* For uint8 images, they are multiplied by 255
* For float32 images, they are used directly
- The behavior differs between old and new parameters:
* When using var_limit (deprecated): samples variance uniformly and takes sqrt to get std dev
* When using std_range: samples standard deviation directly (aligned with torchvision/kornia)
- Setting per_channel=False is faster but applies the same noise to all channels
- The noise_scale_factor parameter allows for a trade-off between transform speed and noise granularity
Expand All @@ -2233,15 +2229,9 @@ class GaussNoise(ImageOnlyTransform):
>>> # Apply Gaussian noise with normalized std_range
>>> transform = A.GaussNoise(std_range=(0.1, 0.2), p=1.0) # 10-20% of max value
>>> noisy_image = transform(image=image)['image']
>>>
>>> # Using deprecated var_limit (will be converted to std_range)
>>> transform = A.GaussNoise(var_limit=(50.0, 100.0), mean=10, p=1.0)
>>> noisy_image = transform(image=image)['image']
"""

class InitSchema(BaseTransformInitSchema):
var_limit: ScaleFloatType | None
mean: float | None
std_range: Annotated[
tuple[float, float],
AfterValidator(check_range_bounds(0, 1)),
Expand All @@ -2255,36 +2245,8 @@ class InitSchema(BaseTransformInitSchema):
per_channel: bool
noise_scale_factor: float = Field(gt=0, le=1)

@model_validator(mode="after")
def check_range(self) -> Self:
if self.var_limit is not None:
warnings.warn("`var_limit` deprecated. Use `std_range` instead.", DeprecationWarning, stacklevel=2)
self.var_limit = to_tuple(self.var_limit, 0)
if self.var_limit[1] > 1:
# Convert legacy uint8 variance to normalized std dev
self.std_range = (math.sqrt(10 / 255), math.sqrt(50 / 255))
else:
# Already normalized variance, convert to std dev
self.std_range = (
math.sqrt(self.var_limit[0]),
math.sqrt(self.var_limit[1]),
)

if self.mean is not None:
warn("`mean` deprecated. Use `mean_range` instead.", DeprecationWarning, stacklevel=2)
if self.mean >= 1:
# Convert legacy uint8 mean to normalized range
self.mean_range = (self.mean / 255, self.mean / 255)
else:
# Already normalized mean
self.mean_range = (self.mean, self.mean)

return self

def __init__(
self,
var_limit: ScaleFloatType | None = None,
mean: float | None = None,
std_range: tuple[float, float] = (0.2, 0.44), # sqrt(10 / 255), sqrt(50 / 255)
mean_range: tuple[float, float] = (0.0, 0.0),
per_channel: bool = True,
Expand All @@ -2297,8 +2259,6 @@ def __init__(
self.per_channel = per_channel
self.noise_scale_factor = noise_scale_factor

self.var_limit = var_limit

def apply(
self,
img: np.ndarray,
Expand All @@ -2315,13 +2275,7 @@ def get_params_dependent_on_data(
image = data["image"] if "image" in data else data["images"][0]
max_value = MAX_VALUES_BY_DTYPE[image.dtype]

if self.var_limit is not None:
# Legacy behavior: sample variance uniformly then take sqrt
var = self.py_random.uniform(self.std_range[0] ** 2, self.std_range[1] ** 2)
sigma = math.sqrt(var)
else:
# New behavior: sample std dev directly (aligned with torchvision/kornia)
sigma = self.py_random.uniform(*self.std_range)
sigma = self.py_random.uniform(*self.std_range)

mean = self.py_random.uniform(*self.mean_range)

Expand Down
15 changes: 0 additions & 15 deletions albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,6 @@ def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Any:
params_dependent_on_data = self.get_params_dependent_on_data(params=params, data=kwargs)
params.update(params_dependent_on_data)

if self.targets_as_params: # this block will be removed after removing `get_params_dependent_on_targets`
targets_as_params = {k: kwargs.get(k) for k in self.targets_as_params}
if missing_keys: # here we expecting case when missing_keys == {"image"} and "images" in kwargs
targets_as_params["image"] = kwargs["images"][0]
params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params)
params.update(params_dependent_on_targets)

# Store the final params
self.params = params

Expand Down Expand Up @@ -337,14 +330,6 @@ def targets_as_params(self) -> list[str]:
"""
return []

def get_params_dependent_on_targets(self, params: dict[str, Any]) -> dict[str, Any]:
"""This method is deprecated.
Use `get_params_dependent_on_data` instead.
Returns parameters dependent on targets.
Dependent target is defined in `self.targets_as_params`
"""
return {}

@classmethod
def get_class_fullname(cls) -> str:
return get_shortest_class_fullname(cls)
Expand Down

0 comments on commit e4d6cce

Please sign in to comment.