-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcredential_reader.py
195 lines (149 loc) · 6.44 KB
/
credential_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from typing import Optional, List, Union, Literal
from enum import Enum
import os
import yaml
import socket
from pydantic import BaseModel, AnyUrl, PositiveInt, DirectoryPath, Field
from pydantic.networks import IPvAnyAddress
# ------ config file names ------ #
explorer_cfg_fname = '.db.secrets.yml'
terra_cfg_fname = 'dataflow_config.yml'
user_pass_fname = '.explorer_user-passwords.yml'
# ------ enum aliases for Environment Variables ------ #
class env_var(Enum):
TERRA = 'TERRA_CONFIG_LOC'
EXPLORER = 'EXPLORER_CONFIG_LOC'
USERPASS = 'EXPLORER_USER_PASS'
# ------ config value pydantic specs ------ #
class databaseArgs(BaseModel):
db_name: str
db_user: str
password: str
host: Union[str, IPvAnyAddress, AnyUrl]
# this allows for values such as localhost or 127.0.0.1
# or 192.168.xxx.xxx or <server_name>.nmr.mgh.harvard.edu
class dataflowArgs(BaseModel):
reserve_threshold_bytes: PositiveInt
suitable_volumes: List[DirectoryPath]
delete_threshold: float = Field(ge=0, le=1)
# ----------------------------------------- #
def read_environment_variable(environment_variable_name: str) -> os.PathLike:
''' Reads an environment variable
Environ vars store location of corresponding config files
'''
config_file_location = os.environ.get(environment_variable_name)
if config_file_location is None:
raise Exception(f'got None when retreiving {environment_variable_name} environment variable')
return config_file_location
def validate_config_fpath(config_fpath: os.PathLike) -> None:
''' validates that any path-to-file exists,
meant to ensure config file exists
'''
if not os.path.exists(config_fpath):
raise Exception(f'config file at {config_fpath} does not exist')
def get_config_file_path(config_file_name: str,
env_var_name: str) -> os.PathLike:
'''Returns validated config file path defined by environment variable'''
try:
config_file_location = read_environment_variable(env_var_name)
except Exception as e:
raise Exception(f'Could not get config file location from environment variable {env_var_name}')
config_fpath = os.path.join(config_file_location, config_file_name)
validate_config_fpath(config_fpath)
return config_fpath
def load_yaml_file_into_dict(yaml_file_path):
with open(yaml_file_path) as yaml_file:
param_dict = yaml.load(yaml_file, yaml.FullLoader)
return param_dict
def read_db_secrets(config_fpath: Optional[str] = None):
''' Returns a dictionary of database credentials with keys:
'database' for the name of the postgres database
'user' for the pg username
'password' for the pg user password
'host' for database host
The credential file is assumed to be at a location which
is defined in the EXPLORER_CONFIG_LOC environment variable
'''
# First get config file path
if config_fpath is None:
config_fpath = get_config_file_path(explorer_cfg_fname, env_var.EXPLORER.value)
# Then load the config file
db_config_dict = load_yaml_file_into_dict(config_fpath)
# Next get the name of server on which the script is being run
try:
hostname = socket.gethostname().split('.')[0]
except Exception as e:
raise Exception('Something went wrong in trying to get hostname')
if hostname not in db_config_dict.keys():
text = f'Hostname {hostname} does not match any key in config file {config_fpath}'
raise RuntimeError (text)
# Now extract db args specific to host
host_specific_db_args = db_config_dict[hostname]
# Then apply pydantic validation to db args values
db_args = databaseArgs(**host_specific_db_args)
credentials = {'database': db_args.db_name,
'user': db_args.db_user,
'password': db_args.password,
'host': db_args.host}
return credentials
def read_dataflow_configs(config_fpath: Optional[str] = None):
''' Returns a dictionary of dataflow parameters with keys:
'reserve_threshold_bytes'
'suitable_volumes'
'delete_threshold'
See config yaml for context on these keys
'''
if config_fpath is None:
config_fpath = get_config_file_path(terra_cfg_fname, env_var.TERRA.value)
dataflow_config_dict = load_yaml_file_into_dict(config_fpath)
dataflow_args = dataflowArgs(**dataflow_config_dict)
# this validates dataflow config values
dataflow_configs = {'reserve_threshold_bytes': dataflow_args.reserve_threshold_bytes,
'suitable_volumes': dataflow_args.suitable_volumes,
'delete_threshold': dataflow_args.delete_threshold}
# check that all volumes in suitable volumes are actually suitable
for volume in dataflow_configs['suitable_volumes']:
if not os.path.exists(volume):
raise Exception(f'volume path {volume} does not exist')
return dataflow_configs
def get_ssh_args():
''' SSH arguments don't change, hence are hardcoded here
and returned as is
'''
ssh_args = dict(
ssh_address_or_host='neurodoor.nmr.mgh.harvard.edu',
ssh_pkey='/space/neurobooth/1/applications/config/id_rsa', # this is user sp1022's id_rsa
remote_bind_address=('192.168.100.1', 5432),
local_bind_address=('localhost', 6543)
)
return ssh_args
def get_user_pass_pairs():
'''Read username-password pairs from yml file'''
config_fpath = get_config_file_path(user_pass_fname, env_var.USERPASS.value)
user_pass_dict = load_yaml_file_into_dict(config_fpath)
return user_pass_dict
if __name__ == '__main__':
''' Run this script standalone to test config reading or config value validation
OPTIONAL: Pass config file paths as command line arguments
'''
import sys
if len(sys.argv)==3:
db_args = read_db_secrets(sys.argv[1])
dataflow_args = read_dataflow_configs(sys.argv[2])
else:
db_args = read_db_secrets()
dataflow_args = read_dataflow_configs()
ssh_args = get_ssh_args()
user_pass_pairs = get_user_pass_pairs()
print()
for ky in db_args.keys():
print(ky, db_args[ky])
print()
for ky in dataflow_args.keys():
print(ky, dataflow_args[ky])
print()
for ky in ssh_args.keys():
print(ky, ssh_args[ky])
print()
for ky in user_pass_pairs.keys():
print(ky, user_pass_pairs[ky])