Skip to content

Commit

Permalink
Test training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
dblasko committed Nov 9, 2023
1 parent f9a0cf1 commit 683968a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ permissions:
jobs:
build:

strategy:
fail-fast: false

runs-on: ubuntu-latest

steps:
Expand Down
58 changes: 58 additions & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import torch
import torch.nn as nn
import warnings
import random
from torch.utils.data import DataLoader

import sys
sys.path.append(".")
from model.MIRNet.model import MIRNet
from training.train import train, validate
from training.training_utils.CharbonnierLoss import CharbonnierLoss
from dataset_generation.PretrainingDataset import PretrainingDataset

class TestOptimizer(unittest.TestCase):
def test_optimizer_updates_weights(self):
Expand All @@ -28,3 +33,56 @@ def test_optimizer_updates_weights(self):
updated_weights = model.conv_start.weight.data
self.assertFalse(torch.equal(initial_weights, updated_weights),
"The optimizer did not update the weights")


class TestTrainingLoop(unittest.TestCase):
def setUp(self):
torch.manual_seed(0)
random.seed(0)
# We use a subset of the data to be able to run on the CI server
self.batch_size = 2
self.img_size = 128
self.train_dataset = PretrainingDataset('tests/data_examples/pretraining/train/imgs', 'tests/data_examples/pretraining/train/targets', img_size=self.img_size)
self.train_data = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=1)

self.model = MIRNet(num_features=64, number_msrb=2, number_rrg=2) # smaller model for CI

self.criterion = CharbonnierLoss()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001)
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 3, 0.00001)

def test_training_loop_runs(self):
"""
Ensure the training loop runs without raising errors.
"""
device = torch.device("cpu") # CI environment constraint
self.model.to(device)

initial_loss = None
for epoch in range(3):
epoch_loss, _ = train(self.train_data, self.model, self.criterion, self.optimizer, epoch, device)
if initial_loss is None:
initial_loss = epoch_loss
else:
self.assertLessEqual(epoch_loss, initial_loss, "Loss did not decrease or remained the same; something might be wrong in the training loop.")
initial_loss = epoch_loss

def test_loss_decreases(self):
"""
Ensure that the loss decreases over multiple epochs.
"""
device = torch.device("cpu")
self.model.to(device)

losses = []
for epoch in range(3):
epoch_loss, _ = train(self.train_data, self.model, self.criterion, self.optimizer, epoch, device)
losses.append(epoch_loss)

for i in range(1, len(losses)):
self.assertLess(losses[i], losses[i-1], "Loss did not decrease after an epoch; training might not be functioning correctly.")



if __name__ == '__main__':
unittest.main()

0 comments on commit 683968a

Please sign in to comment.