Skip to content

Commit

Permalink
Added CenterCrop3D and RandomCrop3D
Browse files Browse the repository at this point in the history
  • Loading branch information
Vladimir Iglovikov committed Dec 14, 2024
1 parent ce7f41a commit 94a36de
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
33 changes: 21 additions & 12 deletions albumentations/augmentations/transforms3d/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,21 @@ def __init__(
self.fill_mask = fill_mask
self.pad_position = pad_position

def _random_pad(self, pad: int) -> tuple[int, int]:
"""Helper function to calculate random padding for one dimension."""
if pad > 0:
pad_start = self.py_random.randint(0, pad)
pad_end = pad - pad_start
else:
pad_start = pad_end = 0
return pad_start, pad_end

def _center_pad(self, pad: int) -> tuple[int, int]:
"""Helper function to calculate center padding for one dimension."""
pad_start = pad // 2
pad_end = pad - pad_start
return pad_start, pad_end

def _get_pad_params(
self,
image_shape: tuple[int, int, int],
Expand All @@ -280,20 +295,14 @@ def _get_pad_params(

# For center padding, split equally
if self.pad_position == "center":
z_front = z_pad // 2
z_back = z_pad - z_front
h_top = h_pad // 2
h_bottom = h_pad - h_top
w_left = w_pad // 2
w_right = w_pad - w_left
z_front, z_back = self._center_pad(z_pad)
h_top, h_bottom = self._center_pad(h_pad)
w_left, w_right = self._center_pad(w_pad)
# For random padding, randomly distribute the padding
else: # random
z_front = self.py_random.randint(0, z_pad + 1) if z_pad > 0 else 0
z_back = z_pad - z_front
h_top = self.py_random.randint(0, h_pad + 1) if h_pad > 0 else 0
h_bottom = h_pad - h_top
w_left = self.py_random.randint(0, w_pad + 1) if w_pad > 0 else 0
w_right = w_pad - w_left
z_front, z_back = self._random_pad(z_pad)
h_top, h_bottom = self._random_pad(h_pad)
w_left, w_right = self._random_pad(w_pad)

return {
"pad_front": z_front,
Expand Down
14 changes: 2 additions & 12 deletions tests/transforms3d/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,7 @@ def test_crop_3d_shapes(transform, input_shape, expected_shape):
def test_crop_3d_padding(transform_cls, input_shape, target_shape, description):
volume = np.random.randint(0, 256, input_shape, dtype=np.uint8)

transform = transform_cls(
size=target_shape,
pad_if_needed=True,
fill=0,
fill_mask=0,
)
transform = A.Compose([transform_cls(p=1, size=target_shape, pad_if_needed=True, fill=0, fill_mask=0)], seed=0)

transformed = transform(images=volume)
assert transformed["images"].shape == target_shape
Expand All @@ -411,12 +406,7 @@ def test_crop_3d_fill_values(transform_cls, size, fill, fill_mask):
volume = np.ones((3, 50, 50), dtype=np.uint8)
mask = np.zeros((3, 50, 50), dtype=np.uint8)

transform = transform_cls(
size=size,
pad_if_needed=True,
fill=fill,
fill_mask=fill_mask,
)
transform = A.Compose([transform_cls(p=1, size=size, pad_if_needed=True, fill=fill, fill_mask=fill_mask)], seed=0)

transformed = transform(images=volume, masks=mask)
padded_volume = transformed["images"]
Expand Down

0 comments on commit 94a36de

Please sign in to comment.