Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Fix/Inference mode for Imputers #75

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
if not in_place:
x = x.clone()

# Initilialize mask once
# Reset NaN locations outside of training for validation and inference.
if not self.training:
self.nan_locations = None

# Initialise mask if not cached.
if self.nan_locations is None:
# The mask is only saved for the last two dimensions (grid, variable)
idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)]
Expand Down
113 changes: 59 additions & 54 deletions tests/preprocessing/test_preprocessor_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,17 @@ def default_constant_data():
return base, expected


fixture_combinations = (
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
)


@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_imputer_not_inplace(imputer_fixture, data_fixture, request) -> None:
"""Check that the imputer does not modify the input tensor when in_place=False."""
Expand All @@ -150,12 +153,7 @@ def test_imputer_not_inplace(imputer_fixture, data_fixture, request) -> None:

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_imputer_inplace(imputer_fixture, data_fixture, request) -> None:
"""Check that the imputer modifies the input tensor when in_place=True."""
Expand All @@ -169,12 +167,7 @@ def test_imputer_inplace(imputer_fixture, data_fixture, request) -> None:

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_with_nan(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs."""
Expand All @@ -186,12 +179,7 @@ def test_transform_with_nan(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_with_nan_small(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs."""
Expand All @@ -211,12 +199,7 @@ def test_transform_with_nan_small(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_with_nan_inference(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs in inference."""
Expand Down Expand Up @@ -244,12 +227,7 @@ def test_transform_with_nan_inference(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_noop(imputer_fixture, data_fixture, request):
"""Check that the imputer does not modify a tensor without NaNs."""
Expand All @@ -262,12 +240,7 @@ def test_transform_noop(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_inverse_transform(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly inverts the transformation."""
Expand All @@ -281,12 +254,7 @@ def test_inverse_transform(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_mask_saving(imputer_fixture, data_fixture, request):
"""Check that the imputer saves the NaN mask correctly."""
Expand All @@ -299,12 +267,7 @@ def test_mask_saving(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_reuse_imputer(imputer_fixture, data_fixture, request):
"""Check that the imputer reuses the mask correctly on subsequent runs."""
Expand All @@ -316,3 +279,45 @@ def test_reuse_imputer(imputer_fixture, data_fixture, request):
assert torch.allclose(
transformed2, expected, equal_nan=True
), "Imputer does not reuse mask correctly on subsequent runs."


@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
fixture_combinations,
)
def test_inference_imputer(imputer_fixture, data_fixture, request):
"""Check that the imputer resets its mask during inference."""
x, expected = request.getfixturevalue(data_fixture)
imputer = request.getfixturevalue(imputer_fixture)

# Check training flag
assert imputer.training, "Imputer is not set to training mode."

expected_mask = torch.isnan(x)
transformed = imputer.transform(x, in_place=False)
assert torch.allclose(transformed, expected, equal_nan=True), "Transform does not handle NaNs correctly."
restored = imputer.inverse_transform(transformed, in_place=False)
assert torch.allclose(restored, x, equal_nan=True), "Inverse transform does not restore NaNs correctly."
assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run."

imputer.eval()
with torch.no_grad():
x2 = x.roll(-1, dims=0)
expected2 = expected.roll(-1, dims=0)
expected_mask2 = torch.isnan(x2)

assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run."

# Check training flag
assert not imputer.training, "Imputer is not set to evaluation mode."

assert not torch.allclose(x, x2, equal_nan=True), "Failed to modify the input data."
assert not torch.allclose(expected, expected2, equal_nan=True), "Failed to modify the expected data."
assert not torch.allclose(expected_mask, expected_mask2, equal_nan=True), "Failed to modify the nan mask."

transformed = imputer.transform(x2, in_place=False)
assert torch.allclose(transformed, expected2, equal_nan=True), "Transform does not handle NaNs correctly."
restored = imputer.inverse_transform(transformed, in_place=False)
assert torch.allclose(restored, x2, equal_nan=True), "Inverse transform does not restore NaNs correctly."

assert torch.equal(imputer.nan_locations, expected_mask2), "Mask not saved correctly after evaluation run."
Loading