-
Notifications
You must be signed in to change notification settings - Fork 558
/
afn.py
263 lines (223 loc) · 11.8 KB
/
afn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""
@author: Baixu Chen
@contact: [email protected]
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.normalization.afn import AdaptiveFeatureNorm, ImageClassifier
from tllib.modules.entropy import entropy
import tllib.vision.models as models
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=False)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
backbone = models.__dict__[args.arch](pretrained=True)
classifier = ImageClassifier(backbone, train_source_dataset.num_classes, args.num_blocks,
bottleneck_dim=args.bottleneck_dim, dropout_p=args.dropout_p, pool_layer=pool_layer).to(device)
adaptive_feature_norm = AdaptiveFeatureNorm(args.delta).to(device)
# define optimizer
# the learning rate is fixed according to origin paper
optimizer = SGD(classifier.get_parameters(), args.lr, weight_decay=args.weight_decay)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, classifier, adaptive_feature_norm, optimizer, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
norm_losses = AverageMeter('Norm Loss', ':3.2f')
src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f')
tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, cls_losses, norm_losses, src_feature_norm, tgt_feature_norm, cls_accs, tgt_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)
x_t, labels_t = next(train_target_iter)
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
labels_t = labels_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, f_s = model(x_s)
y_t, f_t = model(x_t)
# classification loss
cls_loss = F.cross_entropy(y_s, labels_s)
# norm loss
norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)
loss = cls_loss + norm_loss * args.trade_off_norm
# using entropy minimization
if args.trade_off_entropy:
y_t = F.softmax(y_t, dim=1)
entropy_loss = entropy(y_t, reduction='mean')
loss += entropy_loss * args.trade_off_entropy
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update statistics
cls_acc = accuracy(y_s, labels_s)[0]
tgt_acc = accuracy(y_t, labels_t)[0]
cls_losses.update(cls_loss.item(), x_s.size(0))
norm_losses.update(norm_loss.item(), x_s.size(0))
src_feature_norm.update(f_s.norm(p=2, dim=1).mean().item(), x_s.size(0))
tgt_feature_norm.update(f_t.norm(p=2, dim=1).mean().item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
tgt_accs.update(tgt_acc.item(), x_s.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='AFN for Partial Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain')
parser.add_argument('-t', '--target', help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('-n', '--num-blocks', default=1, type=int, help='Number of basic blocks for classifier')
parser.add_argument('--bottleneck-dim', default=1000, type=int, help='Dimension of bottleneck')
parser.add_argument('--dropout-p', default=0.5, type=float,
help='Dropout probability')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float,
metavar='W', help='weight decay (default: 5e-4)',
dest='weight_decay')
parser.add_argument('--trade-off-norm', default=0.05, type=float,
help='the trade-off hyper-parameter for norm loss')
parser.add_argument('--trade-off-entropy', default=None, type=float,
help='the trade-off hyper-parameter for entropy loss')
parser.add_argument('-r', '--delta', default=1, type=float, help='Increment for L2 norm')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='afn',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)