-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
79 lines (63 loc) · 3.49 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import logging
import dotenv
import hydra
dotenv.load_dotenv(override=True)
logging.getLogger('numexpr.utils').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
@hydra.main(version_base=None, config_path="config/", config_name="train.yaml")
def main(config):
from torch.utils.data import DataLoader
from accelerate import Accelerator
from sqp_experiments.utils import seed_everything, register_resolvers, pretty_configs, model_summary, train_valid_split
accelerator = Accelerator()
# preamble
seed_everything(config.seed)
register_resolvers()
logger.info(f"Current configs:\n{pretty_configs(config)}")
# instantiate dataset
logger.info(f"Initializing training/validation dataset {config.dataset._target_}")
dataset = hydra.utils.instantiate(config.dataset)
# split into train/validation and instantiate dataloaders
dataset_train, dataset_valid = train_valid_split(dataset=dataset, valid_split=config.train_valid_split)
logger.info(f"Splitting dataset into training/validation sets: {len(dataset_train)} / {len(dataset_valid)}")
dataloader_train = DataLoader(dataset=dataset_train, batch_size=config.batch_size, num_workers=config.workers, shuffle=True)
dataloader_valid = DataLoader(dataset=dataset_valid, batch_size=config.valid_batch_size, num_workers=config.workers, shuffle=False)
# instantiate callbacks
callbacks = []
for callback_name in config.callbacks:
logger.info(f"Initializing callback {config.callbacks[callback_name]._target_}")
curr_callback = hydra.utils.instantiate(config.callbacks[callback_name])
callbacks.append(curr_callback)
# instantiate model
logger.info(f"Initializing model {config.model._target_}")
model = hydra.utils.instantiate(config.model)
logger.info(f"Model architecture:\n{model_summary(model, dataloader_valid)}")
# instantiate trainer
logger.info(f"Initializing trainer {config.trainer._target_}")
trainer = hydra.utils.instantiate(config.trainer, all_config=config, accelerator=accelerator, callbacks=callbacks, _recursive_=False)
# train
if config.checkpoint_path:
logger.info(f"Loading checkpoint from {config.checkpoint_path}")
trainer.train(model=model, dataloader_train=dataloader_train, dataloader_valid=dataloader_valid, checkpoint_path=config.checkpoint_path)
if config.run_test:
# find and load best checkpoint
logger.info("Finding best checkpoint...")
checkpoint_callbacks = [c for c in callbacks if hasattr(c, 'best_path')]
assert len(checkpoint_callbacks) <= 1, "There appear to be several checkpoint callbacks"
if len(checkpoint_callbacks) == 1 and checkpoint_callbacks[0].best_path is not None:
checkpoint_best_path = checkpoint_callbacks[0].best_path
logger.info(f"Best checkpoint found at {checkpoint_best_path}")
else:
checkpoint_best_path = None
logger.info("Best checkpoint not found! Using current state")
# instantiate test dataset and dataloaders
logger.info(f"Initializing test dataset {config.dataset_test._target_}")
dataset_test = hydra.utils.instantiate(config.dataset_test)
dataloader_test = DataLoader(dataset=dataset_test, batch_size=config.valid_batch_size, num_workers=config.workers, shuffle=False)
# test
trainer.test(
model=model,
dataloader_test=dataloader_test,
checkpoint_path=checkpoint_best_path)
if __name__ == "__main__":
main()