forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
core.py
3070 lines (2653 loc) · 116 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
## @package core
# Module caffe2.python.core
from collections import namedtuple, OrderedDict, defaultdict
from past.builtins import basestring
from itertools import chain
from typing import Dict
from caffe2.proto import caffe2_pb2
from caffe2.python import scope, utils, workspace
from caffe2.python.lazy import TriggerLazyImport
from caffe2.python.control_ops_grad import \
gen_do_gradient, gen_if_gradient, gen_while_gradient, disambiguate_grad_if_op_output
import caffe2.python._import_c_extension as C
import copy
import pickle
import numpy as np
import sys
import traceback
import os
# Mac os specific message
if (sys.platform == 'darwin' and 'leveldb' in C.registered_dbs()):
print('If you are using homebrew leveldb on a Mac OS, you might see an '
'error warning you that malloc_zone_unregister() failed. This is '
'not a caffe2 issue but is due to the homebrew leveldb having an '
'incompatible memory allocator. It does not affect usage.')
# Convenience redirections to functions inside scope.
DeviceScope = scope.DeviceScope
NameScope = scope.NameScope
# Bring datatype enums to the main namespace
class DataType:
UNDEFINED = 0
FLOAT = 1
INT32 = 2
BYTE = 3
STRING = 4
BOOL = 5
UINT8 = 6
INT8 = 7
UINT16 = 8
INT16 = 9
INT64 = 10
FLOAT16 = 12
DOUBLE = 13
ZERO_COLLISION_HASH = 14
REBATCHING_BUFFER = 15
def _CheckDataType():
# Verify that the DataType values defined above match the ones defined in
# the caffe2.proto file
for name, value in caffe2_pb2.TensorProto.DataType.items():
py_value = getattr(DataType, name, None)
if py_value != value:
raise AssertionError(
f"DataType {name} does not match the value defined in "
f"caffe2.proto: {py_value} vs {value}"
)
_CheckDataType()
def _GetRegisteredOperators():
return set(workspace.RegisteredOperators())
_REGISTERED_OPERATORS = _GetRegisteredOperators()
def RefreshRegisteredOperators(trigger_lazy=True):
if trigger_lazy:
TriggerLazyImport()
global _REGISTERED_OPERATORS
_REGISTERED_OPERATORS = _GetRegisteredOperators()
_GLOBAL_INIT_ARGS = []
def GlobalInit(args):
TriggerLazyImport()
_GLOBAL_INIT_ARGS.extend(args[1:])
C.global_init(args)
def GetGlobalInitArgs():
return _GLOBAL_INIT_ARGS[:]
def IsOperator(op_type):
return IsOperatorWithEngine(op_type, engine='DEFAULT')
def IsOperatorWithEngine(op_type, engine):
TriggerLazyImport()
return C.op_registry_key(op_type, engine) in _REGISTERED_OPERATORS
def IsGPUDeviceType(device_type):
return device_type in {caffe2_pb2.CUDA, caffe2_pb2.HIP}
def DeviceOption(
device_type,
device_id=0,
random_seed=None,
node_name=None,
numa_node_id=None,
extra_info=None,
):
option = caffe2_pb2.DeviceOption()
option.device_type = device_type
option.device_id = device_id
if node_name is not None:
option.node_name = node_name
if random_seed is not None:
option.random_seed = random_seed
if numa_node_id is not None:
assert device_type == caffe2_pb2.CPU
option.numa_node_id = numa_node_id
if extra_info is not None:
option.extra_info.extend(extra_info)
return option
def device_option_equal(opt1, opt2, ignore_node_name=True, ignore_random_seed=True):
if not opt1 or not opt2:
return opt1 == opt2
if not ignore_node_name and opt1.node_name != opt2.node_name:
return False
if not ignore_random_seed and opt1.random_seed != opt2.random_seed:
return False
if not opt1.device_type or not opt2.device_type:
# At least one option is for CPU, check if both are for CPU.
return not opt1.device_type and not opt2.device_type
return opt1.device_id == opt2.device_id
def InferBlobDevices(net):
'''
Compute mapping from parameters to devices by looking at the
device option of the op that creates the blob has
'''
mapping = {}
for op in net.Proto().op:
op_device = op.device_option
if op_device is None:
op_device = caffe2_pb2.DeviceOption(caffe2_pb2.CPU)
# TODO: T18892922, use device annotations
for b in op.output:
mapping[b] = op_device
return mapping
def InferOpBlobDevicesAsDict(op):
input_dev_list, output_dev_list = InferOpBlobDevices(op)
input_dict = {
op.input[i]: input_dev_list[i]
for i in range(len(op.input))
}
output_dict = {
op.output[i]: output_dev_list[i]
for i in range(len(op.output))
}
return input_dict, output_dict
def InferOpBlobDevices(op):
device_info = C.infer_op_input_output_device(op.SerializeToString())
input_info = []
output_info = []
for dev_str in device_info[0]:
device_option = caffe2_pb2.DeviceOption()
device_option.ParseFromString(dev_str)
input_info.append(device_option)
for dev_str in device_info[1]:
device_option = caffe2_pb2.DeviceOption()
device_option.ParseFromString(dev_str)
output_info.append(device_option)
return input_info, output_info
def InferOpDeviceAsBlobDevices(op):
op_dev = op.device_option if op.device_option else caffe2_pb2.DeviceOption()
input_dev = [op_dev] * len(op.input)
output_dev = [op_dev] * len(op.output)
return input_dev, output_dev
GradientSlice = namedtuple('GradientSlice', ['indices', 'values'])
class BlobReference:
"""A wrapper around a blob in a net.
BlobReference gives us a way to refer to the network that the blob is
generated from. Note that blobs are, essentially, just strings in the
current workspace.
"""
def __init__(self, name, net=None):
"""Initializes a blob reference.
Note that this does not prepends the namescope. If needed, use
ScopedBlobReference() to prepend the existing namespace.
"""
if isinstance(name, str):
self._name = name
elif isinstance(name, bytes):
self._name = name.decode('utf-8')
else:
self._name = str(name)
self._from_net = net
# meta allows helper functions to put whatever metainformation needed
# there.
self.meta = {}
def __hash__(self):
return hash(self._name)
def __eq__(self, other):
if isinstance(other, str):
return self._name == other
elif isinstance(other, bytes):
return self._name == other.decode('utf-8')
elif isinstance(other, BlobReference):
return self._name == other._name
else:
return False
def __ne__(self, other):
return not(self == other)
def __str__(self):
return self._name
def __repr__(self):
return 'BlobReference("{}")'.format(self._name)
def __add__(self, other):
if not isinstance(other, str):
raise RuntimeError('Cannot add BlobReference to a non-string.')
return BlobReference(self._name + other, self._from_net)
def __radd__(self, other):
if not isinstance(other, str):
raise RuntimeError('Cannot add a non-string to BlobReference.')
return BlobReference(other + self._name, self._from_net)
def Net(self):
return self._from_net
def GetNameScope(self):
return self._name[:self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1]
def GetUnscopedName(self):
return self._name[self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1:]
def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
"""Internal function that routes the operator generation to the
network's __getattr__ function.
"""
inputs = [] if inputs is None else inputs
if isinstance(inputs, BlobReference) or isinstance(inputs, str):
inputs = [inputs]
# add self to the input list.
inputs.insert(0, self)
return self._from_net.__getattr__(op_type)(inputs, *args, **kwargs)
def __getattr__(self, op_type):
"""A wrapper allowing one to initiate operators from a blob reference.
Example: for a blob reference b that comes from network n, doing
b.Relu(...)
is equivalent to doing
net.Relu([b], ...)
"""
if op_type.startswith('__'):
raise AttributeError('Attribute {} not found.'.format(op_type))
if self._from_net is None:
raise AttributeError(
'You cannot use a blob reference that does not have a net '
'source to create operators. Create the operator from an '
'explicit net object.')
if not IsOperator(op_type):
raise AttributeError(
'Method ' + op_type + ' is not a registered operator.' +
' Did you mean: [' +
",".join(workspace.C.nearby_opnames(op_type)) + ']'
)
return lambda *args, **kwargs: self._CreateAndAddToNet(
op_type, *args, **kwargs)
def __dir__(self):
TriggerLazyImport()
additional_methods = [
op
for op in _REGISTERED_OPERATORS
if '_ENGINE_' not in op or '_ENGINE_CUDNN' in op]
return sorted(set(chain(
dir(type(self)),
self.__dict__.keys(),
additional_methods
)))
def ScopedName(name):
"""prefix the name with the current scope."""
if isinstance(name, bytes):
name = name.decode('ascii')
return scope.CurrentNameScope() + name
def ScopedBlobReference(name, *args, **kwargs):
"""Returns a blob reference with scope prefixed."""
return BlobReference(ScopedName(name), *args, **kwargs)
def _RectifyInputOutput(blobs, net=None):
"""A helper function to rectify the input or output of the CreateOperator
interface.
"""
if isinstance(blobs, (bytes, str)):
# If blobs is a single string, prepend scope.CurrentNameScope()
# and put it as a list.
# TODO(jiayq): enforce using BlobReference instead of raw strings.
return [ScopedBlobReference(blobs, net=net)]
elif type(blobs) is BlobReference:
# If blob is a BlobReference, simply put it as a list.
return [blobs]
elif type(blobs) in (list, tuple):
# If blob is a list, we go through it and type check.
rectified = []
for blob in blobs:
if isinstance(blob, (bytes, str)):
rectified.append(ScopedBlobReference(blob, net=net))
elif type(blob) is BlobReference:
rectified.append(blob)
else:
raise TypeError(
"I/O blob #{} of unsupported type: {} of type {}"
.format(len(rectified), str(blob), type(blob)))
return rectified
else:
raise TypeError(
"Unknown input/output type: %s of type %s." %
(str(blobs), type(blobs))
)
def CreateOperator(
operator_type,
inputs,
outputs,
name='',
control_input=None,
device_option=None,
arg=None,
engine=None,
debug_info=None,
**kwargs
):
"""A function wrapper that allows one to create operators based on the
operator type. The type should be a string corresponding to an operator
registered with Caffe2.
"""
operator = caffe2_pb2.OperatorDef()
if (os.environ.get('CAFFE2_DEBUG')):
stack = traceback.format_stack()
operator.debug_info = "".join(stack[:-1])
operator.type = operator_type
operator.name = name
# Add rectified inputs and outputs
inputs = _RectifyInputOutput(inputs)
outputs = _RectifyInputOutput(outputs)
operator.input.extend(map(str, inputs))
operator.output.extend(map(str, outputs))
if control_input:
control_input = _RectifyInputOutput(control_input)
operator.control_input.extend(map(str, control_input))
# Set device option:
# (1) If device_option is explicitly set, use device_option.
# (2) If not, but scope.CurrentDeviceScope() is set,
# then we use scope.CurrentDeviceScope().
# (3) Otherwise, do not set device option.
if device_option is not None:
operator.device_option.CopyFrom(device_option)
elif scope.CurrentDeviceScope() is not None:
operator.device_option.CopyFrom(scope.CurrentDeviceScope())
if engine is not None:
operator.engine = engine
if debug_info is not None:
operator.debug_info = debug_info
# random seed is defined in the device option, so we need to do special
# care.
if 'random_seed' in kwargs:
operator.device_option.random_seed = kwargs['random_seed']
del kwargs['random_seed']
# Add given arguments that do not need parsing
if arg is not None:
operator.arg.extend(arg)
# Add all other arguments
for key, value in kwargs.items():
if value is not None:
operator.arg.add().CopyFrom(utils.MakeArgument(key, value))
if workspace.IsImmediate():
workspace.RunOperatorImmediate(operator)
return operator
def _RegisterPythonImpl(
f, grad_f=None, python_func_type=None, pass_workspace=False
):
if python_func_type:
func = python_func_type(f)
f = func.forward
grad_f = func.backward
else:
if isinstance(f, tuple):
f = f[0](*f[1], **f[2])
if isinstance(grad_f, tuple):
grad_f = grad_f[0](*grad_f[1], **grad_f[2])
token = C.register_python_op(f, pass_workspace, '')
if grad_f:
C.register_python_gradient_op(token, grad_f)
return token
def CreatePythonOperator(
f, inputs,
outputs,
grad_f=None,
pass_workspace=False,
python_func_type=None,
*args,
**kwargs
):
"""
`f` should have a signature (inputs, outputs)
If `pass_workspace` is True, the signature is changed to
(inputs, outputs, workspace) where `workspace` is the workspace the op
is going to run on. This is potentially dangerous (as the op can manipulate
the workspace directly), use on your own risk.
"""
kwargs["token"] = _RegisterPythonImpl(
f, grad_f, python_func_type, pass_workspace=pass_workspace
)
return CreateOperator("Python", inputs, outputs, *args, **kwargs)
def GetIndexFromGradientList(g_list, name):
"""A helper function to get the index from a gradient list, None if not
matching."""
for i, g in enumerate(g_list):
if g == name:
return i
elif type(g) is GradientSlice:
if (g.indices == name or g.values == name):
return i
return None
OpSSA = namedtuple('OpSSA', ['op', 'in_versions', 'out_versions'])
GradGenMeta = namedtuple('GradGenMeta',
['grad_op', 'idx', 'gradient', 'device_option'])
SparseGradGenMeta = namedtuple('SparseGradGenMeta', [
'grad_op_indices', 'idx_indices',
'grad_op_values', 'idx_values',
'gradient', 'device_option',
])
class IR:
"""A simple IR class to keep track of all intermediate representations used
in the gradient computation.
"""
def __init__(self, operators):
# The IR class holds multiple metadata from the forward pass:
# a) ssa: a list of [op, in_versions, out_versions] recording the
# input and the output version of each operator, similar
# to a normal SSA form.
# b) input_usages: a dictionary specifying for each blob and
# each of its version, how many times it is used as input for another
# op.
# c) frontier: maintaining the current versions of the blobs
# we are having in the workspace, after the execution of all the ops
# added to the IR so far. This is useful because if a gradient is
# trying to access an earlier version of a blob, we can sanity check
# that it is no longer there, and thus throw an error.
# d) gradient_frontier: maps the names of blobs to its version that the
# gradient corresponds to.
# e) gradient_generators: for each blob and each of its version, maps to
# a list of operators that generates its gradient together with the
# gradient name.
self.ssa = []
self.input_usages = defaultdict(lambda: defaultdict(list))
self.frontier = defaultdict(int)
self.gradient_frontier = {}
self.gradient_generators = defaultdict(lambda: defaultdict(list))
self.out_version_history = defaultdict(list)
self.in_version_history = defaultdict(list)
for op in operators:
self.Play(op)
self.SanityCheck(operators)
def SanityCheck(self, operators):
# Validate StopGradient usage by checking that StopGradient's output
# is actually passed forward
for op in operators:
if op.type == 'StopGradient':
if op.output[0] not in self.input_usages:
raise ValueError("""StopGradient's output '{}' is orphan.
You typically want to specify same input and output for
StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
def Play(self, op):
""""Adds an op to the current IR, and update the internal states to
reflect the blobs and versions after the execution of the op.
"""
# For input, they are the current version in the dict.
in_versions = {}
for s in op.input:
in_versions[s] = self.frontier[s]
self.input_usages[s][self.frontier[s]].append(len(self.ssa))
self.in_version_history[s].append((op, self.frontier[s]))
# For output, they are the current version plus one. If this is a
# newly created blob, its version starts with zero.
out_versions = {}
for s in op.output:
if s in self.frontier:
self.frontier[s] += 1
out_versions[s] = self.frontier[s]
self.out_version_history[s].append((op, self.frontier[s]))
# Add to SSA for bookkeeping.
self.ssa.append(OpSSA(op, in_versions, out_versions))
def CheckGradientOperatorInput(
self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs):
"""Checks if the gradient operators can be correctly carried out."""
forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
original_index = GetIndexFromGradientList(g_output, grad_op_input)
# Functions to generate debug help for version-mismatches
def versionMismatchInfoOut(name):
s = "DEBUG HELP:\n"
s += "Maybe you use same output blob twice for different ops?\n"
s += "== Version history of blob [{}]\n".format(name)
for (op, vers) in self.out_version_history[name]:
s += "Version (out) {} <-- {}".format(vers, op)
s += "\n"
return s
def versionMismatchInfoIn(name):
s = "DEBUG HELP:\n"
s += "Maybe the blob was overwritten by another op?\n"
s += "== Version history of blob [{}]\n".format(name)
for (op, vers) in self.in_version_history[name]:
s += "version (in) {} <-- {}".format(vers, op)
s += "\n"
return s
# If it is a dense or sparse gradient name, it should match the
# version of the corresponding output.
if original_index is not None:
original_name = forward_op.output[original_index]
if (out_versions[original_name] !=
self.gradient_frontier[original_name]):
raise RuntimeError(
'Gradient name "%s" is expected to correspond '
'to version %d of "%s", but currently we have '
'version %d.\n\n' % (
grad_op_input, out_versions[original_name],
original_name,
self.gradient_frontier[original_name]) +
versionMismatchInfoOut(original_name))
# If it is an output name, the current version should match the
# version when the operator was run.
elif grad_op_input in out_versions:
if self.frontier[grad_op_input] != out_versions[grad_op_input]:
raise RuntimeError(
'Gradient operator needs output "%s" at version'
' %d, but currently we have version %d.\n\n' % (
grad_op_input, out_versions[grad_op_input],
self.frontier[grad_op_input]
) + versionMismatchInfoOut(grad_op_input)
)
# If it is an input name, the current version should match the
# version when the operator was run.
elif grad_op_input in in_versions:
if (self.frontier[grad_op_input] != in_versions[grad_op_input]):
raise RuntimeError(
'Gradient operator needs input "%s" at version '
'%d, but currently we have version %d.\n\n' % (
grad_op_input, in_versions[grad_op_input],
self.frontier[grad_op_input]
) + versionMismatchInfoIn(grad_op_input)
)
# If it is none of the above, it should be a blob that is
# generated locally by one of the previous gradient operators.
else:
if grad_op_input not in locally_generated_blobs:
raise RuntimeError(
'Blob name "%s" not in the scope of operator: '
'%s\nand is not generated by any of the local '
'gradient operators.' % (grad_op_input, str(forward_op))
)
def AppendSparseGenerators(self, sparse_generators):
# merge indices and values generators for sparse gradients
for name, input_generators in sparse_generators.items():
for version, generators in input_generators.items():
if len(generators) == 1:
# either indices or values are generated (but not both)
generator = generators[0]
else:
# both indices and values are generated
assert(len(generators) == 2)
op1_i, idx1_i, op1_v, idx1_v, g1, dev_1 = generators[0]
op2_i, idx2_i, op2_v, idx2_v, g2, dev_2 = generators[1]
assert(g1 == g2)
assert dev_1 == dev_2, (
"Unequal devices for sparse generators: "
"{} and {}".format(dev_1, dev_2)
)
assert(op1_i is None or op2_i is None)
assert(op1_v is None or op2_v is None)
assert(idx1_i == 0 or idx2_i == 0)
assert(idx1_v == 0 or idx2_v == 0)
generator = SparseGradGenMeta(
op1_i or op2_i, idx1_i + idx2_i,
op1_v or op2_v, idx1_v + idx2_v,
g1, dev_1)
self.gradient_generators[name][version].append(generator)
def BuildGradientGenerators( # NOQA
self, fwd_op_idx, gradient_ops, g_output, g_input):
"""Updates gradient_generators and gradient_frontier"""
forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
locally_generated_blobs = []
sparse_generators = defaultdict(lambda: defaultdict(list))
for grad_op in gradient_ops:
# (1) check that inputs are valid
for s in grad_op.input:
self.CheckGradientOperatorInput(
s, g_output, fwd_op_idx, locally_generated_blobs)
# (2) add outputs to the locally generated blobs
# If an output corresponds to the gradient of an input, we also
# record it to gradient_generators
locally_generated_blobs.extend(map(str, grad_op.output))
for i, output in enumerate(grad_op.output):
input_index = GetIndexFromGradientList(g_input, output)
if input_index is not None:
input_name = forward_op.input[input_index]
input_version = in_versions[input_name]
g = g_input[input_index]
if type(g) is GradientSlice:
# the output corresponds either to the indices or the
# values of the sparse gradient. In either case we
# create a (partial) SparseGradGenMeta. If necessary,
# we'll merge indices and values generators
# corresponding to the same gradient in step (3)
if g.indices == output:
m = SparseGradGenMeta(
grad_op, i, None, 0, g, grad_op.device_option)
else:
assert(g.values == output)
m = SparseGradGenMeta(
None, 0, grad_op, i, g, grad_op.device_option)
sparse_generators[input_name][input_version].append(m)
else:
self.gradient_generators[input_name][input_version] \
.append(GradGenMeta(
grad_op, i, g, grad_op.device_option))
# (3) merge indices and values generators for sparse gradients, and
# add them to gradient_generators
self.AppendSparseGenerators(sparse_generators)
# (4) for ops (e.g., Add, Sum, Sub) which have gradient outputs directly
# passed from inputs (not computed from gradient ops), we create an
# GradGenMeta with None grad_op and idx so that the gradient_generators
# knows where the gradients are coming from. This is needed for creating
# Sum op to accumulate the gradients from multiple parents.
for input_index, g in enumerate(g_input):
input_name = forward_op.input[input_index]
input_version = in_versions[input_name]
if not g:
continue
if type(g) is GradientSlice:
if str(g.indices) not in locally_generated_blobs and \
str(g.values) not in locally_generated_blobs:
self.gradient_generators[input_name][input_version].append(
SparseGradGenMeta(None, 0, None, 0, g, forward_op.device_option))
else:
if str(g) not in locally_generated_blobs:
self.gradient_generators[input_name][input_version].append(
GradGenMeta(None, 0, g, forward_op.device_option))
# Finally, for the gradients specified in g_input, we update the
# gradient frontier to reflect the input versions that the gradients
# correspond to.
for i, g in enumerate(g_input):
if g is not None:
input_name = forward_op.input[i]
input_version = in_versions[input_name]
self.gradient_frontier[input_name] = input_version
def _GetSumOpOutputName(self, generator, input_name):
def remove_suffix(s, suffix):
if s.endswith(suffix):
return s[:-len(suffix)]
return s
for g in generator:
if type(g) is GradGenMeta:
grad_op, idx, _, _ = g
if grad_op:
return grad_op.output[idx]
else:
assert(type(g) is SparseGradGenMeta)
op_i, idx_i, op_v, idx_v, _, _ = g
if op_i:
return remove_suffix(op_i.output[idx_i], '_indices')
if op_v:
return remove_suffix(op_v.output[idx_v], '_values')
return input_name + '_grad'
IS_AUTO_GEN_SUM_OPS_TAG = "is_auto_gen_sum_ops"
ONLY_KEEP_IS_AUTO_GEN_SUM_OPS_TAG = "only_keep_is_auto_gen_sum_ops_tag"
def _SetSumOpsDeviceOption(self, sum_ops, generators):
only_keep_is_auto_gen_sum_ops_tag = False
for generator in generators:
# we already checked that device options are consistent so we can just
# break after finding the first clear_info request
for extra_info in generator.device_option.extra_info:
if extra_info == "{}:1".format(IR.ONLY_KEEP_IS_AUTO_GEN_SUM_OPS_TAG):
only_keep_is_auto_gen_sum_ops_tag = True
break
if only_keep_is_auto_gen_sum_ops_tag:
# if we find that device_option in the generator that
# requires clear the extra info for the auto gen sum
# Then we will try to clear them and only leave the
# IS_AUTO_GEN_SUM_OPS_TAG
for op in sum_ops:
op.device_option.extra_info.extend([
"{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG)
])
else:
# we already checked that device options are consistent so we can just
# use the first one we find
for generator in generators:
for op in sum_ops:
op.device_option.CopyFrom(generator.device_option)
op.device_option.extra_info.extend([
"{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG)
])
break
def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
new_grad_output = (
'_' + grad_op.output[idx] + '_autosplit_{}'.format(cnt))
if grad_op.type == "If":
disambiguate_grad_if_op_output(grad_op, idx, new_grad_output)
else:
grad_op.output[idx] = new_grad_output
return grad_op.output[idx], cnt + 1
def _CheckSumOpsConflict(self, out_base_name, g):
if str(out_base_name) == str(g):
# TODO not sure what this message really means
raise RuntimeError(
'The gradient output of empty gradient op can not '
'be the same as the normal name of the current '
'input gradient.')
def _MakeDenseSumOps(self, generators, out_base_name):
sum_op_input = []
cnt = 0
assert len(generators) > 1
first_grad_op = True
for generator in generators:
grad_op, idx, g, _ = generator
assert(type(g) is not GradientSlice)
if grad_op:
if first_grad_op:
first_grad_op = False
out = grad_op.output[idx]
else:
out, cnt = self._DisambiguateGradOpOutput(grad_op, idx, cnt)
sum_op_input.append(out)
else:
self._CheckSumOpsConflict(out_base_name, g)
sum_op_input.append(str(g))
if out_base_name in sum_op_input:
# Sum inplace mode works only for the first input
# So we do a swap
idx = sum_op_input.index(out_base_name)
sum_op_input[0], sum_op_input[idx] = (
sum_op_input[idx], sum_op_input[0]
)
sum_ops = [CreateOperator(
"Sum",
[BlobReference(x) for x in sum_op_input],
BlobReference(out_base_name))]
return sum_ops, out_base_name
def _MakeSparseSumOps(self, generators, out_base_name):
indices_concat_input = []
values_concat_input = []
cnt_i = 0
cnt_v = 0
for generator in generators:
assert(type(generator) is SparseGradGenMeta)
op_i, idx_i, op_v, idx_v, g, _ = generator
if op_i:
out, cnt_i = self._DisambiguateGradOpOutput(op_i, idx_i, cnt_i)
indices_concat_input.append(out)
else:
self._CheckSumOpsConflict(out_base_name, g.indices)
indices_concat_input.append(g.indices)
if op_v:
out, cnt_v = self._DisambiguateGradOpOutput(op_v, idx_v, cnt_v)
values_concat_input.append(out)
else:
self._CheckSumOpsConflict(out_base_name, g.values)
values_concat_input.append(g.values)
indices_concat_output = out_base_name + '_indices_concat'
indices_concat_split = out_base_name + '_indices_concat_split'
values_concat_output = out_base_name + '_values_concat'
values_concat_split = out_base_name + '_values_concat_split'
# Sum the given sparse representations by simply concatenating the
# indices (resp. values) tensors together. We don't do any deduplication
# of indices at this point. This will be done as needed before the
# optimizer is called
sum_ops = [
CreateOperator(
"Concat",
[BlobReference(x) for x in indices_concat_input],
[BlobReference(x) for x in
[indices_concat_output, indices_concat_split]],
axis=0
),
CreateOperator(
"Concat",
[BlobReference(x) for x in values_concat_input],
[BlobReference(x) for x in
[values_concat_output, values_concat_split]],
axis=0
),
]
sum_op_output = GradientSlice(
indices=indices_concat_output,
values=values_concat_output,
)
return sum_ops, sum_op_output
def _MakeSumOps(self, input_name, input_version):
generators = self.gradient_generators[input_name][input_version]
out_base_name = self._GetSumOpOutputName(generators, input_name)
types = list(set(type(x) for x in generators))
assert(len(types) == 1)
if types[0] is GradGenMeta:
sum_ops, g = self._MakeDenseSumOps(generators, out_base_name)
else:
assert(types[0] is SparseGradGenMeta)
sum_ops, g = self._MakeSparseSumOps(generators, out_base_name)
self._SetSumOpsDeviceOption(sum_ops, generators)
return sum_ops, g
def _VerifyGradientGenerators(self, generator):
# (1) check if all gradients are of the same type. Aggregating a mix of
# sparse and dense gradients is not supported yet
if len({type(g) for g in generator}) > 1:
raise RuntimeError(
'Automatic aggregation of a mix of sparse and dense gradients '
'is not supported yet')
# If for all the operators that used the operator, none or only one
# produced the gradient, then no additional sum needs to be carried
# out.
if len(generator) < 2:
return False
all_gradient_names = []
all_device_options = []
for g in generator:
if g.device_option:
all_device_options.append(g.device_option)
if type(g) is GradGenMeta:
if g.grad_op:
all_gradient_names.append(g.gradient)
else:
assert(type(g) is SparseGradGenMeta)
if g.gradient.values:
all_gradient_names.append(g.gradient.values)
# Check if all grad op device options are the same.
if len(all_device_options) >= 2 and not all(
device_option_equal(d, all_device_options[0])
for d in all_device_options[1:]):
raise RuntimeError('Unexpected behavior: not all grad ops '
'have the same device option.')
return True
def DoGradientAccumulation(self, fwd_op_idx):
"""For each input name in the forward op, check if we will need to
add gradient accumulation. If so, do gradient accumulation and return
the list of gradient operators.
The criteria for doing gradient accumulation is:
(1) the specific input version has been used by multiple operators.
(2) the current fwd_op_idx is the first to use that input, i.e. in the
backward pass, is the last to optionally generate the gradient for
the op.
(3) For the operators that used the input, their gradient operators
have generated more than 1 gradient.
When accumulating operators, our current solution is to rename all the
created gradients with an internal intermediate name, and then add a
Sum() operator that adds up all the gradients. This may use more memory
due to intermediate storage, but is usually the fastest approach as one
can do one single sum for multiple intermediate gradients.
"""
forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
additional_sum_ops = []
grad_map = {}
for _i, input_name in enumerate(set(forward_op.input)):
input_version = in_versions[input_name]
input_usage = self.input_usages[input_name][input_version]
if (len(input_usage) <= 1 or fwd_op_idx != input_usage[0]):
# We do not need to do gradient accumulation yet.
continue
generator = self.gradient_generators[input_name][input_version]
try:
if not self._VerifyGradientGenerators(generator):
continue
except RuntimeError as err:
raise RuntimeError(
"Gradients for param ''{}'' failed to verify: {}".format(
input_name,
err
)
) from err
# Finally, let's create the sum operator.
sum_ops, g = self._MakeSumOps(input_name, input_version)
additional_sum_ops.extend(sum_ops)
grad_map[input_name] = g
return additional_sum_ops, grad_map
def _AppendAutoGradGenerator(self, y, grad, autograd_op):
# Gradient here is not sparse as it was generated by
# a ConstantFill operator. Autogeneration for sparse gradients is
# not supported
generator = GradGenMeta(
autograd_op, 0 if autograd_op else None, str(grad),
autograd_op.device_option)
self.gradient_generators[str(y)][self.frontier[str(y)]].append(
generator)
AUTOGEN_GRAD_SUFFIX = "_autogen_grad"
def _GetInitGradients(self, ys):
input_to_grad = {}
gradient_ops = []
for y, g in ys.items():
autograd_op = None
if g is None:
autograd_op = CreateOperator(