-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathgenerate_search_configurations_over_lstm_states_and_past_steps.py
84 lines (77 loc) · 3.5 KB
/
generate_search_configurations_over_lstm_states_and_past_steps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import yaml
import os
def main(xml_dir_path, output_dir):
train_fraction = 0.6
valid_fraction = 0.2
test_fraction = 0.2
loss_function_script_path = '../loss_functions/nll_keras.py'
loss_function_name = 'nl'
nb_future_steps = 6
for i_run in range(5):
for pid in [559, 591]:
for nb_past_steps in [6,24,48]:
for nb_lstm_states in [8,32,96,128]:
config_path =\
os.path.join(output_dir,
'basic_lstm_pid_{}_past_steps_{}_lstm_states_{}_run_{}.yaml'.format( pid, nb_past_steps, nb_lstm_states, i_run))
artifacts_path =\
'../artifacts/{}/basic_lstm_pid_{}_past_steps_{}_lstm_states_{}_run_{}/'.format(output_dir, pid, nb_past_steps, nb_lstm_states, i_run)
xml_path = os.path.join(xml_dir_path, '{}-ws-training.xml'.format(
pid))
cfg = {
'dataset' : {
'script_path': '../datasets/ohio.py',
'xml_path': xml_path,
'nb_past_steps': nb_past_steps,
'nb_future_steps': nb_future_steps,
'train_fraction': train_fraction,
'valid_fraction': valid_fraction,
'test_fraction': test_fraction,
'scale': 0.01
},
'model' : {
'script_path': '../models/basic_lstm_keras.py',
'model_cfg': {
'nb_lstm_states': nb_lstm_states
}
},
'optimizer' : {
'script_path': '../optimizers/adam_keras.py',
'learning_rate': 0.0001
},
'loss_function' : {
'script_path': loss_function_script_path
},
'train' : {
'script_path': '../train/train_keras.py',
'artifacts_path': artifacts_path,
'batch_size': 128,
'epochs': 200,
'patience': 8,
'shuffle': True
}
}
with open(config_path, 'w') as outfile:
yaml.dump(cfg, outfile, default_flow_style=False)
def get_parser():
"""Get parser object."""
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
parser = ArgumentParser(description=__doc__,
formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("-f", "--file",
dest="xml_dir_path",
help="absolute root direcroty path of Ohio patient XML data",
metavar="FILE",
required=True)
parser.add_argument("-o", "--output_dir",
dest="output_dir",
help="output directory of the configuration files",
metavar="FILE",
required=True)
return parser
if __name__ == '__main__':
args = get_parser().parse_args()
print(args.xml_dir_path)
main(args.xml_dir_path, args.output_dir)