-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathspiking_resnet.py
428 lines (360 loc) · 18.5 KB
/
spiking_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
# https://github.com/fangwei123456/Spike-Element-Wise-ResNet
# https://spikingjelly.readthedocs.io/zh_CN/latest/activation_based_en/train_large_scale_snn.html
import torch
import torch.nn as nn
from copy import deepcopy
try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torchvision._internally_replaced_utils import load_state_dict_from_url
from spikingjelly.activation_based import neuron, surrogate, functional, layer, base
__all__ = ['SpikingResNet', 'spiking_resnet18', 'spiking_resnet34', 'spiking_resnet50', 'spiking_resnet101',
'spiking_resnet152', 'spiking_resnext50_32x4d', 'spiking_resnext101_32x8d',
'spiking_wide_resnet50_2', 'spiking_wide_resnet101_2']
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
}
# modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return layer.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return layer.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, spiking_neuron: callable = None, **kwargs):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = layer.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.sn1 = spiking_neuron(**deepcopy(kwargs))
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.sn2 = spiking_neuron(**deepcopy(kwargs))
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.sn1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.sn2(out)
return out
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, spiking_neuron: callable = None, **kwargs):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.sn1 = spiking_neuron(**deepcopy(kwargs))
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.sn2 = spiking_neuron(**deepcopy(kwargs))
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.sn3 = spiking_neuron(**deepcopy(kwargs))
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.sn1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.sn2(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.sn3(out)
return out
class SpikingResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, spiking_neuron: callable = None, **kwargs):
super(SpikingResNet, self).__init__()
if norm_layer is None:
norm_layer = layer.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = layer.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.sn1 = spiking_neuron(**deepcopy(kwargs))
self.maxpool = layer.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], spiking_neuron=spiking_neuron, **kwargs)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0], spiking_neuron=spiking_neuron, **kwargs)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1], spiking_neuron=spiking_neuron, **kwargs)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2], spiking_neuron=spiking_neuron, **kwargs)
self.avgpool = layer.AdaptiveAvgPool2d((1, 1))
self.fc = layer.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, layer.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (layer.BatchNorm2d, layer.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False, spiking_neuron: callable = None, **kwargs):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer, spiking_neuron, **kwargs))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer, spiking_neuron=spiking_neuron, **kwargs))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.sn1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
if self.avgpool.step_mode == 's':
x = torch.flatten(x, 1)
elif self.avgpool.step_mode == 'm':
x = torch.flatten(x, 2)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
spiking_neuron=neuron.IFNode
surrogate_function=surrogate.ATan()
detach_reset=True
def _spiking_resnet(arch, block, layers, pretrained, progress, spiking_neuron, **kwargs):
model = SpikingResNet(block, layers, spiking_neuron=spiking_neuron, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def spiking_resnet18(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking ResNet-18
:rtype: torch.nn.Module
A spiking version of ResNet-18 model from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return _spiking_resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, spiking_neuron, **kwargs)
def spiking_resnet34(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking ResNet-34
:rtype: torch.nn.Module
A spiking version of ResNet-34 model from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return _spiking_resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs)
def spiking_resnet50(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking ResNet-50
:rtype: torch.nn.Module
A spiking version of ResNet-50 model from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return _spiking_resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs)
def spiking_resnet101(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking ResNet-101
:rtype: torch.nn.Module
A spiking version of ResNet-101 model from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return _spiking_resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, spiking_neuron, **kwargs)
def spiking_resnet152(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param spiking_neuron: a single step neuron
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking ResNet-152
:rtype: torch.nn.Module
A spiking version of ResNet-152 model from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return _spiking_resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, spiking_neuron, **kwargs)
def spiking_resnext50_32x4d(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param spiking_neuron: a single step neuron
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking ResNeXt-50 32x4d
:rtype: torch.nn.Module
A spiking version of ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _spiking_resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs)
def spiking_resnext101_32x8d(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param spiking_neuron: a single step neuron
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking ResNeXt-101 32x8d
:rtype: torch.nn.Module
A spiking version of ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _spiking_resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, spiking_neuron, **kwargs)
def spiking_wide_resnet50_2(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param spiking_neuron: a single step neuron
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking Wide ResNet-50-2
:rtype: torch.nn.Module
A spiking version of Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
kwargs['width_per_group'] = 64 * 2
return _spiking_resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs)
def spiking_wide_resnet101_2(pretrained=False, progress=True, spiking_neuron: callable=None, **kwargs):
"""
:param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
:type pretrained: bool
:param progress: If True, displays a progress bar of the download to stderr
:type progress: bool
:param spiking_neuron: a single step neuron
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking Wide ResNet-101-2
:rtype: torch.nn.Module
A spiking version of Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
kwargs['width_per_group'] = 64 * 2
return _spiking_resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, spiking_neuron, **kwargs)
if __name__ == "__main__":
model = spiking_resnet18(surrogate_function=surrogate.ATan(), spiking_neuron=neuron.IFNode)
x = torch.randn(10, 16, 3, 64, 64) #(T,B,C,H,W)
functional.set_step_mode(model, step_mode="m")
model(x) #(T,B,D)