Skip to content

Commit

Permalink
Add more training and validation integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dblasko committed Nov 9, 2023
1 parent 683968a commit ab77363
Showing 1 changed file with 51 additions and 3 deletions.
54 changes: 51 additions & 3 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch.nn as nn
import warnings
import math
import random
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -67,9 +68,10 @@ def test_training_loop_runs(self):
self.assertLessEqual(epoch_loss, initial_loss, "Loss did not decrease or remained the same; something might be wrong in the training loop.")
initial_loss = epoch_loss

def test_loss_decreases(self):
def test_loss_decreases_no_nan_or_inf_params(self):
"""
Ensure that the loss decreases over multiple epochs.
Ensure that the loss decreases over multiple epochs, and that no weights of the model become NaN or Inf.
Both are done in a simple test function to limit the number of times the training loop has to be run on the free CI server. This is more of an integration test.
"""
device = torch.device("cpu")
self.model.to(device)
Expand All @@ -78,11 +80,57 @@ def test_loss_decreases(self):
for epoch in range(3):
epoch_loss, _ = train(self.train_data, self.model, self.criterion, self.optimizer, epoch, device)
losses.append(epoch_loss)

for param in self.model.parameters():
self.assertFalse(torch.isnan(param).any(), f"NaNs found in model parameters after {epoch+1} epochs")
self.assertFalse(torch.isinf(param).any(), f"Infs found in model parameters after {epoch+1} epochs")

for i in range(1, len(losses)):
self.assertLess(losses[i], losses[i-1], "Loss did not decrease after an epoch; training might not be functioning correctly.")


class TestValidationLoop(unittest.TestCase):
def setUp(self):
torch.manual_seed(0)
random.seed(0)
# We use a subset of the data to be able to run on the CI server
self.batch_size = 2
self.img_size = 128
self.train_dataset = PretrainingDataset('tests/data_examples/pretraining/test/imgs', 'tests/data_examples/pretraining/test/targets', img_size=self.img_size)
self.train_data = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=1)

self.model = MIRNet(num_features=64, number_msrb=2, number_rrg=2) # smaller model for CI

self.criterion = CharbonnierLoss()

def test_validation_integration(self):
"""
Integration test for the validation loop.
Ensures that the validation loop does not modify the model's weights,
that the loss and PSNR calculation does not lead to error or create NaN/Inf values,
and that the PSNR calculation creates realistic values.
"""
device = torch.device("cpu")
self.model.to(device)

initial_state_dict = self.model.state_dict()
try:
validation_loss, validation_psnr = validate(self.train_data, self.model, self.criterion, device)
except Exception as e:
self.fail(f"Validation loss computation raised an exception: {e}")
post_validation_state_dict = self.model.state_dict()

self.assertFalse(math.isnan(float(validation_loss)), "Validation loss is NaN")
self.assertFalse(math.isinf(validation_loss), "Validation loss is Inf")
self.assertFalse(math.isnan(validation_psnr), "Validation psnr is NaN")
self.assertFalse(math.isinf(validation_psnr), "Validation psnr is Inf")
self.assertGreater(validation_psnr, 0, "Validation PSNR should be positive")
self.assertLess(validation_psnr, 100, "Validation PSNR should be less than 100 in practice")

for param_before, param_after in zip(initial_state_dict.values(), post_validation_state_dict.values()):
self.assertTrue(torch.equal(param_before, param_after),
"Model weights changed during validation")



if __name__ == '__main__':
unittest.main()

0 comments on commit ab77363

Please sign in to comment.