-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_train_mp_wds_cifar.py
441 lines (387 loc) · 15.2 KB
/
test_train_mp_wds_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
import torch_xla.test.test_utils as test_utils
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import torch_xla.distributed.parallel_loader as pl
import torch_xla.debug.metrics as met
import torch_xla
import torchvision.transforms as transforms
import torchvision
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
import sys
import os
import webdataset as wds
import datetime
import time
from itertools import islice
import torch_xla.debug.profiler as xp
from google.cloud import storage
from google.cloud.storage.bucket import Bucket
from google.cloud.storage.blob import Blob
# import torch_xla.utils.serialization as xser
for extra in ('/usr/share/torch-xla-1.8/pytorch/xla/test', '/pytorch/xla/test', '/usr/share/pytorch/xla/test'):
if os.path.exists(extra):
sys.path.insert(0, extra)
import schedulers
import args_parse
SUPPORTED_MODELS = [
'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34',
'resnet50', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg11_bn', 'vgg13',
'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn'
]
MODEL_OPTS = {
'--model': {
'choices': SUPPORTED_MODELS,
'default': 'resnet50',
},
'--test_set_batch_size': {
'type': int,
},
'--lr_scheduler_type': {
'type': str,
},
'--lr_scheduler_divide_every_n_epochs': {
'type': int,
},
'--lr_scheduler_divisor': {
'type': int,
},
'--dataset': {
'choices': ['gcsdataset', 'torchdataset'],
'default': 'gcsdataset',
'type': str,
},
'--wds_traindir': {
'type': str,
'default':'/tmp/cifar',
},
'--wds_testdir': {
'type': str,
'default': '/tmp/cifar',
},
'--trainsize': {
'type': int,
'default': 1280000,
},
'--testsize': {
'type': int,
'default': 50000,
},
'--save_model': {
'type': str,
'default': "",
},
'--load_chkpt_file': {
'type': str,
'default': "",
},
'--load_chkpt_dir': {
'type': str,
'default': "",
},
'--model_bucket': {
'type': str,
'default': "",
},
}
FLAGS = args_parse.parse_common_options(
datadir='/tmp/imagenet',
batch_size=None,
num_epochs=None,
momentum=None,
lr=None,
target_accuracy=None,
opts=MODEL_OPTS.items(),
profiler_port=9012,
)
DEFAULT_KWARGS = dict(
batch_size=128,
test_set_batch_size=64,
num_epochs=18,
momentum=0.9,
lr=0.1,
target_accuracy=0.0,
)
MODEL_SPECIFIC_DEFAULTS = {
# Override some of the args in DEFAULT_KWARGS, or add them to the dict
# if they don't exist.
'resnet50':
dict(
DEFAULT_KWARGS, **{
'lr': 0.5,
'lr_scheduler_divide_every_n_epochs': 20,
'lr_scheduler_divisor': 5,
'lr_scheduler_type': 'WarmupAndExponentialDecayScheduler',
})
}
# Set any args that were not explicitly given by the user.
default_value_dict = MODEL_SPECIFIC_DEFAULTS.get(FLAGS.model, DEFAULT_KWARGS)
for arg, value in default_value_dict.items():
if getattr(FLAGS, arg) is None:
setattr(FLAGS, arg, value)
def get_model_property(key):
default_model_property = {
'img_dim': 224,
'model_fn': getattr(torchvision.models, FLAGS.model)
}
model_properties = {
'inception_v3': {
'img_dim': 299,
'model_fn': lambda: torchvision.models.inception_v3(aux_logits=False)
},
}
model_fn = model_properties.get(FLAGS.model, default_model_property)[key]
return model_fn
def _train_update(device, step, loss, tracker, epoch, writer):
test_utils.print_training_update(
device,
step,
loss.item(),
tracker.rate(),
tracker.global_rate(),
epoch,
summary_writer=writer)
trainsize = FLAGS.trainsize
testsize = FLAGS.testsize
def _upload_blob_gcs(gcs_uri, source_file_name, destination_blob_name):
"""Uploads a file to GCS bucket"""
client = storage.Client()
blob = Blob.from_string(os.path.join(gcs_uri, destination_blob_name))
blob.bucket._client = client
blob.upload_from_filename(source_file_name)
xm.master_print("Saved Model Checkpoint file {} and uploaded to {}.".format(source_file_name, os.path.join(gcs_uri, destination_blob_name)))
def _read_blob_gcs(BUCKET, CHKPT_FILE, DESTINATION):
"""Downloads a file from GCS to local directory"""
client = storage.Client()
bucket = client.get_bucket(BUCKET)
blob = bucket.get_blob(CHKPT_FILE)
blob.download_to_filename(DESTINATION)
def identity(x):
return x
def my_worker_splitter(urls):
"""Split urls per worker
Selects a subset of urls based on Torch get_worker_info.
Used as a shard selection function in Dataset.
replaces wds.split_by_worker"""
# import torch
urls = [url for url in urls]
assert isinstance(urls, list)
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
wid = worker_info.id
num_workers = worker_info.num_workers
return urls[wid::num_workers]
else:
return urls
def my_node_splitter(urls):
"""Split urls_ correctly per accelerator node
:param urls:
:return: slice of urls_
"""
rank=xm.get_ordinal()
num_replicas=xm.xrt_world_size()
urls_this = urls[rank::num_replicas]
return urls_this
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # imagenet
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) # cifar
cifar_img_dim = 32
def make_train_loader(cifar_img_dim, shuffle=10000, batch_size=FLAGS.batch_size):
num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
epoch_size = trainsize // num_dataset_instances
image_transform = transforms.Compose(
[
transforms.RandomCrop(cifar_img_dim, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
dataset = (
wds.WebDataset(FLAGS.wds_traindir,
splitter=my_worker_splitter,
nodesplitter=my_node_splitter,
shardshuffle=True, length=epoch_size)
.shuffle(shuffle)
.decode("pil")
.to_tuple("ppm;jpg;jpeg;png", "cls")
.map_tuple(image_transform, identity)
.batched(batch_size, partial=True)
)
loader = torch.utils.data.DataLoader(dataset, batch_size=None, shuffle=False, drop_last=False, num_workers=FLAGS.num_workers) # , worker_init_fn=worker_init_fn
return loader
def make_val_loader(cifar_img_dim, batch_size=FLAGS.test_set_batch_size):
num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
epoch_test_size = testsize // num_dataset_instances
val_transform = transforms.Compose(
[
transforms.ToTensor(),
normalize,
]
)
val_dataset = (
wds.WebDataset(FLAGS.wds_testdir,
splitter=my_worker_splitter, nodesplitter=my_node_splitter, shardshuffle=False, length=epoch_test_size)
.decode("pil")
.to_tuple("ppm;jpg;jpeg;png", "cls")
.map_tuple(val_transform, identity)
.batched(batch_size, partial=True)
)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=None, shuffle=False, num_workers=FLAGS.num_workers)
return val_loader
def train_imagenet():
print('==> Preparing data..')
train_loader = make_train_loader(cifar_img_dim, batch_size=FLAGS.batch_size, shuffle=10000)
test_loader = make_val_loader(cifar_img_dim, batch_size=FLAGS.test_set_batch_size)
torch.manual_seed(42)
server = xp.start_server(FLAGS.profiler_port)
device = xm.xla_device()
model = get_model_property('model_fn')().to(device)
writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(FLAGS.logdir)
optimizer = optim.SGD(
model.parameters(),
lr=FLAGS.lr,
momentum=FLAGS.momentum,
weight_decay=1e-4)
num_training_steps_per_epoch = trainsize // (
FLAGS.batch_size * xm.xrt_world_size())
lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
optimizer,
scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
scheduler_divide_every_n_epochs=getattr(
FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
num_steps_per_epoch=num_training_steps_per_epoch,
summary_writer=writer)
loss_fn = nn.CrossEntropyLoss()
if FLAGS.load_chkpt_file != "":
xm.master_print("Loading saved model {}".format(FLAGS.load_chkpt_file))
_read_blob_gcs(FLAGS.model_bucket, FLAGS.load_chkpt_file, FLAGS.load_chkpt_dir)
checkpoint = torch.load(FLAGS.load_chkpt_dir) # torch.load(FLAGS.load_chkpt_dir)
model.load_state_dict(checkpoint['model_state_dict']) #.to(device)
model = model.to(device)
# optimizer.load_state_dict(checkpoint['opt_state_dict']) #.to(device)
# optimizer.to(device)
# start_epoch = checkpoint['epoch']
# best_valid_acc = checkpoint['best_valid_acc']
# server = xp.start_server(profiler_port)
def train_loop_fn(loader, epoch):
train_steps = trainsize // (FLAGS.batch_size * xm.xrt_world_size())
tracker = xm.RateTracker()
total_samples = 0
model.train()
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
tracker.add(FLAGS.batch_size)
total_samples += data.size()[0]
if lr_scheduler:
lr_scheduler.step()
if step % FLAGS.log_steps == 0:
xm.add_step_closure(
_train_update, args=(device, step, loss, tracker, epoch, writer))
test_utils.write_to_summary(writer, step, dict_to_write={'Rate_step': tracker.rate()}, write_xla_metrics=False)
if step == train_steps:
break
reduced_global = xm.mesh_reduce('reduced_global', tracker.global_rate(), np.mean)
return total_samples, reduced_global
def test_loop_fn(loader, epoch):
test_steps = testsize // (FLAGS.test_set_batch_size * xm.xrt_world_size())
total_samples, correct = 0, 0
model.eval()
for step, (data, target) in enumerate(loader):
output = model(data)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum()
total_samples += data.size()[0]
if step % FLAGS.log_steps == 0:
xm.add_step_closure(
test_utils.print_test_update, args=(device, None, epoch, step))
if step == test_steps:
break
correct_val = correct.item()
accuracy_replica = 100.0 * correct_val / total_samples
accuracy = xm.mesh_reduce('test_accuracy', accuracy_replica, np.mean)
return accuracy, accuracy_replica, total_samples
# setup epoch loop
train_device_loader = pl.MpDeviceLoader(train_loader, device)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
accuracy, max_accuracy = 0.0, 0.0
training_start_time = time.time()
if FLAGS.load_chkpt_file != "":
best_valid_acc = checkpoint['best_valid_acc']
start_epoch = checkpoint['epoch']
xm.master_print('Loaded Model CheckPoint: Epoch={}/{}, Val Accuracy={:.2f}%'.format(
start_epoch, FLAGS.num_epochs, best_valid_acc))
else:
best_valid_acc = 0.0
start_epoch = 1
for epoch in range(start_epoch, FLAGS.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(
epoch, test_utils.now()))
replica_epoch_start = time.time()
replica_train_samples, reduced_global = train_loop_fn(train_device_loader, epoch)
replica_epoch_time = time.time() - replica_epoch_start
avg_epoch_time_mesh = xm.mesh_reduce('epoch_time', replica_epoch_time, np.mean)
reduced_global = reduced_global * xm.xrt_world_size()
xm.master_print('Epoch {} train end {}, Epoch Time={}, Replica Train Samples={}, Reduced GlobalRate={:.2f}'.format(
epoch, test_utils.now(),
str(datetime.timedelta(seconds=avg_epoch_time_mesh)).split('.')[0],
replica_train_samples,
reduced_global))
accuracy, accuracy_replica, replica_test_samples = test_loop_fn(test_device_loader, epoch)
xm.master_print('Epoch {} test end {}, Reduced Accuracy={:.2f}%, Replica Accuracy={:.2f}%, Replica Test Samples={}'.format(
epoch, test_utils.now(),
accuracy, accuracy_replica,
replica_test_samples))
if FLAGS.save_model != "":
if accuracy > best_valid_acc:
xm.master_print('Epoch {} validation accuracy improved from {:.2f}% to {:.2f}% - saving model...'.format(epoch, best_valid_acc, accuracy))
best_valid_acc = accuracy
xm.save(
{
"epoch": epoch,
"nepochs": FLAGS.num_epochs,
"model_state_dict": model.state_dict(),
"best_valid_acc": best_valid_acc,
"opt_state_dict": optimizer.state_dict(),
},
FLAGS.save_model,
)
if xm.is_master_ordinal():
_upload_blob_gcs(FLAGS.logdir, FLAGS.save_model, 'model-chkpt.pt')
max_accuracy = max(accuracy, max_accuracy)
test_utils.write_to_summary(
writer,
epoch,
dict_to_write={'Accuracy/test': accuracy,
'Global Rate': reduced_global},
write_xla_metrics=False)
if FLAGS.metrics_debug:
xm.master_print(met.metrics_report())
test_utils.close_summary_writer(writer)
total_train_time = time.time() - training_start_time
xm.master_print('Total Train Time: {}'.format(str(datetime.timedelta(seconds=total_train_time)).split('.')[0]))
xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
xm.master_print('Avg. Global Rate: {:.2f} examples per second'.format(reduced_global))
return max_accuracy
def _mp_fn(index, flags):
global FLAGS
FLAGS = flags
torch.set_default_tensor_type('torch.FloatTensor')
accuracy = train_imagenet()
if accuracy < FLAGS.target_accuracy:
print('Accuracy {} is below target {}'.format(accuracy,
FLAGS.target_accuracy))
sys.exit(21)
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores, start_method='fork') # , start_method='spawn'