HubHop commited on
Commit
bcfa144
·
1 Parent(s): c8ed6d7
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
__pycache__/datasets.cpython-39.pyc ADDED
Binary file (2.97 kB). View file
 
__pycache__/models_v2.cpython-39.pyc ADDED
Binary file (17.5 kB). View file
 
__pycache__/snnet.cpython-39.pyc ADDED
Binary file (13.6 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (13 kB). View file
 
app.py CHANGED
@@ -1,7 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ import argparse
4
+ import datetime
5
+ import numpy as np
6
+ import time
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import json
10
+
11
+ from pathlib import Path
12
+ from utils import get_root_logger
13
+ from timm.models import create_model
14
+ import models_v2
15
+ import requests
16
+
17
+ import utils
18
+ import time
19
+ import sys
20
+ import datetime
21
+ import os
22
+ from snnet import SNNet, SNNetv2
23
+ import warnings
24
+
25
+ warnings.filterwarnings("ignore")
26
+ from fvcore.nn import FlopCountAnalysis
27
+
28
+ from PIL import Image
29
  import gradio as gr
30
+ import plotly.express as px
31
+
32
+
33
+ def get_args_parser():
34
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
35
+ parser.add_argument('--batch-size', default=64, type=int)
36
+ parser.add_argument('--epochs', default=300, type=int)
37
+ parser.add_argument('--bce-loss', action='store_true')
38
+ parser.add_argument('--unscale-lr', action='store_true')
39
+
40
+ # Model parameters
41
+ parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
42
+ help='Name of model to train')
43
+ parser.add_argument('--input-size', default=224, type=int, help='images input size')
44
+
45
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
46
+ help='Dropout rate (default: 0.)')
47
+ parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
48
+ help='Drop path rate (default: 0.1)')
49
+
50
+ parser.add_argument('--model-ema', action='store_true')
51
+ parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
52
+ parser.set_defaults(model_ema=True)
53
+ parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
54
+ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
55
+
56
+ # Optimizer parameters
57
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
58
+ help='Optimizer (default: "adamw"')
59
+ parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
60
+ help='Optimizer Epsilon (default: 1e-8)')
61
+ parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
62
+ help='Optimizer Betas (default: None, use opt default)')
63
+ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
64
+ help='Clip gradient norm (default: None, no clipping)')
65
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
66
+ help='SGD momentum (default: 0.9)')
67
+ parser.add_argument('--weight-decay', type=float, default=0.05,
68
+ help='weight decay (default: 0.05)')
69
+ # Learning rate schedule parameters
70
+ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
71
+ help='LR scheduler (default: "cosine"')
72
+ parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
73
+ help='learning rate (default: 5e-4)')
74
+ parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
75
+ help='learning rate noise on/off epoch percentages')
76
+ parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
77
+ help='learning rate noise limit percent (default: 0.67)')
78
+ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
79
+ help='learning rate noise std-dev (default: 1.0)')
80
+ parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
81
+ help='warmup learning rate (default: 1e-6)')
82
+ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
83
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
84
+
85
+ parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
86
+ help='epoch interval to decay LR')
87
+ parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
88
+ help='epochs to warmup LR, if scheduler supports')
89
+ parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
90
+ help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
91
+ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
92
+ help='patience epochs for Plateau LR scheduler (default: 10')
93
+ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
94
+ help='LR decay rate (default: 0.1)')
95
+
96
+ # Augmentation parameters
97
+ parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
98
+ help='Color jitter factor (default: 0.3)')
99
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
100
+ help='Use AutoAugment policy. "v0" or "original". " + \
101
+ "(default: rand-m9-mstd0.5-inc1)'),
102
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
103
+ parser.add_argument('--train-interpolation', type=str, default='bicubic',
104
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
105
+
106
+ parser.add_argument('--repeated-aug', action='store_true')
107
+ parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
108
+ parser.set_defaults(repeated_aug=True)
109
+
110
+ parser.add_argument('--train-mode', action='store_true')
111
+ parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
112
+ parser.set_defaults(train_mode=True)
113
+
114
+ parser.add_argument('--ThreeAugment', action='store_true') # 3augment
115
+
116
+ parser.add_argument('--src', action='store_true') # simple random crop
117
+
118
+ # * Random Erase params
119
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
120
+ help='Random erase prob (default: 0.25)')
121
+ parser.add_argument('--remode', type=str, default='pixel',
122
+ help='Random erase mode (default: "pixel")')
123
+ parser.add_argument('--recount', type=int, default=1,
124
+ help='Random erase count (default: 1)')
125
+ parser.add_argument('--resplit', action='store_true', default=False,
126
+ help='Do not random erase first (clean) augmentation split')
127
+
128
+ # * Mixup params
129
+ parser.add_argument('--mixup', type=float, default=0.8,
130
+ help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
131
+ parser.add_argument('--cutmix', type=float, default=1.0,
132
+ help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
133
+ parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
134
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
135
+ parser.add_argument('--mixup-prob', type=float, default=1.0,
136
+ help='Probability of performing mixup or cutmix when either/both is enabled')
137
+ parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
138
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
139
+ parser.add_argument('--mixup-mode', type=str, default='batch',
140
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
141
+
142
+ # Distillation parameters
143
+ parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
144
+ help='Name of teacher model to train (default: "regnety_160"')
145
+ parser.add_argument('--teacher-path', type=str, default='')
146
+ parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
147
+ parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
148
+ parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
149
+
150
+ # * Finetuning params
151
+ parser.add_argument('--finetune', default='', help='finetune from checkpoint')
152
+ parser.add_argument('--attn-only', action='store_true')
153
+
154
+ # Dataset parameters
155
+ parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
156
+ help='dataset path')
157
+ parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
158
+ type=str, help='Image Net dataset path')
159
+ parser.add_argument('--inat-category', default='name',
160
+ choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
161
+ type=str, help='semantic granularity')
162
+
163
+ parser.add_argument('--output_dir', default='',
164
+ help='path where to save, empty for no saving')
165
+ parser.add_argument('--device', default='cpu',
166
+ help='device to use for training / testing')
167
+ parser.add_argument('--seed', default=0, type=int)
168
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
169
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
170
+ help='start epoch')
171
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
172
+ parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
173
+ parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
174
+ parser.add_argument('--num_workers', default=10, type=int)
175
+ parser.add_argument('--pin-mem', action='store_true',
176
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
177
+ parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
178
+ help='')
179
+ parser.set_defaults(pin_mem=True)
180
+
181
+ # distributed training parameters
182
+ parser.add_argument('--world_size', default=1, type=int,
183
+ help='number of distributed processes')
184
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
185
+
186
+ parser.add_argument('--exp_name', default='deit', type=str, help='experiment name')
187
+ parser.add_argument('--config', default=None, type=str, help='configuration')
188
+ parser.add_argument('--scoring', action='store_true', default=False, help='configuration')
189
+ parser.add_argument('--proxy', default='synflow', type=str, help='configuration')
190
+ parser.add_argument('--snnet_name', default='snnetv2', type=str, help='configuration')
191
+ parser.add_argument('--get_flops', action='store_true')
192
+ parser.add_argument('--flops_sampling_k', default=None, type=float, help="Crop ratio for evaluation")
193
+ parser.add_argument('--low_rank', action='store_true', default=False, help='Enabling distributed evaluation')
194
+ parser.add_argument('--lora_r', default=64, type=int,
195
+ help='number of distributed processes')
196
+ parser.add_argument('--flops_gap', default=1.0, type=float,
197
+ help='number of distributed processes')
198
+
199
+ return parser
200
+
201
+
202
+ def initialize_model_stitching_layer(model, mixup_fn, data_loader, device):
203
+ for samples, targets in data_loader:
204
+ samples = samples.to(device, non_blocking=True)
205
+ targets = targets.to(device, non_blocking=True)
206
+
207
+ if mixup_fn is not None:
208
+ samples, targets = mixup_fn(samples, targets)
209
+
210
+ with torch.cuda.amp.autocast():
211
+ model.initialize_stitching_weights(samples)
212
+
213
+ break
214
+
215
+
216
+ @torch.no_grad()
217
+ def analyse_flops_for_all(model, config_name):
218
+ all_cfgs = model.all_cfgs
219
+ stitch_results = {}
220
+
221
+ for cfg_id in all_cfgs:
222
+ model.reset_stitch_id(cfg_id)
223
+ flops = FlopCountAnalysis(model, torch.randn(1, 3, 224, 224).cuda()).total()
224
+ stitch_results[cfg_id] = flops
225
+
226
+ save_dir = './model_flops'
227
+ if not os.path.exists(save_dir):
228
+ os.mkdir(save_dir)
229
+
230
+ with open(os.path.join(save_dir, f'flops_{config_name}.json'), 'w+') as f:
231
+ json.dump(stitch_results, f, indent=4)
232
+
233
+
234
+ def main(args):
235
+ utils.init_distributed_mode(args)
236
+
237
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
238
+ logger = get_root_logger(os.path.join(args.output_dir, f'{timestamp}.log'))
239
+
240
+ logger.info(str(args))
241
+
242
+ if args.distillation_type != 'none' and args.finetune and not args.eval:
243
+ raise NotImplementedError("Finetuning with distillation not yet supported")
244
+
245
+ device = torch.device(args.device)
246
+
247
+ # fix the seed for reproducibility
248
+ seed = args.seed + utils.get_rank()
249
+ torch.manual_seed(seed)
250
+ np.random.seed(seed)
251
+ # random.seed(seed)
252
+
253
+ cudnn.benchmark = True
254
+
255
+ from datasets import build_transform
256
+
257
+ transform = build_transform(False, args)
258
+
259
+ anchors = []
260
+ for i, anchor_name in enumerate(args.anchors):
261
+ logger.info(f"Creating model: {anchor_name}")
262
+ anchor = create_model(
263
+ anchor_name,
264
+ pretrained=False,
265
+ pretrained_deit=None,
266
+ num_classes=1000,
267
+ drop_path_rate=args.anchor_drop_path[i],
268
+ img_size=args.input_size
269
+ )
270
+ anchors.append(anchor)
271
+
272
+ model = SNNetv2(anchors, lora_r=args.lora_r)
273
+
274
+ checkpoint = torch.load(args.resume, map_location='cpu')
275
+ # torch.save({'model': checkpoint['model']}, './snnetv2_deit3_s_l_50ep.pth')
276
+
277
+ logger.info(f"load checkpoint from {args.resume}")
278
+ model.load_state_dict(checkpoint['model'])
279
+
280
+ model.to(device)
281
+
282
+ config_name = args.config.split('/')[-1].split('.')[0]
283
+ model.eval()
284
+
285
+ eval_res = {}
286
+ flops_res = {}
287
+ with open('stitches_res_s_l.txt', 'r') as f:
288
+ for line in f.readlines():
289
+ epoch_stat = json.loads(line.strip())
290
+ eval_res[epoch_stat['cfg_id']] = epoch_stat['acc1']
291
+ flops_res[epoch_stat['cfg_id']] = epoch_stat['flops'] / 1e9
292
+
293
+ def visualize_stitch_pos(stitch_id):
294
+ if stitch_id == 13:
295
+ # 13 is equivalent to 0
296
+ stitch_id = 0
297
+
298
+ names = [f'ID {key}' for key in flops_res.keys()]
299
+
300
+ fig = px.scatter(x=flops_res.values(), y=eval_res.values(), hover_name=names)
301
+ fig.update_layout(
302
+ title=f"SN-Netv2 - Stitch ID - {stitch_id}",
303
+ title_x=0.5,
304
+ xaxis_title="GFLOPs",
305
+ yaxis_title="mIoU",
306
+ font=dict(
307
+ family="Courier New, monospace",
308
+ size=18,
309
+ color="RebeccaPurple"
310
+ ),
311
+ legend=dict(
312
+ yanchor="bottom",
313
+ y=0.99,
314
+ xanchor="left",
315
+ x=0.01),
316
+ )
317
+ # continent, DarkSlateGrey
318
+ fig.update_traces(marker=dict(size=10,
319
+ line=dict(width=2)),
320
+ selector=dict(mode='markers'))
321
+
322
+ fig.add_scatter(x=[flops_res[stitch_id]], y=[eval_res[stitch_id]], mode='markers', marker=dict(size=15),
323
+ name='Current Stitch')
324
+ return fig
325
+
326
+ # Download human-readable labels for ImageNet.
327
+ response = requests.get("https://git.io/JJkYN")
328
+ labels = response.text.split("\n")
329
+
330
+ def process_image(image, stitch_id):
331
+ # inp = torch.from_numpy(image).permute(2, 0, 1).float()
332
+ inp = transform(image).unsqueeze(0).to(device)
333
+ model.reset_stitch_id(stitch_id)
334
+ with torch.no_grad():
335
+ prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
336
+ confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
337
+ fig = visualize_stitch_pos(stitch_id)
338
+ return confidences, fig
339
+
340
+ with gr.Blocks() as main_page:
341
+ with gr.Column():
342
+ gr.HTML("""
343
+ <h1 align="center" style=" display: flex; flex-direction: row; justify-content: center; font-size: 25pt; ">Stitched ViTs are Flexible Vision Backbones</h1>
344
+ <div align="center"> <img align="center" src='file/gradio_banner.png' width="70%"> </div>
345
+ <h3 align="center" >This is the classification demo page of SN-Netv2, an flexible vision backbone that allows for 100+ runtime speed and performance trade-offs.</h3>
346
+ <h3 align="center" >You can also run this gradio demo on your local GPUs at https://github.com/ziplab/SN-Netv2</h3>
347
+ """)
348
+ with gr.Row():
349
+ with gr.Column():
350
+ image_input = gr.Image(type='pil')
351
+ stitch_slider = gr.Slider(minimum=0, maximum=134, step=1, label="Stitch ID")
352
+ with gr.Row():
353
+ clear_button = gr.ClearButton()
354
+ submit_button = gr.Button()
355
+ with gr.Column():
356
+ label_output = gr.Label(num_top_classes=5)
357
+ stitch_plot = gr.Plot(label='Stitch Position')
358
+
359
+ submit_button.click(
360
+ fn=process_image,
361
+ inputs=[image_input, stitch_slider],
362
+ outputs=[label_output, stitch_plot],
363
+ )
364
+
365
+ stitch_slider.change(
366
+ fn=visualize_stitch_pos,
367
+ inputs=[stitch_slider],
368
+ outputs=[stitch_plot],
369
+ show_progress=False
370
+ )
371
+
372
+ clear_button.click(
373
+ lambda: [None, 0, None, None],
374
+ outputs=[image_input, stitch_slider, label_output, stitch_plot],
375
+ )
376
+
377
+ gr.Examples(
378
+ [
379
+ ['demo.jpg', 0],
380
+ ],
381
+ inputs=[
382
+ image_input,
383
+ stitch_slider
384
+ ],
385
+ outputs=[
386
+ label_output,
387
+ stitch_plot
388
+ ],
389
+ )
390
+
391
+ main_page.launch(allowed_paths=['./'])
392
+
393
+
394
+ if __name__ == '__main__':
395
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
396
+ args = parser.parse_args()
397
+ setattr(args, 'config', f'gradio_demo.json')
398
+ if args.config is not None:
399
+ config_args = json.load(open(args.config))
400
+ override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:]
401
+ if arg.startswith('--')}
402
+ for k, v in config_args.items():
403
+ if k not in override_keys:
404
+ setattr(args, k, v)
405
+
406
+ output_dir = os.path.join('outputs', args.exp_name)
407
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
408
+ checkpoint_path = os.path.join(output_dir, 'checkpoint.pth')
409
+ if os.path.exists(checkpoint_path) and not args.resume:
410
+ setattr(args, 'resume', checkpoint_path)
411
 
412
+ setattr(args, 'output_dir', output_dir)
 
413
 
414
+ main(args)
 
datasets.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ import os
4
+ import json
5
+
6
+ from torchvision import datasets, transforms
7
+ from torchvision.datasets.folder import ImageFolder, default_loader
8
+
9
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10
+ from timm.data import create_transform
11
+
12
+
13
+ class INatDataset(ImageFolder):
14
+ def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
15
+ category='name', loader=default_loader):
16
+ self.transform = transform
17
+ self.loader = loader
18
+ self.target_transform = target_transform
19
+ self.year = year
20
+ # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
21
+ path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
22
+ with open(path_json) as json_file:
23
+ data = json.load(json_file)
24
+
25
+ with open(os.path.join(root, 'categories.json')) as json_file:
26
+ data_catg = json.load(json_file)
27
+
28
+ path_json_for_targeter = os.path.join(root, f"train{year}.json")
29
+
30
+ with open(path_json_for_targeter) as json_file:
31
+ data_for_targeter = json.load(json_file)
32
+
33
+ targeter = {}
34
+ indexer = 0
35
+ for elem in data_for_targeter['annotations']:
36
+ king = []
37
+ king.append(data_catg[int(elem['category_id'])][category])
38
+ if king[0] not in targeter.keys():
39
+ targeter[king[0]] = indexer
40
+ indexer += 1
41
+ self.nb_classes = len(targeter)
42
+
43
+ self.samples = []
44
+ for elem in data['images']:
45
+ cut = elem['file_name'].split('/')
46
+ target_current = int(cut[2])
47
+ path_current = os.path.join(root, cut[0], cut[2], cut[3])
48
+
49
+ categors = data_catg[target_current]
50
+ target_current_true = targeter[categors[category]]
51
+ self.samples.append((path_current, target_current_true))
52
+
53
+ # __getitem__ and __len__ inherited from ImageFolder
54
+
55
+
56
+ def build_dataset(is_train, args):
57
+ transform = build_transform(is_train, args)
58
+
59
+ if args.data_set == 'CIFAR':
60
+ dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
61
+ nb_classes = 100
62
+ elif args.data_set == 'IMNET':
63
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
64
+ dataset = datasets.ImageFolder(root, transform=transform)
65
+ nb_classes = 1000
66
+ elif args.data_set == 'INAT':
67
+ dataset = INatDataset(args.data_path, train=is_train, year=2018,
68
+ category=args.inat_category, transform=transform)
69
+ nb_classes = dataset.nb_classes
70
+ elif args.data_set == 'INAT19':
71
+ dataset = INatDataset(args.data_path, train=is_train, year=2019,
72
+ category=args.inat_category, transform=transform)
73
+ nb_classes = dataset.nb_classes
74
+
75
+ return dataset, nb_classes
76
+
77
+
78
+ def build_transform(is_train, args):
79
+ resize_im = args.input_size > 32
80
+ if is_train:
81
+ # this should always dispatch to transforms_imagenet_train
82
+ transform = create_transform(
83
+ input_size=args.input_size,
84
+ is_training=True,
85
+ color_jitter=args.color_jitter,
86
+ auto_augment=args.aa,
87
+ interpolation=args.train_interpolation,
88
+ re_prob=args.reprob,
89
+ re_mode=args.remode,
90
+ re_count=args.recount,
91
+ )
92
+ if not resize_im:
93
+ # replace RandomResizedCropAndInterpolation with
94
+ # RandomCrop
95
+ transform.transforms[0] = transforms.RandomCrop(
96
+ args.input_size, padding=4)
97
+ return transform
98
+
99
+ t = []
100
+ if resize_im:
101
+ size = int(args.input_size / args.eval_crop_ratio)
102
+ t.append(
103
+ transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
104
+ )
105
+ t.append(transforms.CenterCrop(args.input_size))
106
+
107
+ t.append(transforms.ToTensor())
108
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
109
+ return transforms.Compose(t)
demo.jpg ADDED
flops_gradio_demo.json ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": 4608338304,
3
+ "1": 61604135936,
4
+ "2": 56843745792,
5
+ "3": 52102230016,
6
+ "4": 47360714240,
7
+ "5": 42619198464,
8
+ "6": 37877682688,
9
+ "7": 33136166912,
10
+ "8": 28394651136,
11
+ "9": 23653135360,
12
+ "10": 18911619584,
13
+ "11": 14170103808,
14
+ "12": 9428588032,
15
+ "14": 9523655552,
16
+ "15": 14265171328,
17
+ "16": 19006687104,
18
+ "17": 23748202880,
19
+ "18": 28489718656,
20
+ "19": 33231234432,
21
+ "20": 37972750208,
22
+ "21": 42714265984,
23
+ "22": 47455781760,
24
+ "23": 52197297536,
25
+ "24": 56938813312,
26
+ "25": 57017547264,
27
+ "26": 52276031488,
28
+ "27": 47534515712,
29
+ "28": 42792999936,
30
+ "29": 38051484160,
31
+ "30": 33309968384,
32
+ "31": 28568452608,
33
+ "32": 23826936832,
34
+ "33": 19085421056,
35
+ "34": 14343905280,
36
+ "35": 57017547264,
37
+ "36": 52276031488,
38
+ "37": 47534515712,
39
+ "38": 42792999936,
40
+ "39": 38051484160,
41
+ "40": 33309968384,
42
+ "41": 28568452608,
43
+ "42": 23826936832,
44
+ "43": 19085421056,
45
+ "44": 57017547264,
46
+ "45": 52276031488,
47
+ "46": 47534515712,
48
+ "47": 42792999936,
49
+ "48": 38051484160,
50
+ "49": 33309968384,
51
+ "50": 28568452608,
52
+ "51": 23826936832,
53
+ "52": 57017547264,
54
+ "53": 52276031488,
55
+ "54": 47534515712,
56
+ "55": 42792999936,
57
+ "56": 38051484160,
58
+ "57": 33309968384,
59
+ "58": 28568452608,
60
+ "59": 57017547264,
61
+ "60": 52276031488,
62
+ "61": 47534515712,
63
+ "62": 42792999936,
64
+ "63": 38051484160,
65
+ "64": 33309968384,
66
+ "65": 57017547264,
67
+ "66": 52276031488,
68
+ "67": 47534515712,
69
+ "68": 42792999936,
70
+ "69": 38051484160,
71
+ "70": 57017547264,
72
+ "71": 52276031488,
73
+ "72": 47534515712,
74
+ "73": 42792999936,
75
+ "74": 57017547264,
76
+ "75": 52276031488,
77
+ "76": 47534515712,
78
+ "77": 57017547264,
79
+ "78": 52276031488,
80
+ "79": 57017547264,
81
+ "80": 9504781184,
82
+ "81": 14246296960,
83
+ "82": 18987812736,
84
+ "83": 23729328512,
85
+ "84": 28470844288,
86
+ "85": 33212360064,
87
+ "86": 37953875840,
88
+ "87": 42695391616,
89
+ "88": 47436907392,
90
+ "89": 52178423168,
91
+ "90": 9504781184,
92
+ "91": 14246296960,
93
+ "92": 18987812736,
94
+ "93": 23729328512,
95
+ "94": 28470844288,
96
+ "95": 33212360064,
97
+ "96": 37953875840,
98
+ "97": 42695391616,
99
+ "98": 47436907392,
100
+ "99": 9504781184,
101
+ "100": 14246296960,
102
+ "101": 18987812736,
103
+ "102": 23729328512,
104
+ "103": 28470844288,
105
+ "104": 33212360064,
106
+ "105": 37953875840,
107
+ "106": 42695391616,
108
+ "107": 9504781184,
109
+ "108": 14246296960,
110
+ "109": 18987812736,
111
+ "110": 23729328512,
112
+ "111": 28470844288,
113
+ "112": 33212360064,
114
+ "113": 37953875840,
115
+ "114": 9504781184,
116
+ "115": 14246296960,
117
+ "116": 18987812736,
118
+ "117": 23729328512,
119
+ "118": 28470844288,
120
+ "119": 33212360064,
121
+ "120": 9504781184,
122
+ "121": 14246296960,
123
+ "122": 18987812736,
124
+ "123": 23729328512,
125
+ "124": 28470844288,
126
+ "125": 9504781184,
127
+ "126": 14246296960,
128
+ "127": 18987812736,
129
+ "128": 23729328512,
130
+ "129": 9504781184,
131
+ "130": 14246296960,
132
+ "131": 18987812736,
133
+ "132": 9504781184,
134
+ "133": 14246296960,
135
+ "134": 9504781184
136
+ }
gradio_banner.png ADDED
gradio_demo.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "anchors": ["deit_small_patch16_LS", "deit_large_patch16_LS"],
3
+ "batch_size": 64,
4
+ "snnet_name": "snnet_v2",
5
+ "data_path": "/data2/datasets/imagenet",
6
+ "data_set": "IMNET",
7
+ "exp_name": "stitch_s_l_v2_lora_r_64_50_ep",
8
+ "input_size": 224,
9
+ "num_workers": 10,
10
+ "lr": 0.00003,
11
+ "warmup_lr": 1e-7,
12
+ "epochs": 50,
13
+ "weight_decay": 0.02,
14
+ "sched": "cosine",
15
+ "eval_crop_ratio": 1.0,
16
+ "reprob": 0.0,
17
+ "smoothing": 0.1,
18
+ "warmup_epochs": 5,
19
+ "drop": 0.0,
20
+ "seed": 0,
21
+ "opt": "fusedlamb",
22
+ "mixup": 0,
23
+ "anchor_drop_path": [0.05, 0.4],
24
+ "cutmix": 1.0,
25
+ "color_jitter": 0.3,
26
+ "unscale_lr": true,
27
+ "no_repeated_aug": true,
28
+ "ThreeAugment": true,
29
+ "src": true,
30
+ "lora_r": 64,
31
+ "pretrained_deit": "../pretrained_weights",
32
+ "resume": "snnetv2_deit3_s_l.pth"
33
+ }
models_v2.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ import os.path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from functools import partial
8
+
9
+ from timm.models.vision_transformer import Mlp, PatchEmbed , _cfg
10
+
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+ from timm.models.registry import register_model
13
+ # from xformers.ops import memory_efficient_attention
14
+
15
+ class Attention(nn.Module):
16
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
18
+ super().__init__()
19
+ self.num_heads = num_heads
20
+ head_dim = dim // num_heads
21
+ self.scale = qk_scale or head_dim ** -0.5
22
+
23
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
24
+ self.attn_drop = nn.Dropout(attn_drop)
25
+ self.proj = nn.Linear(dim, dim)
26
+ self.proj_drop = nn.Dropout(proj_drop)
27
+
28
+ def forward(self, x):
29
+ B, N, C = x.shape
30
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
31
+ q, k, v = qkv[0], qkv[1], qkv[2]
32
+
33
+
34
+ # x = memory_efficient_attention(q, k, v).transpose(1, 2).reshape(B, N, C)
35
+
36
+ q = q * self.scale
37
+ attn = (q @ k.transpose(-2, -1))
38
+ attn = attn.softmax(dim=-1)
39
+ attn = self.attn_drop(attn)
40
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
41
+
42
+ x = self.proj(x)
43
+ x = self.proj_drop(x)
44
+ return x
45
+
46
+ class Block(nn.Module):
47
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
48
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
49
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
50
+ ,init_values=1e-4):
51
+ super().__init__()
52
+ self.norm1 = norm_layer(dim)
53
+ self.attn = Attention_block(
54
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
55
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
56
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
57
+ self.norm2 = norm_layer(dim)
58
+ mlp_hidden_dim = int(dim * mlp_ratio)
59
+ self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
60
+
61
+ def forward(self, x):
62
+ x = x + self.drop_path(self.attn(self.norm1(x)))
63
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
64
+ return x
65
+
66
+ class Layer_scale_init_Block(nn.Module):
67
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
68
+ # with slight modifications
69
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
70
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
71
+ ,init_values=1e-4):
72
+ super().__init__()
73
+ self.norm1 = norm_layer(dim)
74
+ self.attn = Attention_block(
75
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
76
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
77
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
78
+ self.norm2 = norm_layer(dim)
79
+ mlp_hidden_dim = int(dim * mlp_ratio)
80
+ self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
81
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
82
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
83
+
84
+ def forward(self, x):
85
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
86
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
87
+ return x
88
+
89
+ class Layer_scale_init_Block_paralx2(nn.Module):
90
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
91
+ # with slight modifications
92
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
93
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
94
+ ,init_values=1e-4):
95
+ super().__init__()
96
+ self.norm1 = norm_layer(dim)
97
+ self.norm11 = norm_layer(dim)
98
+ self.attn = Attention_block(
99
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
100
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
101
+ self.attn1 = Attention_block(
102
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
103
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
104
+ self.norm2 = norm_layer(dim)
105
+ self.norm21 = norm_layer(dim)
106
+ mlp_hidden_dim = int(dim * mlp_ratio)
107
+ self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
108
+ self.mlp1 = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
109
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
110
+ self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
111
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
112
+ self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
113
+
114
+ def forward(self, x):
115
+ x = x + self.drop_path(self.gamma_1*self.attn(self.norm1(x))) + self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
116
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
117
+ return x
118
+
119
+ class Block_paralx2(nn.Module):
120
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
121
+ # with slight modifications
122
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
123
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
124
+ ,init_values=1e-4):
125
+ super().__init__()
126
+ self.norm1 = norm_layer(dim)
127
+ self.norm11 = norm_layer(dim)
128
+ self.attn = Attention_block(
129
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
130
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
131
+ self.attn1 = Attention_block(
132
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
133
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
134
+ self.norm2 = norm_layer(dim)
135
+ self.norm21 = norm_layer(dim)
136
+ mlp_hidden_dim = int(dim * mlp_ratio)
137
+ self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
138
+ self.mlp1 = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
139
+
140
+ def forward(self, x):
141
+ x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.attn1(self.norm11(x)))
142
+ x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.mlp1(self.norm21(x)))
143
+ return x
144
+
145
+
146
+ class hMLP_stem(nn.Module):
147
+ """ hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf
148
+ taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
149
+ with slight modifications
150
+ """
151
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,norm_layer=nn.SyncBatchNorm):
152
+ super().__init__()
153
+ img_size = to_2tuple(img_size)
154
+ patch_size = to_2tuple(patch_size)
155
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
156
+ self.img_size = img_size
157
+ self.patch_size = patch_size
158
+ self.num_patches = num_patches
159
+ self.proj = torch.nn.Sequential(*[nn.Conv2d(in_chans, embed_dim//4, kernel_size=4, stride=4),
160
+ norm_layer(embed_dim//4),
161
+ nn.GELU(),
162
+ nn.Conv2d(embed_dim//4, embed_dim//4, kernel_size=2, stride=2),
163
+ norm_layer(embed_dim//4),
164
+ nn.GELU(),
165
+ nn.Conv2d(embed_dim//4, embed_dim, kernel_size=2, stride=2),
166
+ norm_layer(embed_dim),
167
+ ])
168
+
169
+
170
+ def forward(self, x):
171
+ B, C, H, W = x.shape
172
+ x = self.proj(x).flatten(2).transpose(1, 2)
173
+ return x
174
+
175
+ class vit_models(nn.Module):
176
+ """ Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support
177
+ taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
178
+ with slight modifications
179
+ """
180
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
181
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
182
+ drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool=None,
183
+ block_layers = Block,
184
+ Patch_layer=PatchEmbed,act_layer=nn.GELU,
185
+ Attention_block = Attention, Mlp_block=Mlp,
186
+ dpr_constant=True,init_scale=1e-4,
187
+ mlp_ratio_clstk = 4.0):
188
+ super().__init__()
189
+
190
+ self.dropout_rate = drop_rate
191
+ self.depth = depth
192
+
193
+ self.num_classes = num_classes
194
+ self.num_features = self.embed_dim = embed_dim
195
+
196
+ self.patch_embed = Patch_layer(
197
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
198
+ num_patches = self.patch_embed.num_patches
199
+
200
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
201
+
202
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
203
+
204
+ dpr = [drop_path_rate for i in range(depth)]
205
+ self.blocks = nn.ModuleList([
206
+ block_layers(
207
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
208
+ drop=0.0, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
209
+ act_layer=act_layer,Attention_block=Attention_block,Mlp_block=Mlp_block,init_values=init_scale)
210
+ for i in range(depth)])
211
+
212
+ self.norm = norm_layer(embed_dim)
213
+
214
+ self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
215
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
216
+
217
+ trunc_normal_(self.pos_embed, std=.02)
218
+ trunc_normal_(self.cls_token, std=.02)
219
+ self.apply(self._init_weights)
220
+
221
+ def _init_weights(self, m):
222
+ if isinstance(m, nn.Linear):
223
+ trunc_normal_(m.weight, std=.02)
224
+ if isinstance(m, nn.Linear) and m.bias is not None:
225
+ nn.init.constant_(m.bias, 0)
226
+ elif isinstance(m, nn.LayerNorm):
227
+ nn.init.constant_(m.bias, 0)
228
+ nn.init.constant_(m.weight, 1.0)
229
+
230
+ @torch.jit.ignore
231
+ def no_weight_decay(self):
232
+ return {'pos_embed', 'cls_token'}
233
+
234
+ def get_classifier(self):
235
+ return self.head
236
+
237
+ def get_num_layers(self):
238
+ return len(self.blocks)
239
+
240
+ def reset_classifier(self, num_classes, global_pool=''):
241
+ self.num_classes = num_classes
242
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
243
+
244
+
245
+ def extract_block_features(self, x):
246
+ B = x.shape[0]
247
+ x = self.patch_embed(x)
248
+
249
+ cls_tokens = self.cls_token.expand(B, -1, -1)
250
+
251
+ x = x + self.pos_embed
252
+
253
+ x = torch.cat((cls_tokens, x), dim=1)
254
+
255
+ outs = {}
256
+
257
+ for i, blk in enumerate(self.blocks):
258
+ x = blk(x)
259
+ outs[i] = x.detach()
260
+ return outs
261
+
262
+ def selective_forward(self, x, begin, end):
263
+ for i, blk in enumerate(self.blocks):
264
+ if i < begin:
265
+ continue
266
+ if i > end:
267
+ break
268
+ x = blk(x)
269
+ return x
270
+
271
+ def forward_until(self, x, blk_id):
272
+ B = x.shape[0]
273
+ x = self.patch_embed(x)
274
+
275
+ cls_tokens = self.cls_token.expand(B, -1, -1)
276
+ x = x + self.pos_embed
277
+ x = torch.cat((cls_tokens, x), dim=1)
278
+
279
+ for i, blk in enumerate(self.blocks):
280
+ x = blk(x)
281
+ if i == blk_id:
282
+ break
283
+
284
+ return x
285
+
286
+ def forward_from(self, x, blk_id):
287
+ for i, blk in enumerate(self.blocks):
288
+ if i < blk_id:
289
+ continue
290
+ x = blk(x)
291
+
292
+ x = self.norm(x)
293
+ x = self.head(x[:, 0])
294
+
295
+ return x
296
+
297
+ def forward_patch_embed(self, x):
298
+ B = x.shape[0]
299
+ x = self.patch_embed(x)
300
+
301
+ cls_tokens = self.cls_token.expand(B, -1, -1)
302
+
303
+ x = x + self.pos_embed
304
+
305
+ x = torch.cat((cls_tokens, x), dim=1)
306
+ return x
307
+
308
+
309
+ def forward_norm_head(self, x):
310
+ x = self.norm(x)
311
+ x = self.head(x[:, 0])
312
+ return x
313
+
314
+
315
+ def forward_features(self, x):
316
+ B = x.shape[0]
317
+ x = self.patch_embed(x)
318
+
319
+ cls_tokens = self.cls_token.expand(B, -1, -1)
320
+
321
+ x = x + self.pos_embed
322
+
323
+ x = torch.cat((cls_tokens, x), dim=1)
324
+
325
+ for i , blk in enumerate(self.blocks):
326
+ x = blk(x)
327
+
328
+ x = self.norm(x)
329
+ return x[:, 0]
330
+
331
+ def forward(self, x):
332
+
333
+ x = self.forward_features(x)
334
+
335
+ if self.dropout_rate:
336
+ x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
337
+ x = self.head(x)
338
+
339
+ return x
340
+
341
+ # DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
342
+
343
+ @register_model
344
+ def deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, pretrained_cfg_overlay=None, **kwargs):
345
+ model = vit_models(
346
+ img_size = img_size, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
347
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
348
+
349
+ return model
350
+
351
+
352
+ @register_model
353
+ def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, pretrained_cfg=None, pretrained_deit=None, pretrained_cfg_overlay=None, **kwargs):
354
+ model = vit_models(
355
+ img_size = img_size, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
356
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
357
+ model.default_cfg = _cfg()
358
+ if pretrained:
359
+ # name = 'https://dl.fbaipublicfiles.com/deit/deit_3_small_'+str(img_size)+'_'
360
+ # if pretrained_21k:
361
+ # name+='21k.pth'
362
+ # else:
363
+ # name+='1k.pth'
364
+
365
+ # checkpoint = torch.hub.load_state_dict_from_url(
366
+ # url=name,
367
+ # map_location="cpu", check_hash=True
368
+ # )
369
+ checkpoint = torch.load(os.path.join(pretrained_deit, 'deit_3_small_224_21k.pth'))
370
+ model.load_state_dict(checkpoint["model"])
371
+
372
+ return model
373
+
374
+ @register_model
375
+ def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
376
+ model = vit_models(
377
+ patch_size=16, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True,
378
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
379
+ model.default_cfg = _cfg()
380
+ if pretrained:
381
+ name = 'https://dl.fbaipublicfiles.com/deit/deit_3_medium_'+str(img_size)+'_'
382
+ if pretrained_21k:
383
+ name+='21k.pth'
384
+ else:
385
+ name+='1k.pth'
386
+
387
+ checkpoint = torch.hub.load_state_dict_from_url(
388
+ url=name,
389
+ map_location="cpu", check_hash=True
390
+ )
391
+ model.load_state_dict(checkpoint["model"])
392
+ return model
393
+
394
+ @register_model
395
+ def deit_base_patch16_LS(pretrained=False, pretrained_cfg=None, img_size=224, pretrained_21k = False, pretrained_deit=None, pretrained_cfg_overlay=None, **kwargs):
396
+ model = vit_models(
397
+ img_size = img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
398
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
399
+ if pretrained:
400
+ # name = 'https://dl.fbaipublicfiles.com/deit/deit_3_small_'+str(img_size)+'_'
401
+ # if pretrained_21k:
402
+ # name+='21k.pth'
403
+ # else:
404
+ # name+='1k.pth'
405
+
406
+ # checkpoint = torch.hub.load_state_dict_from_url(
407
+ # url=name,
408
+ # map_location="cpu", check_hash=True
409
+ # )
410
+ checkpoint = torch.load(os.path.join(pretrained_deit, 'deit_3_base_224_21k.pth'))
411
+ model.load_state_dict(checkpoint["model"])
412
+ return model
413
+
414
+ @register_model
415
+ def deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, pretrained_cfg=None, pretrained_deit=None, pretrained_cfg_overlay=None, **kwargs):
416
+ model = vit_models(
417
+ img_size = img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
418
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
419
+ if pretrained:
420
+ # name = 'https://dl.fbaipublicfiles.com/deit/deit_3_large_'+str(img_size)+'_'
421
+ # if pretrained_21k:
422
+ # name+='21k.pth'
423
+ # else:
424
+ # name+='1k.pth'
425
+ #
426
+ # checkpoint = torch.hub.load_state_dict_from_url(
427
+ # url=name,
428
+ # map_location="cpu", check_hash=True
429
+ # )
430
+ checkpoint = torch.load(os.path.join(pretrained_deit, 'deit_3_large_224_21k.pth'))
431
+ model.load_state_dict(checkpoint["model"])
432
+ return model
433
+
434
+ @register_model
435
+ def deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
436
+ model = vit_models(
437
+ img_size = img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
438
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
439
+ if pretrained:
440
+ name = 'https://dl.fbaipublicfiles.com/deit/deit_3_huge_'+str(img_size)+'_'
441
+ if pretrained_21k:
442
+ name+='21k_v1.pth'
443
+ else:
444
+ name+='1k_v1.pth'
445
+
446
+ checkpoint = torch.hub.load_state_dict_from_url(
447
+ url=name,
448
+ map_location="cpu", check_hash=True
449
+ )
450
+ model.load_state_dict(checkpoint["model"])
451
+ return model
452
+
453
+ @register_model
454
+ def deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
455
+ model = vit_models(
456
+ img_size = img_size, patch_size=14, embed_dim=1280, depth=52, num_heads=16, mlp_ratio=4, qkv_bias=True,
457
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
458
+
459
+ return model
460
+
461
+ @register_model
462
+ def deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
463
+ model = vit_models(
464
+ img_size = img_size, patch_size=14, embed_dim=1280, depth=26, num_heads=16, mlp_ratio=4, qkv_bias=True,
465
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block_paralx2, **kwargs)
466
+
467
+ return model
468
+
469
+ @register_model
470
+ def deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
471
+ model = vit_models(
472
+ img_size = img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,
473
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Block_paral_LS, **kwargs)
474
+
475
+ return model
476
+
477
+ @register_model
478
+ def deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
479
+ model = vit_models(
480
+ img_size = img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4, qkv_bias=True,
481
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Block_paral_LS, **kwargs)
482
+ return model
483
+
484
+ @register_model
485
+ def deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
486
+ model = vit_models(
487
+ img_size = img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,
488
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
489
+ return model
490
+
491
+ @register_model
492
+ def deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
493
+ model = vit_models(
494
+ img_size = img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4, qkv_bias=True,
495
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
496
+ #model.default_cfg = _cfg()
497
+
498
+ return model
499
+
500
+ # Models from Three things everyone should know about Vision Transformers (https://arxiv.org/pdf/2203.09795.pdf)
501
+
502
+ @register_model
503
+ def deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
504
+ model = vit_models(
505
+ img_size = img_size, patch_size=16, embed_dim=384, depth=36, num_heads=6, mlp_ratio=4, qkv_bias=True,
506
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
507
+
508
+ return model
509
+
510
+ @register_model
511
+ def deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
512
+ model = vit_models(
513
+ img_size = img_size, patch_size=16, embed_dim=384, depth=36, num_heads=6, mlp_ratio=4, qkv_bias=True,
514
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
515
+
516
+ return model
517
+
518
+ @register_model
519
+ def deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
520
+ model = vit_models(
521
+ img_size = img_size, patch_size=16, embed_dim=384, depth=18, num_heads=6, mlp_ratio=4, qkv_bias=True,
522
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block_paralx2, **kwargs)
523
+
524
+ return model
525
+
526
+ @register_model
527
+ def deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
528
+ model = vit_models(
529
+ img_size = img_size, patch_size=16, embed_dim=384, depth=18, num_heads=6, mlp_ratio=4, qkv_bias=True,
530
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Block_paralx2, **kwargs)
531
+
532
+ return model
533
+
534
+
535
+ @register_model
536
+ def deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
537
+ model = vit_models(
538
+ img_size = img_size, patch_size=16, embed_dim=768, depth=18, num_heads=12, mlp_ratio=4, qkv_bias=True,
539
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block_paralx2, **kwargs)
540
+
541
+ return model
542
+
543
+
544
+ @register_model
545
+ def deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
546
+ model = vit_models(
547
+ img_size = img_size, patch_size=16, embed_dim=768, depth=18, num_heads=12, mlp_ratio=4, qkv_bias=True,
548
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Block_paralx2, **kwargs)
549
+
550
+ return model
551
+
552
+
553
+ @register_model
554
+ def deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
555
+ model = vit_models(
556
+ img_size = img_size, patch_size=16, embed_dim=768, depth=36, num_heads=12, mlp_ratio=4, qkv_bias=True,
557
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
558
+
559
+ return model
560
+
561
+ @register_model
562
+ def deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
563
+ model = vit_models(
564
+ img_size = img_size, patch_size=16, embed_dim=768, depth=36, num_heads=12, mlp_ratio=4, qkv_bias=True,
565
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
566
+
567
+ return model
568
+
outputs/deit/20240118_171921.log ADDED
@@ -0,0 +1 @@
 
 
1
+ 2024-01-18 17:19:21,866 - snnet - INFO - Namespace(batch_size=64, epochs=300, bce_loss=False, unscale_lr=False, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, sched='cosine', lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-06, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/datasets01/imagenet_full_size/061417/', data_set='IMNET', inat_category='name', output_dir='outputs/deit', device='cuda', seed=0, resume='', start_epoch=0, eval=False, eval_crop_ratio=0.875, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='deit', config=None, scoring=False, proxy='synflow', snnet_name='snnetv2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, distributed=False)
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172124.log ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 2024-01-18 17:21:24,162 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cuda', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
2
+ 2024-01-18 17:21:24,163 - snnet - INFO - Creating model: deit_small_patch16_LS
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172140.log ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 2024-01-18 17:21:40,831 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cuda', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
2
+ 2024-01-18 17:21:40,832 - snnet - INFO - Creating model: deit_small_patch16_LS
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172156.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 2024-01-18 17:21:56,859 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cuda', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
2
+ 2024-01-18 17:21:56,859 - snnet - INFO - Creating model: deit_small_patch16_LS
3
+ 2024-01-18 17:21:57,078 - snnet - INFO - Creating model: deit_large_patch16_LS
4
+ 2024-01-18 17:21:59,994 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
5
+ 2024-01-18 17:22:00,521 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172250.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 2024-01-18 17:22:50,304 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cpu', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
2
+ 2024-01-18 17:22:50,305 - snnet - INFO - Creating model: deit_small_patch16_LS
3
+ 2024-01-18 17:22:50,535 - snnet - INFO - Creating model: deit_large_patch16_LS
4
+ 2024-01-18 17:22:53,873 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
5
+ 2024-01-18 17:22:54,392 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172309.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 2024-01-18 17:23:09,551 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cpu', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
2
+ 2024-01-18 17:23:09,553 - snnet - INFO - Creating model: deit_small_patch16_LS
3
+ 2024-01-18 17:23:09,778 - snnet - INFO - Creating model: deit_large_patch16_LS
4
+ 2024-01-18 17:23:13,077 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
5
+ 2024-01-18 17:23:13,587 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172332.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 2024-01-18 17:23:32,357 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cpu', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
2
+ 2024-01-18 17:23:32,358 - snnet - INFO - Creating model: deit_small_patch16_LS
3
+ 2024-01-18 17:23:32,606 - snnet - INFO - Creating model: deit_large_patch16_LS
4
+ 2024-01-18 17:23:35,576 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
5
+ 2024-01-18 17:23:36,120 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ timm==0.6.12
3
+ fvcore
snnet.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.import math
2
+ import json
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ import numpy as np
9
+
10
+ from collections import defaultdict
11
+ from utils import get_root_logger
12
+ import torch.nn.functional as F
13
+
14
+ def rearrange_activations(activations):
15
+ n_channels = activations.shape[-1]
16
+ activations = activations.reshape(-1, n_channels)
17
+ return activations
18
+
19
+
20
+ def ps_inv(x1, x2):
21
+ '''Least-squares solver given feature maps from two anchors.
22
+ '''
23
+ x1 = rearrange_activations(x1)
24
+ x2 = rearrange_activations(x2)
25
+
26
+ if not x1.shape[0] == x2.shape[0]:
27
+ raise ValueError('Spatial size of compared neurons must match when ' \
28
+ 'calculating psuedo inverse matrix.')
29
+
30
+ # Get transformation matrix shape
31
+ shape = list(x1.shape)
32
+ shape[-1] += 1
33
+
34
+ # Calculate pseudo inverse
35
+ x1_ones = torch.ones(shape)
36
+ x1_ones[:, :-1] = x1
37
+ A_ones = torch.matmul(torch.linalg.pinv(x1_ones), x2.to(x1_ones.device)).T
38
+
39
+ # Get weights and bias
40
+ w = A_ones[..., :-1]
41
+ b = A_ones[..., -1]
42
+
43
+ return w, b
44
+
45
+
46
+ def reset_out_indices(front_depth=12, end_depth=24, out_indices=(9, 14, 19, 23)):
47
+ block_ids = torch.tensor(list(range(front_depth)))
48
+ block_ids = block_ids[None, None, :].float()
49
+ end_mapping_ids = torch.nn.functional.interpolate(block_ids, end_depth)
50
+ end_mapping_ids = end_mapping_ids.squeeze().long().tolist()
51
+
52
+ small_out_indices = []
53
+ for i, idx in enumerate(end_mapping_ids):
54
+ if i in out_indices:
55
+ small_out_indices.append(idx)
56
+
57
+ return small_out_indices
58
+
59
+
60
+ def get_stitch_configs_general_unequal(depths):
61
+ depths = sorted(depths)
62
+
63
+ total_configs = []
64
+
65
+ # anchor configurations
66
+ total_configs.append({'comb_id': [1], })
67
+ num_stitches = depths[0]
68
+ for i, blk_id in enumerate(range(num_stitches)):
69
+ total_configs.append({
70
+ 'comb_id': (0, 1),
71
+ 'stitch_cfgs': (i, (i + 1) * (depths[1] // depths[0]))
72
+ })
73
+ return total_configs, num_stitches
74
+
75
+ def get_stitch_configs_bidirection(depths):
76
+ depths = sorted(depths)
77
+
78
+ total_configs = []
79
+
80
+ # anchor configurations
81
+ total_configs.append({'comb_id': [0], })
82
+ total_configs.append({'comb_id': [1], })
83
+
84
+ num_stitches = depths[0]
85
+
86
+ # small --> large
87
+ sl_configs = []
88
+ for i, blk_id in enumerate(range(num_stitches)):
89
+ sl_configs.append({
90
+ 'comb_id': [0, 1],
91
+ 'stitch_cfgs': [
92
+ [i, (i + 1) * (depths[1] // depths[0])]
93
+ ],
94
+ 'stitch_layer_ids': [i]
95
+ })
96
+
97
+ ls_configs = []
98
+ lsl_confgs = []
99
+ block_ids = torch.tensor(list(range(depths[0])))
100
+ block_ids = block_ids[None, None, :].float()
101
+ end_mapping_ids = torch.nn.functional.interpolate(block_ids, depths[1])
102
+ end_mapping_ids = end_mapping_ids.squeeze().long().tolist()
103
+
104
+ # large --> small
105
+ for i in range(depths[1]):
106
+ if depths[1] != depths[0]:
107
+ if i % 2 == 1 and i < (depths[1] - 1):
108
+ ls_configs.append({
109
+ 'comb_id': [1, 0],
110
+ 'stitch_cfgs': [[i, end_mapping_ids[i] + 1]],
111
+ 'stitch_layer_ids': [i // (depths[1] // depths[0])]
112
+ })
113
+ else:
114
+ if i < (depths[1] - 1):
115
+ ls_configs.append({
116
+ 'comb_id': [1, 0],
117
+ 'stitch_cfgs': [[i, end_mapping_ids[i] + 1]],
118
+ 'stitch_layer_ids': [i // (depths[1] // depths[0])]
119
+ })
120
+
121
+
122
+ # large --> small --> large
123
+ for ls_cfg in ls_configs:
124
+ for sl_cfg in sl_configs:
125
+ if sl_cfg['stitch_layer_ids'][0] == depths[0] - 1:
126
+ continue
127
+ if sl_cfg['stitch_cfgs'][0][0] >= ls_cfg['stitch_cfgs'][0][1]:
128
+ lsl_confgs.append({
129
+ 'comb_id': [1, 0, 1],
130
+ 'stitch_cfgs': [ls_cfg['stitch_cfgs'][0], sl_cfg['stitch_cfgs'][0]],
131
+ 'stitch_layer_ids': ls_cfg['stitch_layer_ids'] + sl_cfg['stitch_layer_ids']
132
+ })
133
+
134
+ # small --> large --> small
135
+ sls_configs = []
136
+ for sl_cfg in sl_configs:
137
+ for ls_cfg in ls_configs:
138
+ if ls_cfg['stitch_cfgs'][0][0] >= sl_cfg['stitch_cfgs'][0][1]:
139
+ sls_configs.append({
140
+ 'comb_id': [0, 1, 0],
141
+ 'stitch_cfgs': [sl_cfg['stitch_cfgs'][0], ls_cfg['stitch_cfgs'][0]],
142
+ 'stitch_layer_ids': sl_cfg['stitch_layer_ids'] + ls_cfg['stitch_layer_ids']
143
+ })
144
+
145
+ total_configs += sl_configs + ls_configs + lsl_confgs + sls_configs
146
+
147
+ anchor_ids = []
148
+ sl_ids = []
149
+ ls_ids = []
150
+ lsl_ids = []
151
+ sls_ids = []
152
+
153
+ for i, cfg in enumerate(total_configs):
154
+ comb_id = cfg['comb_id']
155
+
156
+ if len(comb_id) == 1:
157
+ anchor_ids.append(i)
158
+ continue
159
+
160
+ if len(comb_id) == 2:
161
+ route = []
162
+ front, end = cfg['stitch_cfgs'][0]
163
+ route.append([0, front])
164
+ route.append([end, depths[comb_id[-1]]])
165
+ cfg['route'] = route
166
+
167
+ if comb_id == [0, 1] and front != 11:
168
+ sl_ids.append(i)
169
+ elif comb_id == [1, 0]:
170
+ ls_ids.append(i)
171
+
172
+ if len(comb_id) == 3:
173
+ route = []
174
+ front_1, end_1 = cfg['stitch_cfgs'][0]
175
+ front_2, end_2 = cfg['stitch_cfgs'][1]
176
+ route.append([0, front_1])
177
+ route.append([end_1, front_2])
178
+ route.append([end_2, depths[comb_id[-1]]])
179
+ cfg['route'] = route
180
+
181
+ if comb_id == [1, 0, 1]:
182
+ lsl_ids.append(i)
183
+ elif comb_id == [0, 1, 0]:
184
+ sls_ids.append(i)
185
+
186
+ cfg['stitch_layer_ids'].append(-1)
187
+
188
+ model_combos = [(0, 1), (1, 0)]
189
+ return total_configs, model_combos, [len(sl_configs), len(ls_configs)], anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids
190
+
191
+
192
+ def format_out_features(outs, with_cls_token, hw_shape):
193
+ B, _, C = outs[0].shape
194
+ for i in range(len(outs)):
195
+ if with_cls_token:
196
+ # Remove class token and reshape token for decoder head
197
+ outs[i] = outs[i][:, 1:].reshape(B, hw_shape[0], hw_shape[1],
198
+ C).permute(0, 3, 1, 2).contiguous()
199
+ else:
200
+ outs[i] = outs[i].reshape(B, hw_shape[0], hw_shape[1],
201
+ C).permute(0, 3, 1, 2).contiguous()
202
+ return outs
203
+
204
+
205
+ class LoRALayer():
206
+ def __init__(
207
+ self,
208
+ r: int,
209
+ lora_alpha: int,
210
+ lora_dropout: float,
211
+ merge_weights: bool,
212
+ ):
213
+ self.r = r
214
+ self.lora_alpha = lora_alpha
215
+ # Optional dropout
216
+ if lora_dropout > 0.:
217
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
218
+ else:
219
+ self.lora_dropout = lambda x: x
220
+ # Mark the weight as unmerged
221
+ self.merged = False
222
+ self.merge_weights = merge_weights
223
+
224
+ class Linear(nn.Linear, LoRALayer):
225
+ # LoRA implemented in a dense layer
226
+ def __init__(
227
+ self,
228
+ in_features: int,
229
+ out_features: int,
230
+ r: int = 0,
231
+ lora_alpha: int = 1,
232
+ lora_dropout: float = 0.,
233
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
234
+ merge_weights: bool = True,
235
+ **kwargs
236
+ ):
237
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
238
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
239
+ merge_weights=merge_weights)
240
+
241
+ self.fan_in_fan_out = fan_in_fan_out
242
+ # Actual trainable parameters
243
+ if r > 0:
244
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
245
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
246
+ self.scaling = self.lora_alpha / self.r
247
+ # Freezing the pre-trained weight matrix
248
+ self.weight.requires_grad = False
249
+ self.reset_parameters()
250
+ if fan_in_fan_out:
251
+ self.weight.data = self.weight.data.transpose(0, 1)
252
+
253
+ def reset_parameters(self):
254
+ nn.Linear.reset_parameters(self)
255
+ if hasattr(self, 'lora_A'):
256
+ # initialize A the same way as the default for nn.Linear and B to zero
257
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
258
+ nn.init.zeros_(self.lora_B)
259
+
260
+ def train(self, mode: bool = True):
261
+ def T(w):
262
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
263
+ nn.Linear.train(self, mode)
264
+ if mode:
265
+ if self.merge_weights and self.merged:
266
+ # Make sure that the weights are not merged
267
+ if self.r > 0:
268
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
269
+ self.merged = False
270
+ else:
271
+ if self.merge_weights and not self.merged:
272
+ # Merge the weights and mark it
273
+ if self.r > 0:
274
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
275
+ self.merged = True
276
+
277
+ def forward(self, x: torch.Tensor):
278
+ def T(w):
279
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
280
+ if self.r > 0 and not self.merged:
281
+ result = F.linear(x, T(self.weight), bias=self.bias)
282
+ if self.r > 0:
283
+ result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
284
+ return result
285
+ else:
286
+ return F.linear(x, T(self.weight), bias=self.bias)
287
+
288
+
289
+ class StitchingLayer(nn.Module):
290
+ def __init__(self, in_features=None, out_features=None, r=0):
291
+ super().__init__()
292
+ self.transform = Linear(in_features, out_features, r=r)
293
+
294
+ def init_stitch_weights_bias(self, weight, bias):
295
+ self.transform.weight.data.copy_(weight)
296
+ self.transform.bias.data.copy_(bias)
297
+
298
+ def forward(self, x):
299
+ out = self.transform(x)
300
+ return out
301
+
302
+
303
+ class SNNet(nn.Module):
304
+
305
+ def __init__(self, anchors=None):
306
+ super(SNNet, self).__init__()
307
+ self.anchors = nn.ModuleList(anchors)
308
+
309
+ self.depths = [len(anc.blocks) for anc in self.anchors]
310
+
311
+ total_configs, num_stitches = get_stitch_configs_general_unequal(self.depths)
312
+ self.stitch_layers = nn.ModuleList(
313
+ [StitchingLayer(self.anchors[0].embed_dim, self.anchors[1].embed_dim) for _ in range(num_stitches)])
314
+
315
+ self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)}
316
+ self.all_cfgs = list(self.stitch_configs.keys())
317
+ self.num_configs = len(self.all_cfgs)
318
+ self.stitch_config_id = 0
319
+ self.is_ranking = False
320
+
321
+ def reset_stitch_id(self, stitch_config_id):
322
+ self.stitch_config_id = stitch_config_id
323
+
324
+ def initialize_stitching_weights(self, x):
325
+ logger = get_root_logger()
326
+ front, end = 0, 1
327
+ with torch.no_grad():
328
+ front_features = self.anchors[front].extract_block_features(x)
329
+ end_features = self.anchors[end].extract_block_features(x)
330
+
331
+ for i, blk_id in enumerate(range(self.depths[0])):
332
+ front_id, end_id = i, (i + 1) * (self.depths[1] // self.depths[0])
333
+ front_blk_feat = front_features[front_id]
334
+ end_blk_feat = end_features[end_id - 1]
335
+ w, b = ps_inv(front_blk_feat, end_blk_feat)
336
+ self.stitch_layers[i].init_stitch_weights_bias(w, b)
337
+ logger.info(f'Initialized Stitching Model {front} to Model {end}, Layer {i}')
338
+
339
+ def init_weights(self):
340
+ for anc in self.anchors:
341
+ anc.init_weights()
342
+
343
+ def sampling_stitch_config(self):
344
+ self.stitch_config_id = np.random.choice(self.all_cfgs)
345
+
346
+ def forward(self, x):
347
+
348
+ stitch_cfg_id = self.stitch_config_id
349
+ comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']
350
+
351
+ if len(comb_id) == 1:
352
+ return self.anchors[comb_id[0]](x)
353
+
354
+ cfg = self.stitch_configs[stitch_cfg_id]['stitch_cfgs']
355
+
356
+ x = self.anchors[comb_id[0]].forward_until(x, blk_id=cfg[0])
357
+ x = self.stitch_layers[cfg[0]](x)
358
+ x = self.anchors[comb_id[1]].forward_from(x, blk_id=cfg[1])
359
+
360
+ return x
361
+
362
+
363
+ class SNNetv2(nn.Module):
364
+
365
+ def __init__(self, anchors=None, include_sl=True, include_ls=True, include_lsl=True, include_sls=True, lora_r=0):
366
+ super(SNNetv2, self).__init__()
367
+ self.anchors = nn.ModuleList(anchors)
368
+
369
+ self.lora_r = lora_r
370
+
371
+ self.depths = [len(anc.blocks) for anc in self.anchors]
372
+
373
+ total_configs, model_combos, num_stitches, anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids = get_stitch_configs_bidirection(self.depths)
374
+
375
+ self.stitch_layers = nn.ModuleList()
376
+ self.stitching_map_id = {}
377
+
378
+ for i, (comb, num_sth) in enumerate(zip(model_combos, num_stitches)):
379
+ front, end = comb
380
+ temp = nn.ModuleList(
381
+ [StitchingLayer(self.anchors[front].embed_dim, self.anchors[end].embed_dim, r=lora_r) for _ in range(num_sth)])
382
+ temp.append(nn.Identity())
383
+ self.stitch_layers.append(temp)
384
+
385
+ self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)}
386
+ self.stitch_init_configs = {i: cfg for i, cfg in enumerate(total_configs) if len(cfg['comb_id']) == 2}
387
+
388
+
389
+ self.all_cfgs = list(self.stitch_configs.keys())
390
+ logger = get_root_logger()
391
+ logger.info(str(self.all_cfgs))
392
+
393
+
394
+ self.all_cfgs = anchor_ids
395
+
396
+ if include_sl:
397
+ self.all_cfgs += sl_ids
398
+
399
+ if include_ls:
400
+ self.all_cfgs += ls_ids
401
+
402
+ if include_lsl:
403
+ self.all_cfgs += lsl_ids
404
+
405
+ if include_sls:
406
+ self.all_cfgs += sls_ids
407
+
408
+ self.num_configs = len(self.stitch_configs)
409
+ self.stitch_config_id = 0
410
+
411
+ def reset_stitch_id(self, stitch_config_id):
412
+ self.stitch_config_id = stitch_config_id
413
+
414
+ def set_ranking_mode(self, ranking_mode):
415
+ self.is_ranking = ranking_mode
416
+
417
+ def initialize_stitching_weights(self, x):
418
+ logger = get_root_logger()
419
+ anchor_features = []
420
+ for anchor in self.anchors:
421
+ with torch.no_grad():
422
+ temp = anchor.extract_block_features(x)
423
+ anchor_features.append(temp)
424
+
425
+ for idx, cfg in self.stitch_init_configs.items():
426
+ comb_id = cfg['comb_id']
427
+ if len(comb_id) == 2:
428
+ front_id, end_id = cfg['stitch_cfgs'][0]
429
+ stitch_layer_id = cfg['stitch_layer_ids'][0]
430
+ front_blk_feat = anchor_features[comb_id[0]][front_id]
431
+ end_blk_feat = anchor_features[comb_id[1]][end_id - 1]
432
+ w, b = ps_inv(front_blk_feat, end_blk_feat)
433
+ self.stitch_layers[comb_id[0]][stitch_layer_id].init_stitch_weights_bias(w, b)
434
+ logger.info(f'Initialized Stitching Layer {cfg}')
435
+
436
+ def init_weights(self):
437
+ for anc in self.anchors:
438
+ anc.init_weights()
439
+
440
+
441
+ def sampling_stitch_config(self):
442
+ flops_id = np.random.choice(len(self.flops_grouped_cfgs), p=self.flops_sampling_probs)
443
+ stitch_config_id = np.random.choice(self.flops_grouped_cfgs[flops_id])
444
+ return stitch_config_id
445
+
446
+ def forward(self, x):
447
+
448
+ if self.training:
449
+ stitch_cfg_id = self.sampling_stitch_config()
450
+ else:
451
+ stitch_cfg_id = self.stitch_config_id
452
+
453
+ comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']
454
+
455
+ # forward by a single anchor
456
+ if len(comb_id) == 1:
457
+ return self.anchors[comb_id[0]](x)
458
+
459
+ # forward among anchors
460
+ route = self.stitch_configs[stitch_cfg_id]['route']
461
+ stitch_layer_ids = self.stitch_configs[stitch_cfg_id]['stitch_layer_ids']
462
+
463
+ # patch embeding
464
+ x = self.anchors[comb_id[0]].forward_patch_embed(x)
465
+
466
+ for i, (model_id, cfg) in enumerate(zip(comb_id, route)):
467
+
468
+ x = self.anchors[model_id].selective_forward(x, cfg[0], cfg[1])
469
+ x = self.stitch_layers[model_id][stitch_layer_ids[i]](x)
470
+
471
+ x = self.anchors[comb_id[-1]].forward_norm_head(x)
472
+ return x
473
+
snnetv2_deit3_s_l.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d455f17d73f4ed74702076d4cea516194d8c4aa8fbbc63192f85795f79c76b4
3
+ size 1350494458
stitches_res_s_l.txt ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"loss": 0.7156664722345092, "acc1": 82.9060024609375, "acc5": 96.73400244140625, "cfg_id": 0, "flops": 4608338304}
2
+ {"loss": 0.5377805712209507, "acc1": 86.97800256835937, "acc5": 98.2540023046875, "cfg_id": 1, "flops": 61604135936}
3
+ {"loss": 0.5598483879796483, "acc1": 86.57800241210937, "acc5": 98.08200240234375, "cfg_id": 2, "flops": 56843745792}
4
+ {"loss": 0.5534007405354218, "acc1": 86.6480025390625, "acc5": 98.1760025390625, "cfg_id": 3, "flops": 52102230016}
5
+ {"loss": 0.5610568577028585, "acc1": 86.49800245117187, "acc5": 98.06600229492187, "cfg_id": 4, "flops": 47360714240}
6
+ {"loss": 0.5747850706067049, "acc1": 86.26000259765625, "acc5": 97.93800240234376, "cfg_id": 5, "flops": 42619198464}
7
+ {"loss": 0.5890085864812136, "acc1": 85.79200244140625, "acc5": 97.80000272460937, "cfg_id": 6, "flops": 37877682688}
8
+ {"loss": 0.6165087098876635, "acc1": 85.08200264648437, "acc5": 97.55600231445312, "cfg_id": 7, "flops": 33136166912}
9
+ {"loss": 0.6652509210574807, "acc1": 83.69200263671875, "acc5": 97.23600259765625, "cfg_id": 8, "flops": 28394651136}
10
+ {"loss": 0.7374675334290122, "acc1": 81.7120026171875, "acc5": 96.53200251953125, "cfg_id": 9, "flops": 23653135360}
11
+ {"loss": 0.7991558508665273, "acc1": 79.50600241210938, "acc5": 95.90200240234375, "cfg_id": 10, "flops": 18911619584}
12
+ {"loss": 0.7554851990531791, "acc1": 80.63600265625, "acc5": 96.09000245117187, "cfg_id": 11, "flops": 14170103808}
13
+ {"loss": 0.7068120487824534, "acc1": 82.25000237304687, "acc5": 96.35600284179688, "cfg_id": 12, "flops": 9428588032}
14
+ {"loss": 0.7329587066038088, "acc1": 82.6600027734375, "acc5": 96.58200264648437, "cfg_id": 14, "flops": 9523655552}
15
+ {"loss": 0.7238117807516546, "acc1": 82.94800252929687, "acc5": 96.68600232421875, "cfg_id": 15, "flops": 14265171328}
16
+ {"loss": 0.7139950410434694, "acc1": 83.0860026953125, "acc5": 96.75800252929687, "cfg_id": 16, "flops": 19006687104}
17
+ {"loss": 0.7004092067028537, "acc1": 83.25400249023437, "acc5": 96.8740026171875, "cfg_id": 17, "flops": 23748202880}
18
+ {"loss": 0.6828147762201049, "acc1": 83.45000244140626, "acc5": 96.9520026171875, "cfg_id": 18, "flops": 28489718656}
19
+ {"loss": 0.6787144099221085, "acc1": 83.56400258789063, "acc5": 97.0600024609375, "cfg_id": 19, "flops": 33231234432}
20
+ {"loss": 0.6765228407175252, "acc1": 83.43400251953125, "acc5": 97.19200266601563, "cfg_id": 20, "flops": 37972750208}
21
+ {"loss": 0.6841061733888857, "acc1": 83.5900022265625, "acc5": 97.20800275390626, "cfg_id": 21, "flops": 42714265984}
22
+ {"loss": 0.6446758140104286, "acc1": 84.8660023828125, "acc5": 97.44400258789062, "cfg_id": 22, "flops": 47455781760}
23
+ {"loss": 0.5939652780917558, "acc1": 86.23000265625, "acc5": 97.69200270507812, "cfg_id": 23, "flops": 52197297536}
24
+ {"loss": 0.5654762382760192, "acc1": 86.43400250976562, "acc5": 97.632002578125, "cfg_id": 24, "flops": 56938813312}
25
+ {"loss": 0.5636055112788172, "acc1": 86.39000270507813, "acc5": 98.04800252929688, "cfg_id": 25, "flops": 57017547264}
26
+ {"loss": 0.5706944450397383, "acc1": 86.234002578125, "acc5": 98.00000237304687, "cfg_id": 26, "flops": 52276031488}
27
+ {"loss": 0.5833309799658529, "acc1": 85.9240025390625, "acc5": 97.9260024609375, "cfg_id": 27, "flops": 47534515712}
28
+ {"loss": 0.5972222860225223, "acc1": 85.57400262695313, "acc5": 97.70800255859375, "cfg_id": 28, "flops": 42792999936}
29
+ {"loss": 0.6253456006560362, "acc1": 84.89800259765624, "acc5": 97.47400255859375, "cfg_id": 29, "flops": 38051484160}
30
+ {"loss": 0.6745385262889393, "acc1": 83.5380026171875, "acc5": 97.07600244140625, "cfg_id": 30, "flops": 33309968384}
31
+ {"loss": 0.7486309014034994, "acc1": 81.42600245117187, "acc5": 96.33600258789062, "cfg_id": 31, "flops": 28568452608}
32
+ {"loss": 0.8134756960877867, "acc1": 79.16000235351562, "acc5": 95.72400271484375, "cfg_id": 32, "flops": 23826936832}
33
+ {"loss": 0.7671100513050051, "acc1": 80.37200240234375, "acc5": 95.98400258789063, "cfg_id": 33, "flops": 19085421056}
34
+ {"loss": 0.7206548866674756, "acc1": 81.91000239257812, "acc5": 96.23800239257812, "cfg_id": 34, "flops": 14343905280}
35
+ {"loss": 0.5626872230998494, "acc1": 86.44600235351562, "acc5": 98.062002421875, "cfg_id": 35, "flops": 57017547264}
36
+ {"loss": 0.5785287711769342, "acc1": 86.06400251953124, "acc5": 97.9420023046875, "cfg_id": 36, "flops": 52276031488}
37
+ {"loss": 0.5930487287202568, "acc1": 85.78400234375, "acc5": 97.79000255859376, "cfg_id": 37, "flops": 47534515712}
38
+ {"loss": 0.6189901619923838, "acc1": 85.10800268554688, "acc5": 97.50400228515625, "cfg_id": 38, "flops": 42792999936}
39
+ {"loss": 0.6674688318462083, "acc1": 83.76600272460938, "acc5": 97.09400264648437, "cfg_id": 39, "flops": 38051484160}
40
+ {"loss": 0.7388352820593299, "acc1": 81.70200266601563, "acc5": 96.47000241210938, "cfg_id": 40, "flops": 33309968384}
41
+ {"loss": 0.803126322613521, "acc1": 79.4560025390625, "acc5": 95.81400245117187, "cfg_id": 41, "flops": 28568452608}
42
+ {"loss": 0.7581946616145697, "acc1": 80.70800243164062, "acc5": 96.08600255859375, "cfg_id": 42, "flops": 23826936832}
43
+ {"loss": 0.7118472667467414, "acc1": 82.22600248046875, "acc5": 96.31000268554688, "cfg_id": 43, "flops": 19085421056}
44
+ {"loss": 0.5727639499713074, "acc1": 86.2180025, "acc5": 97.98200247070312, "cfg_id": 44, "flops": 57017547264}
45
+ {"loss": 0.5866389607615543, "acc1": 85.84400263671876, "acc5": 97.8600024609375, "cfg_id": 45, "flops": 52276031488}
46
+ {"loss": 0.6107792718279542, "acc1": 85.19800255859376, "acc5": 97.61800235351562, "cfg_id": 46, "flops": 47534515712}
47
+ {"loss": 0.6602028349809574, "acc1": 83.92600266601562, "acc5": 97.23800282226563, "cfg_id": 47, "flops": 42792999936}
48
+ {"loss": 0.7285334389431007, "acc1": 82.0040028125, "acc5": 96.52400247070312, "cfg_id": 48, "flops": 38051484160}
49
+ {"loss": 0.7910783413910505, "acc1": 79.69600262695313, "acc5": 95.95800241210938, "cfg_id": 49, "flops": 33309968384}
50
+ {"loss": 0.7478298004152197, "acc1": 80.89400243164063, "acc5": 96.172002421875, "cfg_id": 50, "flops": 28568452608}
51
+ {"loss": 0.7014034044449077, "acc1": 82.45600264648438, "acc5": 96.438002421875, "cfg_id": 51, "flops": 23826936832}
52
+ {"loss": 0.5799332931637764, "acc1": 85.92400239257813, "acc5": 97.94000249023438, "cfg_id": 52, "flops": 57017547264}
53
+ {"loss": 0.6004864230300441, "acc1": 85.43800255859375, "acc5": 97.70800227539063, "cfg_id": 53, "flops": 52276031488}
54
+ {"loss": 0.647012604287628, "acc1": 84.20200264648437, "acc5": 97.30600264648437, "cfg_id": 54, "flops": 47534515712}
55
+ {"loss": 0.7162722434961435, "acc1": 82.29000248046874, "acc5": 96.6640023046875, "cfg_id": 55, "flops": 42792999936}
56
+ {"loss": 0.7757266998065241, "acc1": 79.9760025, "acc5": 96.050002421875, "cfg_id": 56, "flops": 38051484160}
57
+ {"loss": 0.7351311787285588, "acc1": 81.04400232421875, "acc5": 96.2400026953125, "cfg_id": 57, "flops": 33309968384}
58
+ {"loss": 0.6896895027408997, "acc1": 82.6220026171875, "acc5": 96.55400252929688, "cfg_id": 58, "flops": 28568452608}
59
+ {"loss": 0.5911911701727094, "acc1": 85.53000266601562, "acc5": 97.76600241210937, "cfg_id": 59, "flops": 57017547264}
60
+ {"loss": 0.6371258264125297, "acc1": 84.41200249023437, "acc5": 97.44200245117187, "cfg_id": 60, "flops": 52276031488}
61
+ {"loss": 0.7022040815403064, "acc1": 82.49400240234375, "acc5": 96.74400272460937, "cfg_id": 61, "flops": 47534515712}
62
+ {"loss": 0.7612808859257987, "acc1": 80.29200265625, "acc5": 96.15800239257813, "cfg_id": 62, "flops": 42792999936}
63
+ {"loss": 0.7246641330420971, "acc1": 81.20400250976563, "acc5": 96.42400264648437, "cfg_id": 63, "flops": 38051484160}
64
+ {"loss": 0.6782861414619468, "acc1": 82.8040024609375, "acc5": 96.60800270507812, "cfg_id": 64, "flops": 33309968384}
65
+ {"loss": 0.629801401755575, "acc1": 84.65200262695312, "acc5": 97.54200265625, "cfg_id": 65, "flops": 57017547264}
66
+ {"loss": 0.6992729283643492, "acc1": 82.58200259765626, "acc5": 96.85600262695313, "cfg_id": 66, "flops": 52276031488}
67
+ {"loss": 0.7595290538262237, "acc1": 80.35600247070313, "acc5": 96.27000247070312, "cfg_id": 67, "flops": 47534515712}
68
+ {"loss": 0.7238247728709019, "acc1": 81.37600248046876, "acc5": 96.46200267578125, "cfg_id": 68, "flops": 42792999936}
69
+ {"loss": 0.6760879844765771, "acc1": 82.96800264648438, "acc5": 96.69400274414062, "cfg_id": 69, "flops": 38051484160}
70
+ {"loss": 0.68392569430624, "acc1": 83.16200254882813, "acc5": 97.09200258789062, "cfg_id": 70, "flops": 57017547264}
71
+ {"loss": 0.7509645553249301, "acc1": 80.68000260742187, "acc5": 96.3900025, "cfg_id": 71, "flops": 52276031488}
72
+ {"loss": 0.7208586267449639, "acc1": 81.55200274414062, "acc5": 96.62200272460937, "cfg_id": 72, "flops": 47534515712}
73
+ {"loss": 0.6785354860352747, "acc1": 82.86000262695312, "acc5": 96.80000244140625, "cfg_id": 73, "flops": 42792999936}
74
+ {"loss": 0.7184764705598354, "acc1": 81.61200241210938, "acc5": 96.63200258789063, "cfg_id": 74, "flops": 57017547264}
75
+ {"loss": 0.7229886900520686, "acc1": 81.45000249023437, "acc5": 96.62200250976562, "cfg_id": 75, "flops": 52276031488}
76
+ {"loss": 0.6883746685855316, "acc1": 83.01600272460938, "acc5": 96.83200240234375, "cfg_id": 76, "flops": 47534515712}
77
+ {"loss": 0.6293963799535325, "acc1": 83.90800231445313, "acc5": 97.27400245117188, "cfg_id": 77, "flops": 57017547264}
78
+ {"loss": 0.642419446824175, "acc1": 84.31400258789063, "acc5": 97.19200287109375, "cfg_id": 78, "flops": 52276031488}
79
+ {"loss": 0.5880116202275861, "acc1": 85.90600231445312, "acc5": 97.58400263671875, "cfg_id": 79, "flops": 57017547264}
80
+ {"loss": 0.750676692096573, "acc1": 82.14200271484376, "acc5": 96.36600260742188, "cfg_id": 80, "flops": 9504781184}
81
+ {"loss": 0.7431871895537232, "acc1": 82.234002578125, "acc5": 96.39200241210938, "cfg_id": 81, "flops": 14246296960}
82
+ {"loss": 0.7236298957105839, "acc1": 82.62600249023437, "acc5": 96.57600243164063, "cfg_id": 82, "flops": 18987812736}
83
+ {"loss": 0.7074674766397837, "acc1": 82.84600237304687, "acc5": 96.68800247070313, "cfg_id": 83, "flops": 23729328512}
84
+ {"loss": 0.7014015182062532, "acc1": 82.99000265625, "acc5": 96.7740025, "cfg_id": 84, "flops": 28470844288}
85
+ {"loss": 0.6996880258348855, "acc1": 82.98000252929687, "acc5": 96.90200263671875, "cfg_id": 85, "flops": 33212360064}
86
+ {"loss": 0.7077699161953095, "acc1": 82.96200270507812, "acc5": 96.98600258789062, "cfg_id": 86, "flops": 37953875840}
87
+ {"loss": 0.6674120087515224, "acc1": 84.33000274414063, "acc5": 97.2640025, "cfg_id": 87, "flops": 42695391616}
88
+ {"loss": 0.6169534720141779, "acc1": 85.86200280273438, "acc5": 97.49200272460938, "cfg_id": 88, "flops": 47436907392}
89
+ {"loss": 0.5848360503600403, "acc1": 86.02600271484376, "acc5": 97.4480026171875, "cfg_id": 89, "flops": 52178423168}
90
+ {"loss": 0.7346750153510859, "acc1": 82.5540027734375, "acc5": 96.52400241210937, "cfg_id": 90, "flops": 9504781184}
91
+ {"loss": 0.7158081559182117, "acc1": 82.82200255859375, "acc5": 96.65600255859376, "cfg_id": 91, "flops": 14246296960}
92
+ {"loss": 0.6994372372600165, "acc1": 83.03600239257813, "acc5": 96.74600252929687, "cfg_id": 92, "flops": 18987812736}
93
+ {"loss": 0.6947964186582601, "acc1": 83.07400255859375, "acc5": 96.88200241210937, "cfg_id": 93, "flops": 23729328512}
94
+ {"loss": 0.6946824553112189, "acc1": 83.0200026171875, "acc5": 96.99400251953125, "cfg_id": 94, "flops": 28470844288}
95
+ {"loss": 0.7041463901599249, "acc1": 83.21600236328125, "acc5": 97.00400250976563, "cfg_id": 95, "flops": 33212360064}
96
+ {"loss": 0.6699620116163384, "acc1": 84.54600258789063, "acc5": 97.31600244140625, "cfg_id": 96, "flops": 37953875840}
97
+ {"loss": 0.6176637105192199, "acc1": 85.91800228515625, "acc5": 97.57000255859376, "cfg_id": 97, "flops": 42695391616}
98
+ {"loss": 0.5765587539045196, "acc1": 86.1880023046875, "acc5": 97.54000249023437, "cfg_id": 98, "flops": 47436907392}
99
+ {"loss": 0.7319535712401072, "acc1": 82.49200258789062, "acc5": 96.46600271484375, "cfg_id": 99, "flops": 9504781184}
100
+ {"loss": 0.710809505516381, "acc1": 82.7860023046875, "acc5": 96.6260026171875, "cfg_id": 100, "flops": 14246296960}
101
+ {"loss": 0.7044268037107858, "acc1": 82.93000239257813, "acc5": 96.79800258789062, "cfg_id": 101, "flops": 18987812736}
102
+ {"loss": 0.7076575808001287, "acc1": 82.83200243164063, "acc5": 96.8820025390625, "cfg_id": 102, "flops": 23729328512}
103
+ {"loss": 0.7188302328189214, "acc1": 82.88200259765625, "acc5": 96.93600249023437, "cfg_id": 103, "flops": 28470844288}
104
+ {"loss": 0.6856357377361167, "acc1": 84.2520023046875, "acc5": 97.2200026171875, "cfg_id": 104, "flops": 33212360064}
105
+ {"loss": 0.6273381847210906, "acc1": 85.758002734375, "acc5": 97.44800275390625, "cfg_id": 105, "flops": 37953875840}
106
+ {"loss": 0.5830204013847944, "acc1": 86.01000260742188, "acc5": 97.4780026171875, "cfg_id": 106, "flops": 42695391616}
107
+ {"loss": 0.7305513945492831, "acc1": 82.3180023828125, "acc5": 96.47200256835937, "cfg_id": 107, "flops": 9504781184}
108
+ {"loss": 0.7206297208639708, "acc1": 82.37200228515626, "acc5": 96.61600234375, "cfg_id": 108, "flops": 14246296960}
109
+ {"loss": 0.7241401795975186, "acc1": 82.33800244140625, "acc5": 96.74400267578125, "cfg_id": 109, "flops": 18987812736}
110
+ {"loss": 0.7350799917723193, "acc1": 82.57800231445313, "acc5": 96.78000252929688, "cfg_id": 110, "flops": 23729328512}
111
+ {"loss": 0.7013292148935072, "acc1": 83.95000255859375, "acc5": 97.17600254882812, "cfg_id": 111, "flops": 28470844288}
112
+ {"loss": 0.64130035031474, "acc1": 85.54600244140624, "acc5": 97.36800262695313, "cfg_id": 112, "flops": 33212360064}
113
+ {"loss": 0.5961506485826138, "acc1": 85.76800252929688, "acc5": 97.33400268554688, "cfg_id": 113, "flops": 37953875840}
114
+ {"loss": 0.7443677056580782, "acc1": 81.86200244140625, "acc5": 96.29200241210937, "cfg_id": 114, "flops": 9504781184}
115
+ {"loss": 0.7442678388659701, "acc1": 81.82000262695313, "acc5": 96.3900026953125, "cfg_id": 115, "flops": 14246296960}
116
+ {"loss": 0.749958168037913, "acc1": 81.99400229492187, "acc5": 96.49600264648437, "cfg_id": 116, "flops": 18987812736}
117
+ {"loss": 0.7073916116672935, "acc1": 83.43200267578125, "acc5": 96.99200250976563, "cfg_id": 117, "flops": 23729328512}
118
+ {"loss": 0.6501240834706661, "acc1": 85.1480025, "acc5": 97.212002734375, "cfg_id": 118, "flops": 28470844288}
119
+ {"loss": 0.6135486943477934, "acc1": 85.35400266601563, "acc5": 97.16000249023438, "cfg_id": 119, "flops": 33212360064}
120
+ {"loss": 0.7824224890633062, "acc1": 81.16600250976562, "acc5": 96.11200255859374, "cfg_id": 120, "flops": 9504781184}
121
+ {"loss": 0.7932735298844901, "acc1": 80.93600233398438, "acc5": 96.12600254882813, "cfg_id": 121, "flops": 14246296960}
122
+ {"loss": 0.7476008046757091, "acc1": 82.49800259765625, "acc5": 96.65000265625, "cfg_id": 122, "flops": 18987812736}
123
+ {"loss": 0.6842290776019747, "acc1": 84.35800275390625, "acc5": 96.94000270507813, "cfg_id": 123, "flops": 23729328512}
124
+ {"loss": 0.6409159135073423, "acc1": 84.520002578125, "acc5": 96.91800252929687, "cfg_id": 124, "flops": 28470844288}
125
+ {"loss": 0.8394820786109476, "acc1": 79.96600270507813, "acc5": 95.5980024609375, "cfg_id": 125, "flops": 9504781184}
126
+ {"loss": 0.7924092794683847, "acc1": 81.1100024609375, "acc5": 96.12600264648438, "cfg_id": 126, "flops": 14246296960}
127
+ {"loss": 0.724013783997207, "acc1": 83.04000248046874, "acc5": 96.46600255859374, "cfg_id": 127, "flops": 18987812736}
128
+ {"loss": 0.6948224610338608, "acc1": 83.016002421875, "acc5": 96.44000282226563, "cfg_id": 128, "flops": 23729328512}
129
+ {"loss": 0.8688964597655066, "acc1": 79.15800264160156, "acc5": 95.2460025, "cfg_id": 129, "flops": 9504781184}
130
+ {"loss": 0.8095247168658357, "acc1": 80.57800265625, "acc5": 95.63000274414063, "cfg_id": 130, "flops": 14246296960}
131
+ {"loss": 0.7750140779059042, "acc1": 80.9680026171875, "acc5": 95.68200263671875, "cfg_id": 131, "flops": 18987812736}
132
+ {"loss": 0.9017180648039688, "acc1": 77.88400251953125, "acc5": 94.78400234375, "cfg_id": 132, "flops": 9504781184}
133
+ {"loss": 0.8799216277671583, "acc1": 77.84800252929688, "acc5": 94.79800247070312, "cfg_id": 133, "flops": 14246296960}
134
+ {"loss": 0.8987371790589709, "acc1": 78.12000266601562, "acc5": 94.92200256835937, "cfg_id": 134, "flops": 9504781184}
utils.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ """
4
+ Misc functions, including distributed helpers.
5
+
6
+ Mostly copy-paste from torchvision references.
7
+ """
8
+ import io
9
+ import os
10
+ import time
11
+ from collections import defaultdict, deque
12
+ import datetime
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+ import logging
17
+
18
+ logger_initialized = {}
19
+
20
+ def group_subnets_by_flops(data, flops_gap=1.0):
21
+ sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])}
22
+ candidate_idx = []
23
+ grouped_cands = []
24
+ last_flops = 0
25
+ for cfg_id, flops in sorted_data.items():
26
+ flops = flops / 1e9
27
+ if abs(last_flops - flops) > flops_gap:
28
+ if len(candidate_idx) > 0:
29
+ grouped_cands.append(sorted(candidate_idx))
30
+ candidate_idx = [int(cfg_id)]
31
+ last_flops = flops
32
+ else:
33
+ candidate_idx.append(int(cfg_id))
34
+
35
+ if len(candidate_idx) > 0:
36
+ grouped_cands.append(sorted(candidate_idx))
37
+
38
+ return grouped_cands
39
+
40
+ def find_best_candidates(data):
41
+ sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])}
42
+ candidate_idx = []
43
+ last_flops = 0
44
+ for cfg_id, values in sorted_data.items():
45
+ flops, score = values
46
+ if abs(last_flops - flops) > 1:
47
+ candidate_idx.append(cfg_id)
48
+ last_flops = flops
49
+ else:
50
+ if score > data[candidate_idx[-1]][1]:
51
+ candidate_idx[-1] = cfg_id
52
+
53
+ return candidate_idx
54
+
55
+
56
+
57
+ def find_top_candidates(data, ratio=0.9):
58
+ sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])}
59
+ candidate_idx = []
60
+ grouped_cands = []
61
+ last_flops = 0
62
+ for cfg_id, values in sorted_data.items():
63
+ flops, score = values
64
+ if abs(last_flops - flops) > 3:
65
+ if len(candidate_idx) > 0:
66
+ grouped_cands.append(candidate_idx)
67
+ candidate_idx = [cfg_id]
68
+ last_flops = flops
69
+ else:
70
+ candidate_idx.append(cfg_id)
71
+
72
+ if len(candidate_idx) > 0:
73
+ grouped_cands.append(candidate_idx)
74
+
75
+ final_list = []
76
+ for group in grouped_cands:
77
+ if len(group) == 1:
78
+ final_list += list(map(int, group))
79
+ continue
80
+ scores = torch.tensor([sorted_data[cfg_id][-1] for cfg_id in group])
81
+
82
+ indices = torch.argsort(scores, descending=True)
83
+ num_selected = int(ratio*len(group)) if int(ratio*len(group)) > 0 else 1
84
+
85
+ top_ids = indices[:num_selected].tolist()
86
+ selected = [group[idx] for idx in top_ids]
87
+ final_list += list(map(int, selected))
88
+
89
+ return final_list
90
+
91
+
92
+
93
+ def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
94
+ """Initialize and get a logger by name.
95
+
96
+ If the logger has not been initialized, this method will initialize the
97
+ logger by adding one or two handlers, otherwise the initialized logger will
98
+ be directly returned. During initialization, a StreamHandler will always be
99
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
100
+ will also be added.
101
+
102
+ Args:
103
+ name (str): Logger name.
104
+ log_file (str | None): The log filename. If specified, a FileHandler
105
+ will be added to the logger.
106
+ log_level (int): The logger level. Note that only the process of
107
+ rank 0 is affected, and other processes will set the level to
108
+ "Error" thus be silent most of the time.
109
+ file_mode (str): The file mode used in opening log file.
110
+ Defaults to 'w'.
111
+
112
+ Returns:
113
+ logging.Logger: The expected logger.
114
+ """
115
+ logger = logging.getLogger(name)
116
+ if name in logger_initialized:
117
+ return logger
118
+ # handle hierarchical names
119
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
120
+ # initialization since it is a child of "a".
121
+ for logger_name in logger_initialized:
122
+ if name.startswith(logger_name):
123
+ return logger
124
+
125
+ stream_handler = logging.StreamHandler()
126
+ handlers = [stream_handler]
127
+
128
+ if dist.is_available() and dist.is_initialized():
129
+ rank = dist.get_rank()
130
+ else:
131
+ rank = 0
132
+
133
+ # only rank 0 will add a FileHandler
134
+ if rank == 0 and log_file is not None:
135
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
136
+ # provide an interface to change the file mode to the default
137
+ # behaviour.
138
+ file_handler = logging.FileHandler(log_file, file_mode)
139
+ handlers.append(file_handler)
140
+
141
+ formatter = logging.Formatter(
142
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
143
+ for handler in handlers:
144
+ handler.setFormatter(formatter)
145
+ handler.setLevel(log_level)
146
+ logger.addHandler(handler)
147
+
148
+ if rank == 0:
149
+ logger.setLevel(log_level)
150
+ else:
151
+ logger.setLevel(logging.ERROR)
152
+
153
+ logger_initialized[name] = True
154
+
155
+ return logger
156
+
157
+
158
+ def get_root_logger(log_file=None, log_level=logging.INFO):
159
+ """Get the root logger.
160
+
161
+ The logger will be initialized if it has not been initialized. By default a
162
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
163
+ also be added. The name of the root logger is the top-level package name,
164
+ e.g., "mmseg".
165
+
166
+ Args:
167
+ log_file (str | None): The log filename. If specified, a FileHandler
168
+ will be added to the root logger.
169
+ log_level (int): The root logger level. Note that only the process of
170
+ rank 0 is affected, while other processes will set the level to
171
+ "Error" and be silent most of the time.
172
+
173
+ Returns:
174
+ logging.Logger: The root logger.
175
+ """
176
+
177
+ logger = get_logger(name='snnet', log_file=log_file, log_level=log_level)
178
+
179
+ return logger
180
+
181
+ class SmoothedValue(object):
182
+ """Track a series of values and provide access to smoothed values over a
183
+ window or the global series average.
184
+ """
185
+
186
+ def __init__(self, window_size=20, fmt=None):
187
+ if fmt is None:
188
+ fmt = "{median:.4f} ({global_avg:.4f})"
189
+ self.deque = deque(maxlen=window_size)
190
+ self.total = 0.0
191
+ self.count = 0
192
+ self.fmt = fmt
193
+
194
+ def update(self, value, n=1):
195
+ self.deque.append(value)
196
+ self.count += n
197
+ self.total += value * n
198
+
199
+ def synchronize_between_processes(self):
200
+ """
201
+ Warning: does not synchronize the deque!
202
+ """
203
+ if not is_dist_avail_and_initialized():
204
+ return
205
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
206
+ dist.barrier()
207
+ dist.all_reduce(t)
208
+ t = t.tolist()
209
+ self.count = int(t[0])
210
+ self.total = t[1]
211
+
212
+ @property
213
+ def median(self):
214
+ d = torch.tensor(list(self.deque))
215
+ return d.median().item()
216
+
217
+ @property
218
+ def avg(self):
219
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
220
+ return d.mean().item()
221
+
222
+ @property
223
+ def global_avg(self):
224
+ return self.total / self.count
225
+
226
+ @property
227
+ def max(self):
228
+ return max(self.deque)
229
+
230
+ @property
231
+ def value(self):
232
+ return self.deque[-1]
233
+
234
+ def __str__(self):
235
+ return self.fmt.format(
236
+ median=self.median,
237
+ avg=self.avg,
238
+ global_avg=self.global_avg,
239
+ max=self.max,
240
+ value=self.value)
241
+
242
+
243
+ class MetricLogger(object):
244
+ def __init__(self, delimiter="\t", logger=None):
245
+ self.meters = defaultdict(SmoothedValue)
246
+ self.delimiter = delimiter
247
+ self.logger = logger
248
+
249
+ def update(self, **kwargs):
250
+ for k, v in kwargs.items():
251
+ if isinstance(v, torch.Tensor):
252
+ v = v.item()
253
+ assert isinstance(v, (float, int))
254
+ self.meters[k].update(v)
255
+
256
+ def __getattr__(self, attr):
257
+ if attr in self.meters:
258
+ return self.meters[attr]
259
+ if attr in self.__dict__:
260
+ return self.__dict__[attr]
261
+ raise AttributeError("'{}' object has no attribute '{}'".format(
262
+ type(self).__name__, attr))
263
+
264
+ def __str__(self):
265
+ loss_str = []
266
+ for name, meter in self.meters.items():
267
+ loss_str.append(
268
+ "{}: {}".format(name, str(meter))
269
+ )
270
+ return self.delimiter.join(loss_str)
271
+
272
+ def synchronize_between_processes(self):
273
+ for meter in self.meters.values():
274
+ meter.synchronize_between_processes()
275
+
276
+ def add_meter(self, name, meter):
277
+ self.meters[name] = meter
278
+
279
+ def log_every(self, iterable, print_freq, header=None):
280
+ i = 0
281
+ if not header:
282
+ header = ''
283
+ start_time = time.time()
284
+ end = time.time()
285
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
286
+ data_time = SmoothedValue(fmt='{avg:.4f}')
287
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
288
+ log_msg = [
289
+ header,
290
+ '[{0' + space_fmt + '}/{1}]',
291
+ 'eta: {eta}',
292
+ '{meters}',
293
+ 'time: {time}',
294
+ 'data: {data}'
295
+ ]
296
+ if torch.cuda.is_available():
297
+ log_msg.append('max mem: {memory:.0f}')
298
+ log_msg = self.delimiter.join(log_msg)
299
+ MB = 1024.0 * 1024.0
300
+ for obj in iterable:
301
+ data_time.update(time.time() - end)
302
+ yield obj
303
+ iter_time.update(time.time() - end)
304
+ if i % print_freq == 0 or i == len(iterable) - 1:
305
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
306
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
307
+ if torch.cuda.is_available():
308
+ self.logger.info(log_msg.format(
309
+ i, len(iterable), eta=eta_string,
310
+ meters=str(self),
311
+ time=str(iter_time), data=str(data_time),
312
+ memory=torch.cuda.max_memory_allocated() / MB))
313
+ else:
314
+ self.logger.info(log_msg.format(
315
+ i, len(iterable), eta=eta_string,
316
+ meters=str(self),
317
+ time=str(iter_time), data=str(data_time)))
318
+ i += 1
319
+ end = time.time()
320
+ total_time = time.time() - start_time
321
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
322
+ self.logger.info('{} Total time: {} ({:.4f} s / it)'.format(
323
+ header, total_time_str, total_time / len(iterable)))
324
+
325
+
326
+ def _load_checkpoint_for_ema(model_ema, checkpoint):
327
+ """
328
+ Workaround for ModelEma._load_checkpoint to accept an already-loaded object
329
+ """
330
+ mem_file = io.BytesIO()
331
+ torch.save({'state_dict_ema':checkpoint}, mem_file)
332
+ mem_file.seek(0)
333
+ model_ema._load_checkpoint(mem_file)
334
+
335
+
336
+ def setup_for_distributed(is_master):
337
+ """
338
+ This function disables printing when not in master process
339
+ """
340
+ import builtins as __builtin__
341
+ builtin_print = __builtin__.print
342
+
343
+ def print(*args, **kwargs):
344
+ force = kwargs.pop('force', False)
345
+ if is_master or force:
346
+ builtin_print(*args, **kwargs)
347
+
348
+ __builtin__.print = print
349
+
350
+
351
+ def is_dist_avail_and_initialized():
352
+ if not dist.is_available():
353
+ return False
354
+ if not dist.is_initialized():
355
+ return False
356
+ return True
357
+
358
+
359
+ def get_world_size():
360
+ if not is_dist_avail_and_initialized():
361
+ return 1
362
+ return dist.get_world_size()
363
+
364
+
365
+ def get_rank():
366
+ if not is_dist_avail_and_initialized():
367
+ return 0
368
+ return dist.get_rank()
369
+
370
+
371
+ def is_main_process():
372
+ return get_rank() == 0
373
+
374
+
375
+ def save_on_master(*args, **kwargs):
376
+ if is_main_process():
377
+ torch.save(*args, **kwargs)
378
+
379
+
380
+ def init_distributed_mode(args):
381
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
382
+ args.rank = int(os.environ["RANK"])
383
+ args.world_size = int(os.environ['WORLD_SIZE'])
384
+ args.gpu = int(os.environ['LOCAL_RANK'])
385
+ elif 'SLURM_PROCID' in os.environ:
386
+ args.rank = int(os.environ['SLURM_PROCID'])
387
+ args.gpu = args.rank % torch.cuda.device_count()
388
+ else:
389
+ print('Not using distributed mode')
390
+ args.distributed = False
391
+ return
392
+
393
+ args.distributed = True
394
+
395
+ torch.cuda.set_device(args.gpu)
396
+ args.dist_backend = 'nccl'
397
+ print('| distributed init (rank {}): {}'.format(
398
+ args.rank, args.dist_url), flush=True)
399
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
400
+ world_size=args.world_size, rank=args.rank)
401
+ torch.distributed.barrier()
402
+ setup_for_distributed(args.rank == 0)
403
+
404
+ import json
405
+ def save_on_master_eval_res(log_stats, output_dir):
406
+ if is_main_process():
407
+ with open(output_dir, 'a') as f:
408
+ f.write(json.dumps(log_stats) + "\n")