Skip to content

Commit

Permalink
refactored elastic
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Jan 30, 2025
1 parent 0eae070 commit 2dac564
Showing 1 changed file with 35 additions and 32 deletions.
67 changes: 35 additions & 32 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,40 +1580,43 @@ def generate_displacement_fields(
random_generator: np.random.Generator,
noise_distribution: Literal["gaussian", "uniform"],
) -> tuple[np.ndarray, np.ndarray]:
"""Generate displacement fields for elastic transform.
Args:
image_shape: Shape of the image (height, width)
alpha: Scaling factor for displacement
sigma: Standard deviation for Gaussian blur
same_dxdy: Whether to use same displacement field for both directions
kernel_size: Size of Gaussian blur kernel
random_generator: NumPy random number generator
noise_distribution: Type of noise distribution to use ("gaussian" or "uniform")
Returns:
tuple: (dx, dy) displacement fields
"""

def generate_noise_field() -> np.ndarray:
# Generate noise based on distribution type
if noise_distribution == "gaussian":
field = random_generator.standard_normal(size=image_shape[:2])
else: # uniform
field = random_generator.uniform(low=-1, high=1, size=image_shape[:2])

# Common operations for both distributions
field = field.astype(np.float32)
cv2.GaussianBlur(field, kernel_size, sigma, dst=field)
return field * alpha

# Generate first displacement field
dx = generate_noise_field()
"""Generate displacement fields for elastic transform."""
# Pre-allocate memory and generate noise in one step
if noise_distribution == "gaussian":
# Generate and normalize in one step, directly as float32
fields = random_generator.standard_normal(
(1 if same_dxdy else 2, *image_shape[:2]),
dtype=np.float32,
)
# Normalize inplace
max_abs = np.abs(fields, out=np.empty_like(fields)).max()
if max_abs > 1e-6:
fields /= max_abs
else: # uniform is already normalized to [-1, 1]
fields = random_generator.uniform(
-1,
1,
size=(1 if same_dxdy else 2, *image_shape[:2]),
).astype(np.float32)

# Apply Gaussian blur if needed using fast OpenCV operations
if kernel_size != (0, 0):
# Use faster OpenCV operations for Gaussian blur
for i in range(fields.shape[0]):
# Use cv2.BORDER_REPLICATE to avoid edge artifacts
cv2.GaussianBlur(
fields[i],
kernel_size,
sigma,
dst=fields[i],
borderType=cv2.BORDER_REPLICATE,
)

# Generate or copy second displacement field
dy = dx if same_dxdy else generate_noise_field()
# Scale by alpha inplace
fields *= alpha

return dx, dy
# Return views of the array to avoid copies
return (fields[0], fields[0]) if same_dxdy else (fields[0], fields[1])


@handle_empty_array("bboxes")
Expand Down

0 comments on commit 2dac564

Please sign in to comment.