-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodel.py
120 lines (100 loc) · 4.51 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
import torch
# torch.set_default_tensor_type(torch.cuda.FloatTensor)
import torch.nn as nn
from torch.nn.modules.module import _addindent
class Conv2dWithConstraint(nn.Conv2d):
def __init__(self, *args, max_norm=1, **kwargs):
self.max_norm = max_norm
super(Conv2dWithConstraint, self).__init__(*args, **kwargs)
def forward(self, x):
self.weight.data = torch.renorm(
self.weight.data, p=2, dim=0, maxnorm=self.max_norm
)
return super(Conv2dWithConstraint, self).forward(x)
class EEGNet(nn.Module):
def InitialBlocks(self, dropoutRate, *args, **kwargs):
block1 = nn.Sequential(
nn.Conv2d(1, self.F1, (1, self.kernelLength), stride=1, padding=(0, self.kernelLength // 2), bias=False),
nn.BatchNorm2d(self.F1, momentum=0.01, affine=True, eps=1e-3),
# DepthwiseConv2D =======================
Conv2dWithConstraint(self.F1, self.F1 * self.D, (self.channels, 1), max_norm=1, stride=1, padding=(0, 0),
groups=self.F1, bias=False),
# ========================================
nn.BatchNorm2d(self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3),
nn.ELU(),
nn.AvgPool2d((1, 4), stride=4),
nn.Dropout(p=dropoutRate))
block2 = nn.Sequential(
# SeparableConv2D =======================
nn.Conv2d(self.F1 * self.D, self.F1 * self.D, (1, self.kernelLength2), stride=1,
padding=(0, self.kernelLength2 // 2), bias=False, groups=self.F1 * self.D),
nn.Conv2d(self.F1 * self.D, self.F2, 1, padding=(0, 0), groups=1, bias=False, stride=1),
# ========================================
nn.BatchNorm2d(self.F2, momentum=0.01, affine=True, eps=1e-3),
nn.ELU(),
nn.AvgPool2d((1, 8), stride=8),
nn.Dropout(p=dropoutRate))
return nn.Sequential(block1, block2)
def ClassifierBlock(self, inputSize, n_classes):
return nn.Sequential(
nn.Linear(inputSize, n_classes, bias=False),
nn.Softmax(dim=1))
def CalculateOutSize(self, model, channels, samples):
'''
Calculate the output based on input size.
model is from nn.Module and inputSize is a array.
'''
data = torch.rand(1, 1, channels, samples)
model.eval()
out = model(data).shape
return out[2:]
def __init__(self, n_classes=4, channels=60, samples=151,
dropoutRate=0.5, kernelLength=64, kernelLength2=16, F1=8,
D=2, F2=16):
super(EEGNet, self).__init__()
self.F1 = F1
self.F2 = F2
self.D = D
self.samples = samples
self.n_classes = n_classes
self.channels = channels
self.kernelLength = kernelLength
self.kernelLength2 = kernelLength2
self.dropoutRate = dropoutRate
self.blocks = self.InitialBlocks(dropoutRate)
self.blockOutputSize = self.CalculateOutSize(self.blocks, channels, samples)
self.classifierBlock = self.ClassifierBlock(self.F2 * self.blockOutputSize[1], n_classes)
def forward(self, x):
x = self.blocks(x)
x = x.view(x.size()[0], -1) # Flatten
x = self.classifierBlock(x)
return x
def categorical_cross_entropy(y_pred, y_true):
# y_pred = y_pred.cuda()
# y_true = y_true.cuda()
y_pred = torch.clamp(y_pred, 1e-9, 1 - 1e-9)
return -(y_true * torch.log(y_pred)).sum(dim=1).mean()
def torch_summarize(model, show_weights=True, show_parameters=True):
"""Summarizes torch model by showing trainable parameters and weights."""
tmpstr = model.__class__.__name__ + ' (\n'
for key, module in model._modules.items():
# if it contains layers let call it recursively to get params and weights
if type(module) in [
torch.nn.modules.container.Container,
torch.nn.modules.container.Sequential
]:
modstr = torch_summarize(module)
else:
modstr = module.__repr__()
modstr = _addindent(modstr, 2)
params = sum([np.prod(p.size()) for p in module.parameters()])
weights = tuple([tuple(p.size()) for p in module.parameters()])
tmpstr += ' (' + key + '): ' + modstr
if show_weights:
tmpstr += ', weights={}'.format(weights)
if show_parameters:
tmpstr += ', parameters={}'.format(params)
tmpstr += '\n'
tmpstr = tmpstr + ')'
return tmpstr