Spaces:
Runtime error
Runtime error
HubHop
commited on
Commit
·
bcfa144
1
Parent(s):
c8ed6d7
update
Browse files- .idea/.gitignore +8 -0
- __pycache__/datasets.cpython-39.pyc +0 -0
- __pycache__/models_v2.cpython-39.pyc +0 -0
- __pycache__/snnet.cpython-39.pyc +0 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- app.py +411 -4
- datasets.py +109 -0
- demo.jpg +0 -0
- flops_gradio_demo.json +136 -0
- gradio_banner.png +0 -0
- gradio_demo.json +33 -0
- models_v2.py +568 -0
- outputs/deit/20240118_171921.log +1 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172124.log +2 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172140.log +2 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172156.log +5 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172250.log +5 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172309.log +5 -0
- outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172332.log +5 -0
- requirements.txt +3 -0
- snnet.py +473 -0
- snnetv2_deit3_s_l.pth +3 -0
- stitches_res_s_l.txt +134 -0
- utils.py +408 -0
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
__pycache__/datasets.cpython-39.pyc
ADDED
Binary file (2.97 kB). View file
|
|
__pycache__/models_v2.cpython-39.pyc
ADDED
Binary file (17.5 kB). View file
|
|
__pycache__/snnet.cpython-39.pyc
ADDED
Binary file (13.6 kB). View file
|
|
__pycache__/utils.cpython-39.pyc
ADDED
Binary file (13 kB). View file
|
|
app.py
CHANGED
@@ -1,7 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
iface.launch()
|
|
|
1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
import argparse
|
4 |
+
import datetime
|
5 |
+
import numpy as np
|
6 |
+
import time
|
7 |
+
import torch
|
8 |
+
import torch.backends.cudnn as cudnn
|
9 |
+
import json
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
from utils import get_root_logger
|
13 |
+
from timm.models import create_model
|
14 |
+
import models_v2
|
15 |
+
import requests
|
16 |
+
|
17 |
+
import utils
|
18 |
+
import time
|
19 |
+
import sys
|
20 |
+
import datetime
|
21 |
+
import os
|
22 |
+
from snnet import SNNet, SNNetv2
|
23 |
+
import warnings
|
24 |
+
|
25 |
+
warnings.filterwarnings("ignore")
|
26 |
+
from fvcore.nn import FlopCountAnalysis
|
27 |
+
|
28 |
+
from PIL import Image
|
29 |
import gradio as gr
|
30 |
+
import plotly.express as px
|
31 |
+
|
32 |
+
|
33 |
+
def get_args_parser():
|
34 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
35 |
+
parser.add_argument('--batch-size', default=64, type=int)
|
36 |
+
parser.add_argument('--epochs', default=300, type=int)
|
37 |
+
parser.add_argument('--bce-loss', action='store_true')
|
38 |
+
parser.add_argument('--unscale-lr', action='store_true')
|
39 |
+
|
40 |
+
# Model parameters
|
41 |
+
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
|
42 |
+
help='Name of model to train')
|
43 |
+
parser.add_argument('--input-size', default=224, type=int, help='images input size')
|
44 |
+
|
45 |
+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
46 |
+
help='Dropout rate (default: 0.)')
|
47 |
+
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
|
48 |
+
help='Drop path rate (default: 0.1)')
|
49 |
+
|
50 |
+
parser.add_argument('--model-ema', action='store_true')
|
51 |
+
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
|
52 |
+
parser.set_defaults(model_ema=True)
|
53 |
+
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
|
54 |
+
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
|
55 |
+
|
56 |
+
# Optimizer parameters
|
57 |
+
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
58 |
+
help='Optimizer (default: "adamw"')
|
59 |
+
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
60 |
+
help='Optimizer Epsilon (default: 1e-8)')
|
61 |
+
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
62 |
+
help='Optimizer Betas (default: None, use opt default)')
|
63 |
+
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
64 |
+
help='Clip gradient norm (default: None, no clipping)')
|
65 |
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
66 |
+
help='SGD momentum (default: 0.9)')
|
67 |
+
parser.add_argument('--weight-decay', type=float, default=0.05,
|
68 |
+
help='weight decay (default: 0.05)')
|
69 |
+
# Learning rate schedule parameters
|
70 |
+
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
71 |
+
help='LR scheduler (default: "cosine"')
|
72 |
+
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
|
73 |
+
help='learning rate (default: 5e-4)')
|
74 |
+
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
75 |
+
help='learning rate noise on/off epoch percentages')
|
76 |
+
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
77 |
+
help='learning rate noise limit percent (default: 0.67)')
|
78 |
+
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
79 |
+
help='learning rate noise std-dev (default: 1.0)')
|
80 |
+
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
81 |
+
help='warmup learning rate (default: 1e-6)')
|
82 |
+
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
83 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
84 |
+
|
85 |
+
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
86 |
+
help='epoch interval to decay LR')
|
87 |
+
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
88 |
+
help='epochs to warmup LR, if scheduler supports')
|
89 |
+
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
90 |
+
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
91 |
+
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
92 |
+
help='patience epochs for Plateau LR scheduler (default: 10')
|
93 |
+
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
94 |
+
help='LR decay rate (default: 0.1)')
|
95 |
+
|
96 |
+
# Augmentation parameters
|
97 |
+
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
|
98 |
+
help='Color jitter factor (default: 0.3)')
|
99 |
+
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
100 |
+
help='Use AutoAugment policy. "v0" or "original". " + \
|
101 |
+
"(default: rand-m9-mstd0.5-inc1)'),
|
102 |
+
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
|
103 |
+
parser.add_argument('--train-interpolation', type=str, default='bicubic',
|
104 |
+
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
105 |
+
|
106 |
+
parser.add_argument('--repeated-aug', action='store_true')
|
107 |
+
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
|
108 |
+
parser.set_defaults(repeated_aug=True)
|
109 |
+
|
110 |
+
parser.add_argument('--train-mode', action='store_true')
|
111 |
+
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
|
112 |
+
parser.set_defaults(train_mode=True)
|
113 |
+
|
114 |
+
parser.add_argument('--ThreeAugment', action='store_true') # 3augment
|
115 |
+
|
116 |
+
parser.add_argument('--src', action='store_true') # simple random crop
|
117 |
+
|
118 |
+
# * Random Erase params
|
119 |
+
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
120 |
+
help='Random erase prob (default: 0.25)')
|
121 |
+
parser.add_argument('--remode', type=str, default='pixel',
|
122 |
+
help='Random erase mode (default: "pixel")')
|
123 |
+
parser.add_argument('--recount', type=int, default=1,
|
124 |
+
help='Random erase count (default: 1)')
|
125 |
+
parser.add_argument('--resplit', action='store_true', default=False,
|
126 |
+
help='Do not random erase first (clean) augmentation split')
|
127 |
+
|
128 |
+
# * Mixup params
|
129 |
+
parser.add_argument('--mixup', type=float, default=0.8,
|
130 |
+
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
|
131 |
+
parser.add_argument('--cutmix', type=float, default=1.0,
|
132 |
+
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
|
133 |
+
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
|
134 |
+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
135 |
+
parser.add_argument('--mixup-prob', type=float, default=1.0,
|
136 |
+
help='Probability of performing mixup or cutmix when either/both is enabled')
|
137 |
+
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
|
138 |
+
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
139 |
+
parser.add_argument('--mixup-mode', type=str, default='batch',
|
140 |
+
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
141 |
+
|
142 |
+
# Distillation parameters
|
143 |
+
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
|
144 |
+
help='Name of teacher model to train (default: "regnety_160"')
|
145 |
+
parser.add_argument('--teacher-path', type=str, default='')
|
146 |
+
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
|
147 |
+
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
148 |
+
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
149 |
+
|
150 |
+
# * Finetuning params
|
151 |
+
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
152 |
+
parser.add_argument('--attn-only', action='store_true')
|
153 |
+
|
154 |
+
# Dataset parameters
|
155 |
+
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
156 |
+
help='dataset path')
|
157 |
+
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
|
158 |
+
type=str, help='Image Net dataset path')
|
159 |
+
parser.add_argument('--inat-category', default='name',
|
160 |
+
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
|
161 |
+
type=str, help='semantic granularity')
|
162 |
+
|
163 |
+
parser.add_argument('--output_dir', default='',
|
164 |
+
help='path where to save, empty for no saving')
|
165 |
+
parser.add_argument('--device', default='cpu',
|
166 |
+
help='device to use for training / testing')
|
167 |
+
parser.add_argument('--seed', default=0, type=int)
|
168 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
169 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
170 |
+
help='start epoch')
|
171 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
172 |
+
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
|
173 |
+
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
|
174 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
175 |
+
parser.add_argument('--pin-mem', action='store_true',
|
176 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
177 |
+
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
|
178 |
+
help='')
|
179 |
+
parser.set_defaults(pin_mem=True)
|
180 |
+
|
181 |
+
# distributed training parameters
|
182 |
+
parser.add_argument('--world_size', default=1, type=int,
|
183 |
+
help='number of distributed processes')
|
184 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
185 |
+
|
186 |
+
parser.add_argument('--exp_name', default='deit', type=str, help='experiment name')
|
187 |
+
parser.add_argument('--config', default=None, type=str, help='configuration')
|
188 |
+
parser.add_argument('--scoring', action='store_true', default=False, help='configuration')
|
189 |
+
parser.add_argument('--proxy', default='synflow', type=str, help='configuration')
|
190 |
+
parser.add_argument('--snnet_name', default='snnetv2', type=str, help='configuration')
|
191 |
+
parser.add_argument('--get_flops', action='store_true')
|
192 |
+
parser.add_argument('--flops_sampling_k', default=None, type=float, help="Crop ratio for evaluation")
|
193 |
+
parser.add_argument('--low_rank', action='store_true', default=False, help='Enabling distributed evaluation')
|
194 |
+
parser.add_argument('--lora_r', default=64, type=int,
|
195 |
+
help='number of distributed processes')
|
196 |
+
parser.add_argument('--flops_gap', default=1.0, type=float,
|
197 |
+
help='number of distributed processes')
|
198 |
+
|
199 |
+
return parser
|
200 |
+
|
201 |
+
|
202 |
+
def initialize_model_stitching_layer(model, mixup_fn, data_loader, device):
|
203 |
+
for samples, targets in data_loader:
|
204 |
+
samples = samples.to(device, non_blocking=True)
|
205 |
+
targets = targets.to(device, non_blocking=True)
|
206 |
+
|
207 |
+
if mixup_fn is not None:
|
208 |
+
samples, targets = mixup_fn(samples, targets)
|
209 |
+
|
210 |
+
with torch.cuda.amp.autocast():
|
211 |
+
model.initialize_stitching_weights(samples)
|
212 |
+
|
213 |
+
break
|
214 |
+
|
215 |
+
|
216 |
+
@torch.no_grad()
|
217 |
+
def analyse_flops_for_all(model, config_name):
|
218 |
+
all_cfgs = model.all_cfgs
|
219 |
+
stitch_results = {}
|
220 |
+
|
221 |
+
for cfg_id in all_cfgs:
|
222 |
+
model.reset_stitch_id(cfg_id)
|
223 |
+
flops = FlopCountAnalysis(model, torch.randn(1, 3, 224, 224).cuda()).total()
|
224 |
+
stitch_results[cfg_id] = flops
|
225 |
+
|
226 |
+
save_dir = './model_flops'
|
227 |
+
if not os.path.exists(save_dir):
|
228 |
+
os.mkdir(save_dir)
|
229 |
+
|
230 |
+
with open(os.path.join(save_dir, f'flops_{config_name}.json'), 'w+') as f:
|
231 |
+
json.dump(stitch_results, f, indent=4)
|
232 |
+
|
233 |
+
|
234 |
+
def main(args):
|
235 |
+
utils.init_distributed_mode(args)
|
236 |
+
|
237 |
+
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
238 |
+
logger = get_root_logger(os.path.join(args.output_dir, f'{timestamp}.log'))
|
239 |
+
|
240 |
+
logger.info(str(args))
|
241 |
+
|
242 |
+
if args.distillation_type != 'none' and args.finetune and not args.eval:
|
243 |
+
raise NotImplementedError("Finetuning with distillation not yet supported")
|
244 |
+
|
245 |
+
device = torch.device(args.device)
|
246 |
+
|
247 |
+
# fix the seed for reproducibility
|
248 |
+
seed = args.seed + utils.get_rank()
|
249 |
+
torch.manual_seed(seed)
|
250 |
+
np.random.seed(seed)
|
251 |
+
# random.seed(seed)
|
252 |
+
|
253 |
+
cudnn.benchmark = True
|
254 |
+
|
255 |
+
from datasets import build_transform
|
256 |
+
|
257 |
+
transform = build_transform(False, args)
|
258 |
+
|
259 |
+
anchors = []
|
260 |
+
for i, anchor_name in enumerate(args.anchors):
|
261 |
+
logger.info(f"Creating model: {anchor_name}")
|
262 |
+
anchor = create_model(
|
263 |
+
anchor_name,
|
264 |
+
pretrained=False,
|
265 |
+
pretrained_deit=None,
|
266 |
+
num_classes=1000,
|
267 |
+
drop_path_rate=args.anchor_drop_path[i],
|
268 |
+
img_size=args.input_size
|
269 |
+
)
|
270 |
+
anchors.append(anchor)
|
271 |
+
|
272 |
+
model = SNNetv2(anchors, lora_r=args.lora_r)
|
273 |
+
|
274 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
275 |
+
# torch.save({'model': checkpoint['model']}, './snnetv2_deit3_s_l_50ep.pth')
|
276 |
+
|
277 |
+
logger.info(f"load checkpoint from {args.resume}")
|
278 |
+
model.load_state_dict(checkpoint['model'])
|
279 |
+
|
280 |
+
model.to(device)
|
281 |
+
|
282 |
+
config_name = args.config.split('/')[-1].split('.')[0]
|
283 |
+
model.eval()
|
284 |
+
|
285 |
+
eval_res = {}
|
286 |
+
flops_res = {}
|
287 |
+
with open('stitches_res_s_l.txt', 'r') as f:
|
288 |
+
for line in f.readlines():
|
289 |
+
epoch_stat = json.loads(line.strip())
|
290 |
+
eval_res[epoch_stat['cfg_id']] = epoch_stat['acc1']
|
291 |
+
flops_res[epoch_stat['cfg_id']] = epoch_stat['flops'] / 1e9
|
292 |
+
|
293 |
+
def visualize_stitch_pos(stitch_id):
|
294 |
+
if stitch_id == 13:
|
295 |
+
# 13 is equivalent to 0
|
296 |
+
stitch_id = 0
|
297 |
+
|
298 |
+
names = [f'ID {key}' for key in flops_res.keys()]
|
299 |
+
|
300 |
+
fig = px.scatter(x=flops_res.values(), y=eval_res.values(), hover_name=names)
|
301 |
+
fig.update_layout(
|
302 |
+
title=f"SN-Netv2 - Stitch ID - {stitch_id}",
|
303 |
+
title_x=0.5,
|
304 |
+
xaxis_title="GFLOPs",
|
305 |
+
yaxis_title="mIoU",
|
306 |
+
font=dict(
|
307 |
+
family="Courier New, monospace",
|
308 |
+
size=18,
|
309 |
+
color="RebeccaPurple"
|
310 |
+
),
|
311 |
+
legend=dict(
|
312 |
+
yanchor="bottom",
|
313 |
+
y=0.99,
|
314 |
+
xanchor="left",
|
315 |
+
x=0.01),
|
316 |
+
)
|
317 |
+
# continent, DarkSlateGrey
|
318 |
+
fig.update_traces(marker=dict(size=10,
|
319 |
+
line=dict(width=2)),
|
320 |
+
selector=dict(mode='markers'))
|
321 |
+
|
322 |
+
fig.add_scatter(x=[flops_res[stitch_id]], y=[eval_res[stitch_id]], mode='markers', marker=dict(size=15),
|
323 |
+
name='Current Stitch')
|
324 |
+
return fig
|
325 |
+
|
326 |
+
# Download human-readable labels for ImageNet.
|
327 |
+
response = requests.get("https://git.io/JJkYN")
|
328 |
+
labels = response.text.split("\n")
|
329 |
+
|
330 |
+
def process_image(image, stitch_id):
|
331 |
+
# inp = torch.from_numpy(image).permute(2, 0, 1).float()
|
332 |
+
inp = transform(image).unsqueeze(0).to(device)
|
333 |
+
model.reset_stitch_id(stitch_id)
|
334 |
+
with torch.no_grad():
|
335 |
+
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
|
336 |
+
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
|
337 |
+
fig = visualize_stitch_pos(stitch_id)
|
338 |
+
return confidences, fig
|
339 |
+
|
340 |
+
with gr.Blocks() as main_page:
|
341 |
+
with gr.Column():
|
342 |
+
gr.HTML("""
|
343 |
+
<h1 align="center" style=" display: flex; flex-direction: row; justify-content: center; font-size: 25pt; ">Stitched ViTs are Flexible Vision Backbones</h1>
|
344 |
+
<div align="center"> <img align="center" src='file/gradio_banner.png' width="70%"> </div>
|
345 |
+
<h3 align="center" >This is the classification demo page of SN-Netv2, an flexible vision backbone that allows for 100+ runtime speed and performance trade-offs.</h3>
|
346 |
+
<h3 align="center" >You can also run this gradio demo on your local GPUs at https://github.com/ziplab/SN-Netv2</h3>
|
347 |
+
""")
|
348 |
+
with gr.Row():
|
349 |
+
with gr.Column():
|
350 |
+
image_input = gr.Image(type='pil')
|
351 |
+
stitch_slider = gr.Slider(minimum=0, maximum=134, step=1, label="Stitch ID")
|
352 |
+
with gr.Row():
|
353 |
+
clear_button = gr.ClearButton()
|
354 |
+
submit_button = gr.Button()
|
355 |
+
with gr.Column():
|
356 |
+
label_output = gr.Label(num_top_classes=5)
|
357 |
+
stitch_plot = gr.Plot(label='Stitch Position')
|
358 |
+
|
359 |
+
submit_button.click(
|
360 |
+
fn=process_image,
|
361 |
+
inputs=[image_input, stitch_slider],
|
362 |
+
outputs=[label_output, stitch_plot],
|
363 |
+
)
|
364 |
+
|
365 |
+
stitch_slider.change(
|
366 |
+
fn=visualize_stitch_pos,
|
367 |
+
inputs=[stitch_slider],
|
368 |
+
outputs=[stitch_plot],
|
369 |
+
show_progress=False
|
370 |
+
)
|
371 |
+
|
372 |
+
clear_button.click(
|
373 |
+
lambda: [None, 0, None, None],
|
374 |
+
outputs=[image_input, stitch_slider, label_output, stitch_plot],
|
375 |
+
)
|
376 |
+
|
377 |
+
gr.Examples(
|
378 |
+
[
|
379 |
+
['demo.jpg', 0],
|
380 |
+
],
|
381 |
+
inputs=[
|
382 |
+
image_input,
|
383 |
+
stitch_slider
|
384 |
+
],
|
385 |
+
outputs=[
|
386 |
+
label_output,
|
387 |
+
stitch_plot
|
388 |
+
],
|
389 |
+
)
|
390 |
+
|
391 |
+
main_page.launch(allowed_paths=['./'])
|
392 |
+
|
393 |
+
|
394 |
+
if __name__ == '__main__':
|
395 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
|
396 |
+
args = parser.parse_args()
|
397 |
+
setattr(args, 'config', f'gradio_demo.json')
|
398 |
+
if args.config is not None:
|
399 |
+
config_args = json.load(open(args.config))
|
400 |
+
override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:]
|
401 |
+
if arg.startswith('--')}
|
402 |
+
for k, v in config_args.items():
|
403 |
+
if k not in override_keys:
|
404 |
+
setattr(args, k, v)
|
405 |
+
|
406 |
+
output_dir = os.path.join('outputs', args.exp_name)
|
407 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
408 |
+
checkpoint_path = os.path.join(output_dir, 'checkpoint.pth')
|
409 |
+
if os.path.exists(checkpoint_path) and not args.resume:
|
410 |
+
setattr(args, 'resume', checkpoint_path)
|
411 |
|
412 |
+
setattr(args, 'output_dir', output_dir)
|
|
|
413 |
|
414 |
+
main(args)
|
|
datasets.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
|
6 |
+
from torchvision import datasets, transforms
|
7 |
+
from torchvision.datasets.folder import ImageFolder, default_loader
|
8 |
+
|
9 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
10 |
+
from timm.data import create_transform
|
11 |
+
|
12 |
+
|
13 |
+
class INatDataset(ImageFolder):
|
14 |
+
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
|
15 |
+
category='name', loader=default_loader):
|
16 |
+
self.transform = transform
|
17 |
+
self.loader = loader
|
18 |
+
self.target_transform = target_transform
|
19 |
+
self.year = year
|
20 |
+
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
|
21 |
+
path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
|
22 |
+
with open(path_json) as json_file:
|
23 |
+
data = json.load(json_file)
|
24 |
+
|
25 |
+
with open(os.path.join(root, 'categories.json')) as json_file:
|
26 |
+
data_catg = json.load(json_file)
|
27 |
+
|
28 |
+
path_json_for_targeter = os.path.join(root, f"train{year}.json")
|
29 |
+
|
30 |
+
with open(path_json_for_targeter) as json_file:
|
31 |
+
data_for_targeter = json.load(json_file)
|
32 |
+
|
33 |
+
targeter = {}
|
34 |
+
indexer = 0
|
35 |
+
for elem in data_for_targeter['annotations']:
|
36 |
+
king = []
|
37 |
+
king.append(data_catg[int(elem['category_id'])][category])
|
38 |
+
if king[0] not in targeter.keys():
|
39 |
+
targeter[king[0]] = indexer
|
40 |
+
indexer += 1
|
41 |
+
self.nb_classes = len(targeter)
|
42 |
+
|
43 |
+
self.samples = []
|
44 |
+
for elem in data['images']:
|
45 |
+
cut = elem['file_name'].split('/')
|
46 |
+
target_current = int(cut[2])
|
47 |
+
path_current = os.path.join(root, cut[0], cut[2], cut[3])
|
48 |
+
|
49 |
+
categors = data_catg[target_current]
|
50 |
+
target_current_true = targeter[categors[category]]
|
51 |
+
self.samples.append((path_current, target_current_true))
|
52 |
+
|
53 |
+
# __getitem__ and __len__ inherited from ImageFolder
|
54 |
+
|
55 |
+
|
56 |
+
def build_dataset(is_train, args):
|
57 |
+
transform = build_transform(is_train, args)
|
58 |
+
|
59 |
+
if args.data_set == 'CIFAR':
|
60 |
+
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
|
61 |
+
nb_classes = 100
|
62 |
+
elif args.data_set == 'IMNET':
|
63 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
64 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
65 |
+
nb_classes = 1000
|
66 |
+
elif args.data_set == 'INAT':
|
67 |
+
dataset = INatDataset(args.data_path, train=is_train, year=2018,
|
68 |
+
category=args.inat_category, transform=transform)
|
69 |
+
nb_classes = dataset.nb_classes
|
70 |
+
elif args.data_set == 'INAT19':
|
71 |
+
dataset = INatDataset(args.data_path, train=is_train, year=2019,
|
72 |
+
category=args.inat_category, transform=transform)
|
73 |
+
nb_classes = dataset.nb_classes
|
74 |
+
|
75 |
+
return dataset, nb_classes
|
76 |
+
|
77 |
+
|
78 |
+
def build_transform(is_train, args):
|
79 |
+
resize_im = args.input_size > 32
|
80 |
+
if is_train:
|
81 |
+
# this should always dispatch to transforms_imagenet_train
|
82 |
+
transform = create_transform(
|
83 |
+
input_size=args.input_size,
|
84 |
+
is_training=True,
|
85 |
+
color_jitter=args.color_jitter,
|
86 |
+
auto_augment=args.aa,
|
87 |
+
interpolation=args.train_interpolation,
|
88 |
+
re_prob=args.reprob,
|
89 |
+
re_mode=args.remode,
|
90 |
+
re_count=args.recount,
|
91 |
+
)
|
92 |
+
if not resize_im:
|
93 |
+
# replace RandomResizedCropAndInterpolation with
|
94 |
+
# RandomCrop
|
95 |
+
transform.transforms[0] = transforms.RandomCrop(
|
96 |
+
args.input_size, padding=4)
|
97 |
+
return transform
|
98 |
+
|
99 |
+
t = []
|
100 |
+
if resize_im:
|
101 |
+
size = int(args.input_size / args.eval_crop_ratio)
|
102 |
+
t.append(
|
103 |
+
transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
|
104 |
+
)
|
105 |
+
t.append(transforms.CenterCrop(args.input_size))
|
106 |
+
|
107 |
+
t.append(transforms.ToTensor())
|
108 |
+
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
109 |
+
return transforms.Compose(t)
|
demo.jpg
ADDED
flops_gradio_demo.json
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"0": 4608338304,
|
3 |
+
"1": 61604135936,
|
4 |
+
"2": 56843745792,
|
5 |
+
"3": 52102230016,
|
6 |
+
"4": 47360714240,
|
7 |
+
"5": 42619198464,
|
8 |
+
"6": 37877682688,
|
9 |
+
"7": 33136166912,
|
10 |
+
"8": 28394651136,
|
11 |
+
"9": 23653135360,
|
12 |
+
"10": 18911619584,
|
13 |
+
"11": 14170103808,
|
14 |
+
"12": 9428588032,
|
15 |
+
"14": 9523655552,
|
16 |
+
"15": 14265171328,
|
17 |
+
"16": 19006687104,
|
18 |
+
"17": 23748202880,
|
19 |
+
"18": 28489718656,
|
20 |
+
"19": 33231234432,
|
21 |
+
"20": 37972750208,
|
22 |
+
"21": 42714265984,
|
23 |
+
"22": 47455781760,
|
24 |
+
"23": 52197297536,
|
25 |
+
"24": 56938813312,
|
26 |
+
"25": 57017547264,
|
27 |
+
"26": 52276031488,
|
28 |
+
"27": 47534515712,
|
29 |
+
"28": 42792999936,
|
30 |
+
"29": 38051484160,
|
31 |
+
"30": 33309968384,
|
32 |
+
"31": 28568452608,
|
33 |
+
"32": 23826936832,
|
34 |
+
"33": 19085421056,
|
35 |
+
"34": 14343905280,
|
36 |
+
"35": 57017547264,
|
37 |
+
"36": 52276031488,
|
38 |
+
"37": 47534515712,
|
39 |
+
"38": 42792999936,
|
40 |
+
"39": 38051484160,
|
41 |
+
"40": 33309968384,
|
42 |
+
"41": 28568452608,
|
43 |
+
"42": 23826936832,
|
44 |
+
"43": 19085421056,
|
45 |
+
"44": 57017547264,
|
46 |
+
"45": 52276031488,
|
47 |
+
"46": 47534515712,
|
48 |
+
"47": 42792999936,
|
49 |
+
"48": 38051484160,
|
50 |
+
"49": 33309968384,
|
51 |
+
"50": 28568452608,
|
52 |
+
"51": 23826936832,
|
53 |
+
"52": 57017547264,
|
54 |
+
"53": 52276031488,
|
55 |
+
"54": 47534515712,
|
56 |
+
"55": 42792999936,
|
57 |
+
"56": 38051484160,
|
58 |
+
"57": 33309968384,
|
59 |
+
"58": 28568452608,
|
60 |
+
"59": 57017547264,
|
61 |
+
"60": 52276031488,
|
62 |
+
"61": 47534515712,
|
63 |
+
"62": 42792999936,
|
64 |
+
"63": 38051484160,
|
65 |
+
"64": 33309968384,
|
66 |
+
"65": 57017547264,
|
67 |
+
"66": 52276031488,
|
68 |
+
"67": 47534515712,
|
69 |
+
"68": 42792999936,
|
70 |
+
"69": 38051484160,
|
71 |
+
"70": 57017547264,
|
72 |
+
"71": 52276031488,
|
73 |
+
"72": 47534515712,
|
74 |
+
"73": 42792999936,
|
75 |
+
"74": 57017547264,
|
76 |
+
"75": 52276031488,
|
77 |
+
"76": 47534515712,
|
78 |
+
"77": 57017547264,
|
79 |
+
"78": 52276031488,
|
80 |
+
"79": 57017547264,
|
81 |
+
"80": 9504781184,
|
82 |
+
"81": 14246296960,
|
83 |
+
"82": 18987812736,
|
84 |
+
"83": 23729328512,
|
85 |
+
"84": 28470844288,
|
86 |
+
"85": 33212360064,
|
87 |
+
"86": 37953875840,
|
88 |
+
"87": 42695391616,
|
89 |
+
"88": 47436907392,
|
90 |
+
"89": 52178423168,
|
91 |
+
"90": 9504781184,
|
92 |
+
"91": 14246296960,
|
93 |
+
"92": 18987812736,
|
94 |
+
"93": 23729328512,
|
95 |
+
"94": 28470844288,
|
96 |
+
"95": 33212360064,
|
97 |
+
"96": 37953875840,
|
98 |
+
"97": 42695391616,
|
99 |
+
"98": 47436907392,
|
100 |
+
"99": 9504781184,
|
101 |
+
"100": 14246296960,
|
102 |
+
"101": 18987812736,
|
103 |
+
"102": 23729328512,
|
104 |
+
"103": 28470844288,
|
105 |
+
"104": 33212360064,
|
106 |
+
"105": 37953875840,
|
107 |
+
"106": 42695391616,
|
108 |
+
"107": 9504781184,
|
109 |
+
"108": 14246296960,
|
110 |
+
"109": 18987812736,
|
111 |
+
"110": 23729328512,
|
112 |
+
"111": 28470844288,
|
113 |
+
"112": 33212360064,
|
114 |
+
"113": 37953875840,
|
115 |
+
"114": 9504781184,
|
116 |
+
"115": 14246296960,
|
117 |
+
"116": 18987812736,
|
118 |
+
"117": 23729328512,
|
119 |
+
"118": 28470844288,
|
120 |
+
"119": 33212360064,
|
121 |
+
"120": 9504781184,
|
122 |
+
"121": 14246296960,
|
123 |
+
"122": 18987812736,
|
124 |
+
"123": 23729328512,
|
125 |
+
"124": 28470844288,
|
126 |
+
"125": 9504781184,
|
127 |
+
"126": 14246296960,
|
128 |
+
"127": 18987812736,
|
129 |
+
"128": 23729328512,
|
130 |
+
"129": 9504781184,
|
131 |
+
"130": 14246296960,
|
132 |
+
"131": 18987812736,
|
133 |
+
"132": 9504781184,
|
134 |
+
"133": 14246296960,
|
135 |
+
"134": 9504781184
|
136 |
+
}
|
gradio_banner.png
ADDED
gradio_demo.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"anchors": ["deit_small_patch16_LS", "deit_large_patch16_LS"],
|
3 |
+
"batch_size": 64,
|
4 |
+
"snnet_name": "snnet_v2",
|
5 |
+
"data_path": "/data2/datasets/imagenet",
|
6 |
+
"data_set": "IMNET",
|
7 |
+
"exp_name": "stitch_s_l_v2_lora_r_64_50_ep",
|
8 |
+
"input_size": 224,
|
9 |
+
"num_workers": 10,
|
10 |
+
"lr": 0.00003,
|
11 |
+
"warmup_lr": 1e-7,
|
12 |
+
"epochs": 50,
|
13 |
+
"weight_decay": 0.02,
|
14 |
+
"sched": "cosine",
|
15 |
+
"eval_crop_ratio": 1.0,
|
16 |
+
"reprob": 0.0,
|
17 |
+
"smoothing": 0.1,
|
18 |
+
"warmup_epochs": 5,
|
19 |
+
"drop": 0.0,
|
20 |
+
"seed": 0,
|
21 |
+
"opt": "fusedlamb",
|
22 |
+
"mixup": 0,
|
23 |
+
"anchor_drop_path": [0.05, 0.4],
|
24 |
+
"cutmix": 1.0,
|
25 |
+
"color_jitter": 0.3,
|
26 |
+
"unscale_lr": true,
|
27 |
+
"no_repeated_aug": true,
|
28 |
+
"ThreeAugment": true,
|
29 |
+
"src": true,
|
30 |
+
"lora_r": 64,
|
31 |
+
"pretrained_deit": "../pretrained_weights",
|
32 |
+
"resume": "snnetv2_deit3_s_l.pth"
|
33 |
+
}
|
models_v2.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
import os.path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
from timm.models.vision_transformer import Mlp, PatchEmbed , _cfg
|
10 |
+
|
11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
12 |
+
from timm.models.registry import register_model
|
13 |
+
# from xformers.ops import memory_efficient_attention
|
14 |
+
|
15 |
+
class Attention(nn.Module):
|
16 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
17 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
18 |
+
super().__init__()
|
19 |
+
self.num_heads = num_heads
|
20 |
+
head_dim = dim // num_heads
|
21 |
+
self.scale = qk_scale or head_dim ** -0.5
|
22 |
+
|
23 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
24 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
25 |
+
self.proj = nn.Linear(dim, dim)
|
26 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
B, N, C = x.shape
|
30 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
31 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
32 |
+
|
33 |
+
|
34 |
+
# x = memory_efficient_attention(q, k, v).transpose(1, 2).reshape(B, N, C)
|
35 |
+
|
36 |
+
q = q * self.scale
|
37 |
+
attn = (q @ k.transpose(-2, -1))
|
38 |
+
attn = attn.softmax(dim=-1)
|
39 |
+
attn = self.attn_drop(attn)
|
40 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
41 |
+
|
42 |
+
x = self.proj(x)
|
43 |
+
x = self.proj_drop(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
class Block(nn.Module):
|
47 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
48 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
49 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
|
50 |
+
,init_values=1e-4):
|
51 |
+
super().__init__()
|
52 |
+
self.norm1 = norm_layer(dim)
|
53 |
+
self.attn = Attention_block(
|
54 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
55 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
56 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
57 |
+
self.norm2 = norm_layer(dim)
|
58 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
59 |
+
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
63 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
64 |
+
return x
|
65 |
+
|
66 |
+
class Layer_scale_init_Block(nn.Module):
|
67 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
68 |
+
# with slight modifications
|
69 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
70 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
|
71 |
+
,init_values=1e-4):
|
72 |
+
super().__init__()
|
73 |
+
self.norm1 = norm_layer(dim)
|
74 |
+
self.attn = Attention_block(
|
75 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
76 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
77 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
78 |
+
self.norm2 = norm_layer(dim)
|
79 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
80 |
+
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
81 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
82 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
86 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
87 |
+
return x
|
88 |
+
|
89 |
+
class Layer_scale_init_Block_paralx2(nn.Module):
|
90 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
91 |
+
# with slight modifications
|
92 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
93 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
|
94 |
+
,init_values=1e-4):
|
95 |
+
super().__init__()
|
96 |
+
self.norm1 = norm_layer(dim)
|
97 |
+
self.norm11 = norm_layer(dim)
|
98 |
+
self.attn = Attention_block(
|
99 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
100 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
101 |
+
self.attn1 = Attention_block(
|
102 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
103 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
104 |
+
self.norm2 = norm_layer(dim)
|
105 |
+
self.norm21 = norm_layer(dim)
|
106 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
107 |
+
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
108 |
+
self.mlp1 = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
109 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
110 |
+
self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
111 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
112 |
+
self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
x = x + self.drop_path(self.gamma_1*self.attn(self.norm1(x))) + self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
|
116 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
|
117 |
+
return x
|
118 |
+
|
119 |
+
class Block_paralx2(nn.Module):
|
120 |
+
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
121 |
+
# with slight modifications
|
122 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
123 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,Attention_block = Attention,Mlp_block=Mlp
|
124 |
+
,init_values=1e-4):
|
125 |
+
super().__init__()
|
126 |
+
self.norm1 = norm_layer(dim)
|
127 |
+
self.norm11 = norm_layer(dim)
|
128 |
+
self.attn = Attention_block(
|
129 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
130 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
131 |
+
self.attn1 = Attention_block(
|
132 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
133 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
134 |
+
self.norm2 = norm_layer(dim)
|
135 |
+
self.norm21 = norm_layer(dim)
|
136 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
137 |
+
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
138 |
+
self.mlp1 = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.attn1(self.norm11(x)))
|
142 |
+
x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.mlp1(self.norm21(x)))
|
143 |
+
return x
|
144 |
+
|
145 |
+
|
146 |
+
class hMLP_stem(nn.Module):
|
147 |
+
""" hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf
|
148 |
+
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
149 |
+
with slight modifications
|
150 |
+
"""
|
151 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,norm_layer=nn.SyncBatchNorm):
|
152 |
+
super().__init__()
|
153 |
+
img_size = to_2tuple(img_size)
|
154 |
+
patch_size = to_2tuple(patch_size)
|
155 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
156 |
+
self.img_size = img_size
|
157 |
+
self.patch_size = patch_size
|
158 |
+
self.num_patches = num_patches
|
159 |
+
self.proj = torch.nn.Sequential(*[nn.Conv2d(in_chans, embed_dim//4, kernel_size=4, stride=4),
|
160 |
+
norm_layer(embed_dim//4),
|
161 |
+
nn.GELU(),
|
162 |
+
nn.Conv2d(embed_dim//4, embed_dim//4, kernel_size=2, stride=2),
|
163 |
+
norm_layer(embed_dim//4),
|
164 |
+
nn.GELU(),
|
165 |
+
nn.Conv2d(embed_dim//4, embed_dim, kernel_size=2, stride=2),
|
166 |
+
norm_layer(embed_dim),
|
167 |
+
])
|
168 |
+
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
B, C, H, W = x.shape
|
172 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
173 |
+
return x
|
174 |
+
|
175 |
+
class vit_models(nn.Module):
|
176 |
+
""" Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support
|
177 |
+
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
178 |
+
with slight modifications
|
179 |
+
"""
|
180 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
181 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
182 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool=None,
|
183 |
+
block_layers = Block,
|
184 |
+
Patch_layer=PatchEmbed,act_layer=nn.GELU,
|
185 |
+
Attention_block = Attention, Mlp_block=Mlp,
|
186 |
+
dpr_constant=True,init_scale=1e-4,
|
187 |
+
mlp_ratio_clstk = 4.0):
|
188 |
+
super().__init__()
|
189 |
+
|
190 |
+
self.dropout_rate = drop_rate
|
191 |
+
self.depth = depth
|
192 |
+
|
193 |
+
self.num_classes = num_classes
|
194 |
+
self.num_features = self.embed_dim = embed_dim
|
195 |
+
|
196 |
+
self.patch_embed = Patch_layer(
|
197 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
198 |
+
num_patches = self.patch_embed.num_patches
|
199 |
+
|
200 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
201 |
+
|
202 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
203 |
+
|
204 |
+
dpr = [drop_path_rate for i in range(depth)]
|
205 |
+
self.blocks = nn.ModuleList([
|
206 |
+
block_layers(
|
207 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
208 |
+
drop=0.0, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
209 |
+
act_layer=act_layer,Attention_block=Attention_block,Mlp_block=Mlp_block,init_values=init_scale)
|
210 |
+
for i in range(depth)])
|
211 |
+
|
212 |
+
self.norm = norm_layer(embed_dim)
|
213 |
+
|
214 |
+
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
|
215 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
216 |
+
|
217 |
+
trunc_normal_(self.pos_embed, std=.02)
|
218 |
+
trunc_normal_(self.cls_token, std=.02)
|
219 |
+
self.apply(self._init_weights)
|
220 |
+
|
221 |
+
def _init_weights(self, m):
|
222 |
+
if isinstance(m, nn.Linear):
|
223 |
+
trunc_normal_(m.weight, std=.02)
|
224 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
225 |
+
nn.init.constant_(m.bias, 0)
|
226 |
+
elif isinstance(m, nn.LayerNorm):
|
227 |
+
nn.init.constant_(m.bias, 0)
|
228 |
+
nn.init.constant_(m.weight, 1.0)
|
229 |
+
|
230 |
+
@torch.jit.ignore
|
231 |
+
def no_weight_decay(self):
|
232 |
+
return {'pos_embed', 'cls_token'}
|
233 |
+
|
234 |
+
def get_classifier(self):
|
235 |
+
return self.head
|
236 |
+
|
237 |
+
def get_num_layers(self):
|
238 |
+
return len(self.blocks)
|
239 |
+
|
240 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
241 |
+
self.num_classes = num_classes
|
242 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
243 |
+
|
244 |
+
|
245 |
+
def extract_block_features(self, x):
|
246 |
+
B = x.shape[0]
|
247 |
+
x = self.patch_embed(x)
|
248 |
+
|
249 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
250 |
+
|
251 |
+
x = x + self.pos_embed
|
252 |
+
|
253 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
254 |
+
|
255 |
+
outs = {}
|
256 |
+
|
257 |
+
for i, blk in enumerate(self.blocks):
|
258 |
+
x = blk(x)
|
259 |
+
outs[i] = x.detach()
|
260 |
+
return outs
|
261 |
+
|
262 |
+
def selective_forward(self, x, begin, end):
|
263 |
+
for i, blk in enumerate(self.blocks):
|
264 |
+
if i < begin:
|
265 |
+
continue
|
266 |
+
if i > end:
|
267 |
+
break
|
268 |
+
x = blk(x)
|
269 |
+
return x
|
270 |
+
|
271 |
+
def forward_until(self, x, blk_id):
|
272 |
+
B = x.shape[0]
|
273 |
+
x = self.patch_embed(x)
|
274 |
+
|
275 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
276 |
+
x = x + self.pos_embed
|
277 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
278 |
+
|
279 |
+
for i, blk in enumerate(self.blocks):
|
280 |
+
x = blk(x)
|
281 |
+
if i == blk_id:
|
282 |
+
break
|
283 |
+
|
284 |
+
return x
|
285 |
+
|
286 |
+
def forward_from(self, x, blk_id):
|
287 |
+
for i, blk in enumerate(self.blocks):
|
288 |
+
if i < blk_id:
|
289 |
+
continue
|
290 |
+
x = blk(x)
|
291 |
+
|
292 |
+
x = self.norm(x)
|
293 |
+
x = self.head(x[:, 0])
|
294 |
+
|
295 |
+
return x
|
296 |
+
|
297 |
+
def forward_patch_embed(self, x):
|
298 |
+
B = x.shape[0]
|
299 |
+
x = self.patch_embed(x)
|
300 |
+
|
301 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
302 |
+
|
303 |
+
x = x + self.pos_embed
|
304 |
+
|
305 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
306 |
+
return x
|
307 |
+
|
308 |
+
|
309 |
+
def forward_norm_head(self, x):
|
310 |
+
x = self.norm(x)
|
311 |
+
x = self.head(x[:, 0])
|
312 |
+
return x
|
313 |
+
|
314 |
+
|
315 |
+
def forward_features(self, x):
|
316 |
+
B = x.shape[0]
|
317 |
+
x = self.patch_embed(x)
|
318 |
+
|
319 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
320 |
+
|
321 |
+
x = x + self.pos_embed
|
322 |
+
|
323 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
324 |
+
|
325 |
+
for i , blk in enumerate(self.blocks):
|
326 |
+
x = blk(x)
|
327 |
+
|
328 |
+
x = self.norm(x)
|
329 |
+
return x[:, 0]
|
330 |
+
|
331 |
+
def forward(self, x):
|
332 |
+
|
333 |
+
x = self.forward_features(x)
|
334 |
+
|
335 |
+
if self.dropout_rate:
|
336 |
+
x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
|
337 |
+
x = self.head(x)
|
338 |
+
|
339 |
+
return x
|
340 |
+
|
341 |
+
# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
|
342 |
+
|
343 |
+
@register_model
|
344 |
+
def deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, pretrained_cfg_overlay=None, **kwargs):
|
345 |
+
model = vit_models(
|
346 |
+
img_size = img_size, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
347 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
348 |
+
|
349 |
+
return model
|
350 |
+
|
351 |
+
|
352 |
+
@register_model
|
353 |
+
def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, pretrained_cfg=None, pretrained_deit=None, pretrained_cfg_overlay=None, **kwargs):
|
354 |
+
model = vit_models(
|
355 |
+
img_size = img_size, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
356 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
357 |
+
model.default_cfg = _cfg()
|
358 |
+
if pretrained:
|
359 |
+
# name = 'https://dl.fbaipublicfiles.com/deit/deit_3_small_'+str(img_size)+'_'
|
360 |
+
# if pretrained_21k:
|
361 |
+
# name+='21k.pth'
|
362 |
+
# else:
|
363 |
+
# name+='1k.pth'
|
364 |
+
|
365 |
+
# checkpoint = torch.hub.load_state_dict_from_url(
|
366 |
+
# url=name,
|
367 |
+
# map_location="cpu", check_hash=True
|
368 |
+
# )
|
369 |
+
checkpoint = torch.load(os.path.join(pretrained_deit, 'deit_3_small_224_21k.pth'))
|
370 |
+
model.load_state_dict(checkpoint["model"])
|
371 |
+
|
372 |
+
return model
|
373 |
+
|
374 |
+
@register_model
|
375 |
+
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
376 |
+
model = vit_models(
|
377 |
+
patch_size=16, embed_dim=512, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True,
|
378 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
379 |
+
model.default_cfg = _cfg()
|
380 |
+
if pretrained:
|
381 |
+
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_medium_'+str(img_size)+'_'
|
382 |
+
if pretrained_21k:
|
383 |
+
name+='21k.pth'
|
384 |
+
else:
|
385 |
+
name+='1k.pth'
|
386 |
+
|
387 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
388 |
+
url=name,
|
389 |
+
map_location="cpu", check_hash=True
|
390 |
+
)
|
391 |
+
model.load_state_dict(checkpoint["model"])
|
392 |
+
return model
|
393 |
+
|
394 |
+
@register_model
|
395 |
+
def deit_base_patch16_LS(pretrained=False, pretrained_cfg=None, img_size=224, pretrained_21k = False, pretrained_deit=None, pretrained_cfg_overlay=None, **kwargs):
|
396 |
+
model = vit_models(
|
397 |
+
img_size = img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
398 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
399 |
+
if pretrained:
|
400 |
+
# name = 'https://dl.fbaipublicfiles.com/deit/deit_3_small_'+str(img_size)+'_'
|
401 |
+
# if pretrained_21k:
|
402 |
+
# name+='21k.pth'
|
403 |
+
# else:
|
404 |
+
# name+='1k.pth'
|
405 |
+
|
406 |
+
# checkpoint = torch.hub.load_state_dict_from_url(
|
407 |
+
# url=name,
|
408 |
+
# map_location="cpu", check_hash=True
|
409 |
+
# )
|
410 |
+
checkpoint = torch.load(os.path.join(pretrained_deit, 'deit_3_base_224_21k.pth'))
|
411 |
+
model.load_state_dict(checkpoint["model"])
|
412 |
+
return model
|
413 |
+
|
414 |
+
@register_model
|
415 |
+
def deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k = False, pretrained_cfg=None, pretrained_deit=None, pretrained_cfg_overlay=None, **kwargs):
|
416 |
+
model = vit_models(
|
417 |
+
img_size = img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
418 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
419 |
+
if pretrained:
|
420 |
+
# name = 'https://dl.fbaipublicfiles.com/deit/deit_3_large_'+str(img_size)+'_'
|
421 |
+
# if pretrained_21k:
|
422 |
+
# name+='21k.pth'
|
423 |
+
# else:
|
424 |
+
# name+='1k.pth'
|
425 |
+
#
|
426 |
+
# checkpoint = torch.hub.load_state_dict_from_url(
|
427 |
+
# url=name,
|
428 |
+
# map_location="cpu", check_hash=True
|
429 |
+
# )
|
430 |
+
checkpoint = torch.load(os.path.join(pretrained_deit, 'deit_3_large_224_21k.pth'))
|
431 |
+
model.load_state_dict(checkpoint["model"])
|
432 |
+
return model
|
433 |
+
|
434 |
+
@register_model
|
435 |
+
def deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
436 |
+
model = vit_models(
|
437 |
+
img_size = img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
438 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
439 |
+
if pretrained:
|
440 |
+
name = 'https://dl.fbaipublicfiles.com/deit/deit_3_huge_'+str(img_size)+'_'
|
441 |
+
if pretrained_21k:
|
442 |
+
name+='21k_v1.pth'
|
443 |
+
else:
|
444 |
+
name+='1k_v1.pth'
|
445 |
+
|
446 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
447 |
+
url=name,
|
448 |
+
map_location="cpu", check_hash=True
|
449 |
+
)
|
450 |
+
model.load_state_dict(checkpoint["model"])
|
451 |
+
return model
|
452 |
+
|
453 |
+
@register_model
|
454 |
+
def deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
455 |
+
model = vit_models(
|
456 |
+
img_size = img_size, patch_size=14, embed_dim=1280, depth=52, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
457 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
458 |
+
|
459 |
+
return model
|
460 |
+
|
461 |
+
@register_model
|
462 |
+
def deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
463 |
+
model = vit_models(
|
464 |
+
img_size = img_size, patch_size=14, embed_dim=1280, depth=26, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
465 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block_paralx2, **kwargs)
|
466 |
+
|
467 |
+
return model
|
468 |
+
|
469 |
+
@register_model
|
470 |
+
def deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
471 |
+
model = vit_models(
|
472 |
+
img_size = img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
473 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Block_paral_LS, **kwargs)
|
474 |
+
|
475 |
+
return model
|
476 |
+
|
477 |
+
@register_model
|
478 |
+
def deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
479 |
+
model = vit_models(
|
480 |
+
img_size = img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
481 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Block_paral_LS, **kwargs)
|
482 |
+
return model
|
483 |
+
|
484 |
+
@register_model
|
485 |
+
def deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
486 |
+
model = vit_models(
|
487 |
+
img_size = img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
488 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
489 |
+
return model
|
490 |
+
|
491 |
+
@register_model
|
492 |
+
def deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
493 |
+
model = vit_models(
|
494 |
+
img_size = img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
495 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers = Layer_scale_init_Block, **kwargs)
|
496 |
+
#model.default_cfg = _cfg()
|
497 |
+
|
498 |
+
return model
|
499 |
+
|
500 |
+
# Models from Three things everyone should know about Vision Transformers (https://arxiv.org/pdf/2203.09795.pdf)
|
501 |
+
|
502 |
+
@register_model
|
503 |
+
def deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
504 |
+
model = vit_models(
|
505 |
+
img_size = img_size, patch_size=16, embed_dim=384, depth=36, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
506 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
507 |
+
|
508 |
+
return model
|
509 |
+
|
510 |
+
@register_model
|
511 |
+
def deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
512 |
+
model = vit_models(
|
513 |
+
img_size = img_size, patch_size=16, embed_dim=384, depth=36, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
514 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
515 |
+
|
516 |
+
return model
|
517 |
+
|
518 |
+
@register_model
|
519 |
+
def deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
520 |
+
model = vit_models(
|
521 |
+
img_size = img_size, patch_size=16, embed_dim=384, depth=18, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
522 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block_paralx2, **kwargs)
|
523 |
+
|
524 |
+
return model
|
525 |
+
|
526 |
+
@register_model
|
527 |
+
def deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
528 |
+
model = vit_models(
|
529 |
+
img_size = img_size, patch_size=16, embed_dim=384, depth=18, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
530 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Block_paralx2, **kwargs)
|
531 |
+
|
532 |
+
return model
|
533 |
+
|
534 |
+
|
535 |
+
@register_model
|
536 |
+
def deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
537 |
+
model = vit_models(
|
538 |
+
img_size = img_size, patch_size=16, embed_dim=768, depth=18, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
539 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block_paralx2, **kwargs)
|
540 |
+
|
541 |
+
return model
|
542 |
+
|
543 |
+
|
544 |
+
@register_model
|
545 |
+
def deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
546 |
+
model = vit_models(
|
547 |
+
img_size = img_size, patch_size=16, embed_dim=768, depth=18, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
548 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Block_paralx2, **kwargs)
|
549 |
+
|
550 |
+
return model
|
551 |
+
|
552 |
+
|
553 |
+
@register_model
|
554 |
+
def deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
555 |
+
model = vit_models(
|
556 |
+
img_size = img_size, patch_size=16, embed_dim=768, depth=36, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
557 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),block_layers=Layer_scale_init_Block, **kwargs)
|
558 |
+
|
559 |
+
return model
|
560 |
+
|
561 |
+
@register_model
|
562 |
+
def deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21k = False, **kwargs):
|
563 |
+
model = vit_models(
|
564 |
+
img_size = img_size, patch_size=16, embed_dim=768, depth=36, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
565 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
566 |
+
|
567 |
+
return model
|
568 |
+
|
outputs/deit/20240118_171921.log
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
2024-01-18 17:19:21,866 - snnet - INFO - Namespace(batch_size=64, epochs=300, bce_loss=False, unscale_lr=False, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, sched='cosine', lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-06, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/datasets01/imagenet_full_size/061417/', data_set='IMNET', inat_category='name', output_dir='outputs/deit', device='cuda', seed=0, resume='', start_epoch=0, eval=False, eval_crop_ratio=0.875, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='deit', config=None, scoring=False, proxy='synflow', snnet_name='snnetv2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, distributed=False)
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172124.log
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
2024-01-18 17:21:24,162 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cuda', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
2 |
+
2024-01-18 17:21:24,163 - snnet - INFO - Creating model: deit_small_patch16_LS
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172140.log
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
2024-01-18 17:21:40,831 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cuda', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
2 |
+
2024-01-18 17:21:40,832 - snnet - INFO - Creating model: deit_small_patch16_LS
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172156.log
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-01-18 17:21:56,859 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cuda', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
2 |
+
2024-01-18 17:21:56,859 - snnet - INFO - Creating model: deit_small_patch16_LS
|
3 |
+
2024-01-18 17:21:57,078 - snnet - INFO - Creating model: deit_large_patch16_LS
|
4 |
+
2024-01-18 17:21:59,994 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
|
5 |
+
2024-01-18 17:22:00,521 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172250.log
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-01-18 17:22:50,304 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cpu', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
2 |
+
2024-01-18 17:22:50,305 - snnet - INFO - Creating model: deit_small_patch16_LS
|
3 |
+
2024-01-18 17:22:50,535 - snnet - INFO - Creating model: deit_large_patch16_LS
|
4 |
+
2024-01-18 17:22:53,873 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
|
5 |
+
2024-01-18 17:22:54,392 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172309.log
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-01-18 17:23:09,551 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cpu', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
2 |
+
2024-01-18 17:23:09,553 - snnet - INFO - Creating model: deit_small_patch16_LS
|
3 |
+
2024-01-18 17:23:09,778 - snnet - INFO - Creating model: deit_large_patch16_LS
|
4 |
+
2024-01-18 17:23:13,077 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
|
5 |
+
2024-01-18 17:23:13,587 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
|
outputs/stitch_s_l_v2_lora_r_64_50_ep/20240118_172332.log
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-01-18 17:23:32,357 - snnet - INFO - Namespace(batch_size=64, epochs=50, bce_loss=False, unscale_lr=True, model='deit_base_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='fusedlamb', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=3e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-07, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=True, src=True, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, finetune='', attn_only=False, data_path='/data2/datasets/imagenet', data_set='IMNET', inat_category='name', output_dir='outputs/stitch_s_l_v2_lora_r_64_50_ep', device='cpu', seed=0, resume='snnetv2_deit3_s_l.pth', start_epoch=0, eval=False, eval_crop_ratio=1.0, dist_eval=False, num_workers=10, pin_mem=True, world_size=1, dist_url='env://', exp_name='stitch_s_l_v2_lora_r_64_50_ep', config='gradio_demo.json', scoring=False, proxy='synflow', snnet_name='snnet_v2', get_flops=False, flops_sampling_k=None, low_rank=False, lora_r=64, flops_gap=1.0, anchors=['deit_small_patch16_LS', 'deit_large_patch16_LS'], anchor_drop_path=[0.05, 0.4], no_repeated_aug=True, pretrained_deit='../pretrained_weights', distributed=False)
|
2 |
+
2024-01-18 17:23:32,358 - snnet - INFO - Creating model: deit_small_patch16_LS
|
3 |
+
2024-01-18 17:23:32,606 - snnet - INFO - Creating model: deit_large_patch16_LS
|
4 |
+
2024-01-18 17:23:35,576 - snnet - INFO - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134]
|
5 |
+
2024-01-18 17:23:36,120 - snnet - INFO - load checkpoint from snnetv2_deit3_s_l.pth
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
timm==0.6.12
|
3 |
+
fvcore
|
snnet.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.import math
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from collections import defaultdict
|
11 |
+
from utils import get_root_logger
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
def rearrange_activations(activations):
|
15 |
+
n_channels = activations.shape[-1]
|
16 |
+
activations = activations.reshape(-1, n_channels)
|
17 |
+
return activations
|
18 |
+
|
19 |
+
|
20 |
+
def ps_inv(x1, x2):
|
21 |
+
'''Least-squares solver given feature maps from two anchors.
|
22 |
+
'''
|
23 |
+
x1 = rearrange_activations(x1)
|
24 |
+
x2 = rearrange_activations(x2)
|
25 |
+
|
26 |
+
if not x1.shape[0] == x2.shape[0]:
|
27 |
+
raise ValueError('Spatial size of compared neurons must match when ' \
|
28 |
+
'calculating psuedo inverse matrix.')
|
29 |
+
|
30 |
+
# Get transformation matrix shape
|
31 |
+
shape = list(x1.shape)
|
32 |
+
shape[-1] += 1
|
33 |
+
|
34 |
+
# Calculate pseudo inverse
|
35 |
+
x1_ones = torch.ones(shape)
|
36 |
+
x1_ones[:, :-1] = x1
|
37 |
+
A_ones = torch.matmul(torch.linalg.pinv(x1_ones), x2.to(x1_ones.device)).T
|
38 |
+
|
39 |
+
# Get weights and bias
|
40 |
+
w = A_ones[..., :-1]
|
41 |
+
b = A_ones[..., -1]
|
42 |
+
|
43 |
+
return w, b
|
44 |
+
|
45 |
+
|
46 |
+
def reset_out_indices(front_depth=12, end_depth=24, out_indices=(9, 14, 19, 23)):
|
47 |
+
block_ids = torch.tensor(list(range(front_depth)))
|
48 |
+
block_ids = block_ids[None, None, :].float()
|
49 |
+
end_mapping_ids = torch.nn.functional.interpolate(block_ids, end_depth)
|
50 |
+
end_mapping_ids = end_mapping_ids.squeeze().long().tolist()
|
51 |
+
|
52 |
+
small_out_indices = []
|
53 |
+
for i, idx in enumerate(end_mapping_ids):
|
54 |
+
if i in out_indices:
|
55 |
+
small_out_indices.append(idx)
|
56 |
+
|
57 |
+
return small_out_indices
|
58 |
+
|
59 |
+
|
60 |
+
def get_stitch_configs_general_unequal(depths):
|
61 |
+
depths = sorted(depths)
|
62 |
+
|
63 |
+
total_configs = []
|
64 |
+
|
65 |
+
# anchor configurations
|
66 |
+
total_configs.append({'comb_id': [1], })
|
67 |
+
num_stitches = depths[0]
|
68 |
+
for i, blk_id in enumerate(range(num_stitches)):
|
69 |
+
total_configs.append({
|
70 |
+
'comb_id': (0, 1),
|
71 |
+
'stitch_cfgs': (i, (i + 1) * (depths[1] // depths[0]))
|
72 |
+
})
|
73 |
+
return total_configs, num_stitches
|
74 |
+
|
75 |
+
def get_stitch_configs_bidirection(depths):
|
76 |
+
depths = sorted(depths)
|
77 |
+
|
78 |
+
total_configs = []
|
79 |
+
|
80 |
+
# anchor configurations
|
81 |
+
total_configs.append({'comb_id': [0], })
|
82 |
+
total_configs.append({'comb_id': [1], })
|
83 |
+
|
84 |
+
num_stitches = depths[0]
|
85 |
+
|
86 |
+
# small --> large
|
87 |
+
sl_configs = []
|
88 |
+
for i, blk_id in enumerate(range(num_stitches)):
|
89 |
+
sl_configs.append({
|
90 |
+
'comb_id': [0, 1],
|
91 |
+
'stitch_cfgs': [
|
92 |
+
[i, (i + 1) * (depths[1] // depths[0])]
|
93 |
+
],
|
94 |
+
'stitch_layer_ids': [i]
|
95 |
+
})
|
96 |
+
|
97 |
+
ls_configs = []
|
98 |
+
lsl_confgs = []
|
99 |
+
block_ids = torch.tensor(list(range(depths[0])))
|
100 |
+
block_ids = block_ids[None, None, :].float()
|
101 |
+
end_mapping_ids = torch.nn.functional.interpolate(block_ids, depths[1])
|
102 |
+
end_mapping_ids = end_mapping_ids.squeeze().long().tolist()
|
103 |
+
|
104 |
+
# large --> small
|
105 |
+
for i in range(depths[1]):
|
106 |
+
if depths[1] != depths[0]:
|
107 |
+
if i % 2 == 1 and i < (depths[1] - 1):
|
108 |
+
ls_configs.append({
|
109 |
+
'comb_id': [1, 0],
|
110 |
+
'stitch_cfgs': [[i, end_mapping_ids[i] + 1]],
|
111 |
+
'stitch_layer_ids': [i // (depths[1] // depths[0])]
|
112 |
+
})
|
113 |
+
else:
|
114 |
+
if i < (depths[1] - 1):
|
115 |
+
ls_configs.append({
|
116 |
+
'comb_id': [1, 0],
|
117 |
+
'stitch_cfgs': [[i, end_mapping_ids[i] + 1]],
|
118 |
+
'stitch_layer_ids': [i // (depths[1] // depths[0])]
|
119 |
+
})
|
120 |
+
|
121 |
+
|
122 |
+
# large --> small --> large
|
123 |
+
for ls_cfg in ls_configs:
|
124 |
+
for sl_cfg in sl_configs:
|
125 |
+
if sl_cfg['stitch_layer_ids'][0] == depths[0] - 1:
|
126 |
+
continue
|
127 |
+
if sl_cfg['stitch_cfgs'][0][0] >= ls_cfg['stitch_cfgs'][0][1]:
|
128 |
+
lsl_confgs.append({
|
129 |
+
'comb_id': [1, 0, 1],
|
130 |
+
'stitch_cfgs': [ls_cfg['stitch_cfgs'][0], sl_cfg['stitch_cfgs'][0]],
|
131 |
+
'stitch_layer_ids': ls_cfg['stitch_layer_ids'] + sl_cfg['stitch_layer_ids']
|
132 |
+
})
|
133 |
+
|
134 |
+
# small --> large --> small
|
135 |
+
sls_configs = []
|
136 |
+
for sl_cfg in sl_configs:
|
137 |
+
for ls_cfg in ls_configs:
|
138 |
+
if ls_cfg['stitch_cfgs'][0][0] >= sl_cfg['stitch_cfgs'][0][1]:
|
139 |
+
sls_configs.append({
|
140 |
+
'comb_id': [0, 1, 0],
|
141 |
+
'stitch_cfgs': [sl_cfg['stitch_cfgs'][0], ls_cfg['stitch_cfgs'][0]],
|
142 |
+
'stitch_layer_ids': sl_cfg['stitch_layer_ids'] + ls_cfg['stitch_layer_ids']
|
143 |
+
})
|
144 |
+
|
145 |
+
total_configs += sl_configs + ls_configs + lsl_confgs + sls_configs
|
146 |
+
|
147 |
+
anchor_ids = []
|
148 |
+
sl_ids = []
|
149 |
+
ls_ids = []
|
150 |
+
lsl_ids = []
|
151 |
+
sls_ids = []
|
152 |
+
|
153 |
+
for i, cfg in enumerate(total_configs):
|
154 |
+
comb_id = cfg['comb_id']
|
155 |
+
|
156 |
+
if len(comb_id) == 1:
|
157 |
+
anchor_ids.append(i)
|
158 |
+
continue
|
159 |
+
|
160 |
+
if len(comb_id) == 2:
|
161 |
+
route = []
|
162 |
+
front, end = cfg['stitch_cfgs'][0]
|
163 |
+
route.append([0, front])
|
164 |
+
route.append([end, depths[comb_id[-1]]])
|
165 |
+
cfg['route'] = route
|
166 |
+
|
167 |
+
if comb_id == [0, 1] and front != 11:
|
168 |
+
sl_ids.append(i)
|
169 |
+
elif comb_id == [1, 0]:
|
170 |
+
ls_ids.append(i)
|
171 |
+
|
172 |
+
if len(comb_id) == 3:
|
173 |
+
route = []
|
174 |
+
front_1, end_1 = cfg['stitch_cfgs'][0]
|
175 |
+
front_2, end_2 = cfg['stitch_cfgs'][1]
|
176 |
+
route.append([0, front_1])
|
177 |
+
route.append([end_1, front_2])
|
178 |
+
route.append([end_2, depths[comb_id[-1]]])
|
179 |
+
cfg['route'] = route
|
180 |
+
|
181 |
+
if comb_id == [1, 0, 1]:
|
182 |
+
lsl_ids.append(i)
|
183 |
+
elif comb_id == [0, 1, 0]:
|
184 |
+
sls_ids.append(i)
|
185 |
+
|
186 |
+
cfg['stitch_layer_ids'].append(-1)
|
187 |
+
|
188 |
+
model_combos = [(0, 1), (1, 0)]
|
189 |
+
return total_configs, model_combos, [len(sl_configs), len(ls_configs)], anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids
|
190 |
+
|
191 |
+
|
192 |
+
def format_out_features(outs, with_cls_token, hw_shape):
|
193 |
+
B, _, C = outs[0].shape
|
194 |
+
for i in range(len(outs)):
|
195 |
+
if with_cls_token:
|
196 |
+
# Remove class token and reshape token for decoder head
|
197 |
+
outs[i] = outs[i][:, 1:].reshape(B, hw_shape[0], hw_shape[1],
|
198 |
+
C).permute(0, 3, 1, 2).contiguous()
|
199 |
+
else:
|
200 |
+
outs[i] = outs[i].reshape(B, hw_shape[0], hw_shape[1],
|
201 |
+
C).permute(0, 3, 1, 2).contiguous()
|
202 |
+
return outs
|
203 |
+
|
204 |
+
|
205 |
+
class LoRALayer():
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
r: int,
|
209 |
+
lora_alpha: int,
|
210 |
+
lora_dropout: float,
|
211 |
+
merge_weights: bool,
|
212 |
+
):
|
213 |
+
self.r = r
|
214 |
+
self.lora_alpha = lora_alpha
|
215 |
+
# Optional dropout
|
216 |
+
if lora_dropout > 0.:
|
217 |
+
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
218 |
+
else:
|
219 |
+
self.lora_dropout = lambda x: x
|
220 |
+
# Mark the weight as unmerged
|
221 |
+
self.merged = False
|
222 |
+
self.merge_weights = merge_weights
|
223 |
+
|
224 |
+
class Linear(nn.Linear, LoRALayer):
|
225 |
+
# LoRA implemented in a dense layer
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
in_features: int,
|
229 |
+
out_features: int,
|
230 |
+
r: int = 0,
|
231 |
+
lora_alpha: int = 1,
|
232 |
+
lora_dropout: float = 0.,
|
233 |
+
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
234 |
+
merge_weights: bool = True,
|
235 |
+
**kwargs
|
236 |
+
):
|
237 |
+
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
238 |
+
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
239 |
+
merge_weights=merge_weights)
|
240 |
+
|
241 |
+
self.fan_in_fan_out = fan_in_fan_out
|
242 |
+
# Actual trainable parameters
|
243 |
+
if r > 0:
|
244 |
+
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
245 |
+
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
246 |
+
self.scaling = self.lora_alpha / self.r
|
247 |
+
# Freezing the pre-trained weight matrix
|
248 |
+
self.weight.requires_grad = False
|
249 |
+
self.reset_parameters()
|
250 |
+
if fan_in_fan_out:
|
251 |
+
self.weight.data = self.weight.data.transpose(0, 1)
|
252 |
+
|
253 |
+
def reset_parameters(self):
|
254 |
+
nn.Linear.reset_parameters(self)
|
255 |
+
if hasattr(self, 'lora_A'):
|
256 |
+
# initialize A the same way as the default for nn.Linear and B to zero
|
257 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
258 |
+
nn.init.zeros_(self.lora_B)
|
259 |
+
|
260 |
+
def train(self, mode: bool = True):
|
261 |
+
def T(w):
|
262 |
+
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
263 |
+
nn.Linear.train(self, mode)
|
264 |
+
if mode:
|
265 |
+
if self.merge_weights and self.merged:
|
266 |
+
# Make sure that the weights are not merged
|
267 |
+
if self.r > 0:
|
268 |
+
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
269 |
+
self.merged = False
|
270 |
+
else:
|
271 |
+
if self.merge_weights and not self.merged:
|
272 |
+
# Merge the weights and mark it
|
273 |
+
if self.r > 0:
|
274 |
+
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
275 |
+
self.merged = True
|
276 |
+
|
277 |
+
def forward(self, x: torch.Tensor):
|
278 |
+
def T(w):
|
279 |
+
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
280 |
+
if self.r > 0 and not self.merged:
|
281 |
+
result = F.linear(x, T(self.weight), bias=self.bias)
|
282 |
+
if self.r > 0:
|
283 |
+
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
|
284 |
+
return result
|
285 |
+
else:
|
286 |
+
return F.linear(x, T(self.weight), bias=self.bias)
|
287 |
+
|
288 |
+
|
289 |
+
class StitchingLayer(nn.Module):
|
290 |
+
def __init__(self, in_features=None, out_features=None, r=0):
|
291 |
+
super().__init__()
|
292 |
+
self.transform = Linear(in_features, out_features, r=r)
|
293 |
+
|
294 |
+
def init_stitch_weights_bias(self, weight, bias):
|
295 |
+
self.transform.weight.data.copy_(weight)
|
296 |
+
self.transform.bias.data.copy_(bias)
|
297 |
+
|
298 |
+
def forward(self, x):
|
299 |
+
out = self.transform(x)
|
300 |
+
return out
|
301 |
+
|
302 |
+
|
303 |
+
class SNNet(nn.Module):
|
304 |
+
|
305 |
+
def __init__(self, anchors=None):
|
306 |
+
super(SNNet, self).__init__()
|
307 |
+
self.anchors = nn.ModuleList(anchors)
|
308 |
+
|
309 |
+
self.depths = [len(anc.blocks) for anc in self.anchors]
|
310 |
+
|
311 |
+
total_configs, num_stitches = get_stitch_configs_general_unequal(self.depths)
|
312 |
+
self.stitch_layers = nn.ModuleList(
|
313 |
+
[StitchingLayer(self.anchors[0].embed_dim, self.anchors[1].embed_dim) for _ in range(num_stitches)])
|
314 |
+
|
315 |
+
self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)}
|
316 |
+
self.all_cfgs = list(self.stitch_configs.keys())
|
317 |
+
self.num_configs = len(self.all_cfgs)
|
318 |
+
self.stitch_config_id = 0
|
319 |
+
self.is_ranking = False
|
320 |
+
|
321 |
+
def reset_stitch_id(self, stitch_config_id):
|
322 |
+
self.stitch_config_id = stitch_config_id
|
323 |
+
|
324 |
+
def initialize_stitching_weights(self, x):
|
325 |
+
logger = get_root_logger()
|
326 |
+
front, end = 0, 1
|
327 |
+
with torch.no_grad():
|
328 |
+
front_features = self.anchors[front].extract_block_features(x)
|
329 |
+
end_features = self.anchors[end].extract_block_features(x)
|
330 |
+
|
331 |
+
for i, blk_id in enumerate(range(self.depths[0])):
|
332 |
+
front_id, end_id = i, (i + 1) * (self.depths[1] // self.depths[0])
|
333 |
+
front_blk_feat = front_features[front_id]
|
334 |
+
end_blk_feat = end_features[end_id - 1]
|
335 |
+
w, b = ps_inv(front_blk_feat, end_blk_feat)
|
336 |
+
self.stitch_layers[i].init_stitch_weights_bias(w, b)
|
337 |
+
logger.info(f'Initialized Stitching Model {front} to Model {end}, Layer {i}')
|
338 |
+
|
339 |
+
def init_weights(self):
|
340 |
+
for anc in self.anchors:
|
341 |
+
anc.init_weights()
|
342 |
+
|
343 |
+
def sampling_stitch_config(self):
|
344 |
+
self.stitch_config_id = np.random.choice(self.all_cfgs)
|
345 |
+
|
346 |
+
def forward(self, x):
|
347 |
+
|
348 |
+
stitch_cfg_id = self.stitch_config_id
|
349 |
+
comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']
|
350 |
+
|
351 |
+
if len(comb_id) == 1:
|
352 |
+
return self.anchors[comb_id[0]](x)
|
353 |
+
|
354 |
+
cfg = self.stitch_configs[stitch_cfg_id]['stitch_cfgs']
|
355 |
+
|
356 |
+
x = self.anchors[comb_id[0]].forward_until(x, blk_id=cfg[0])
|
357 |
+
x = self.stitch_layers[cfg[0]](x)
|
358 |
+
x = self.anchors[comb_id[1]].forward_from(x, blk_id=cfg[1])
|
359 |
+
|
360 |
+
return x
|
361 |
+
|
362 |
+
|
363 |
+
class SNNetv2(nn.Module):
|
364 |
+
|
365 |
+
def __init__(self, anchors=None, include_sl=True, include_ls=True, include_lsl=True, include_sls=True, lora_r=0):
|
366 |
+
super(SNNetv2, self).__init__()
|
367 |
+
self.anchors = nn.ModuleList(anchors)
|
368 |
+
|
369 |
+
self.lora_r = lora_r
|
370 |
+
|
371 |
+
self.depths = [len(anc.blocks) for anc in self.anchors]
|
372 |
+
|
373 |
+
total_configs, model_combos, num_stitches, anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids = get_stitch_configs_bidirection(self.depths)
|
374 |
+
|
375 |
+
self.stitch_layers = nn.ModuleList()
|
376 |
+
self.stitching_map_id = {}
|
377 |
+
|
378 |
+
for i, (comb, num_sth) in enumerate(zip(model_combos, num_stitches)):
|
379 |
+
front, end = comb
|
380 |
+
temp = nn.ModuleList(
|
381 |
+
[StitchingLayer(self.anchors[front].embed_dim, self.anchors[end].embed_dim, r=lora_r) for _ in range(num_sth)])
|
382 |
+
temp.append(nn.Identity())
|
383 |
+
self.stitch_layers.append(temp)
|
384 |
+
|
385 |
+
self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)}
|
386 |
+
self.stitch_init_configs = {i: cfg for i, cfg in enumerate(total_configs) if len(cfg['comb_id']) == 2}
|
387 |
+
|
388 |
+
|
389 |
+
self.all_cfgs = list(self.stitch_configs.keys())
|
390 |
+
logger = get_root_logger()
|
391 |
+
logger.info(str(self.all_cfgs))
|
392 |
+
|
393 |
+
|
394 |
+
self.all_cfgs = anchor_ids
|
395 |
+
|
396 |
+
if include_sl:
|
397 |
+
self.all_cfgs += sl_ids
|
398 |
+
|
399 |
+
if include_ls:
|
400 |
+
self.all_cfgs += ls_ids
|
401 |
+
|
402 |
+
if include_lsl:
|
403 |
+
self.all_cfgs += lsl_ids
|
404 |
+
|
405 |
+
if include_sls:
|
406 |
+
self.all_cfgs += sls_ids
|
407 |
+
|
408 |
+
self.num_configs = len(self.stitch_configs)
|
409 |
+
self.stitch_config_id = 0
|
410 |
+
|
411 |
+
def reset_stitch_id(self, stitch_config_id):
|
412 |
+
self.stitch_config_id = stitch_config_id
|
413 |
+
|
414 |
+
def set_ranking_mode(self, ranking_mode):
|
415 |
+
self.is_ranking = ranking_mode
|
416 |
+
|
417 |
+
def initialize_stitching_weights(self, x):
|
418 |
+
logger = get_root_logger()
|
419 |
+
anchor_features = []
|
420 |
+
for anchor in self.anchors:
|
421 |
+
with torch.no_grad():
|
422 |
+
temp = anchor.extract_block_features(x)
|
423 |
+
anchor_features.append(temp)
|
424 |
+
|
425 |
+
for idx, cfg in self.stitch_init_configs.items():
|
426 |
+
comb_id = cfg['comb_id']
|
427 |
+
if len(comb_id) == 2:
|
428 |
+
front_id, end_id = cfg['stitch_cfgs'][0]
|
429 |
+
stitch_layer_id = cfg['stitch_layer_ids'][0]
|
430 |
+
front_blk_feat = anchor_features[comb_id[0]][front_id]
|
431 |
+
end_blk_feat = anchor_features[comb_id[1]][end_id - 1]
|
432 |
+
w, b = ps_inv(front_blk_feat, end_blk_feat)
|
433 |
+
self.stitch_layers[comb_id[0]][stitch_layer_id].init_stitch_weights_bias(w, b)
|
434 |
+
logger.info(f'Initialized Stitching Layer {cfg}')
|
435 |
+
|
436 |
+
def init_weights(self):
|
437 |
+
for anc in self.anchors:
|
438 |
+
anc.init_weights()
|
439 |
+
|
440 |
+
|
441 |
+
def sampling_stitch_config(self):
|
442 |
+
flops_id = np.random.choice(len(self.flops_grouped_cfgs), p=self.flops_sampling_probs)
|
443 |
+
stitch_config_id = np.random.choice(self.flops_grouped_cfgs[flops_id])
|
444 |
+
return stitch_config_id
|
445 |
+
|
446 |
+
def forward(self, x):
|
447 |
+
|
448 |
+
if self.training:
|
449 |
+
stitch_cfg_id = self.sampling_stitch_config()
|
450 |
+
else:
|
451 |
+
stitch_cfg_id = self.stitch_config_id
|
452 |
+
|
453 |
+
comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']
|
454 |
+
|
455 |
+
# forward by a single anchor
|
456 |
+
if len(comb_id) == 1:
|
457 |
+
return self.anchors[comb_id[0]](x)
|
458 |
+
|
459 |
+
# forward among anchors
|
460 |
+
route = self.stitch_configs[stitch_cfg_id]['route']
|
461 |
+
stitch_layer_ids = self.stitch_configs[stitch_cfg_id]['stitch_layer_ids']
|
462 |
+
|
463 |
+
# patch embeding
|
464 |
+
x = self.anchors[comb_id[0]].forward_patch_embed(x)
|
465 |
+
|
466 |
+
for i, (model_id, cfg) in enumerate(zip(comb_id, route)):
|
467 |
+
|
468 |
+
x = self.anchors[model_id].selective_forward(x, cfg[0], cfg[1])
|
469 |
+
x = self.stitch_layers[model_id][stitch_layer_ids[i]](x)
|
470 |
+
|
471 |
+
x = self.anchors[comb_id[-1]].forward_norm_head(x)
|
472 |
+
return x
|
473 |
+
|
snnetv2_deit3_s_l.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d455f17d73f4ed74702076d4cea516194d8c4aa8fbbc63192f85795f79c76b4
|
3 |
+
size 1350494458
|
stitches_res_s_l.txt
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"loss": 0.7156664722345092, "acc1": 82.9060024609375, "acc5": 96.73400244140625, "cfg_id": 0, "flops": 4608338304}
|
2 |
+
{"loss": 0.5377805712209507, "acc1": 86.97800256835937, "acc5": 98.2540023046875, "cfg_id": 1, "flops": 61604135936}
|
3 |
+
{"loss": 0.5598483879796483, "acc1": 86.57800241210937, "acc5": 98.08200240234375, "cfg_id": 2, "flops": 56843745792}
|
4 |
+
{"loss": 0.5534007405354218, "acc1": 86.6480025390625, "acc5": 98.1760025390625, "cfg_id": 3, "flops": 52102230016}
|
5 |
+
{"loss": 0.5610568577028585, "acc1": 86.49800245117187, "acc5": 98.06600229492187, "cfg_id": 4, "flops": 47360714240}
|
6 |
+
{"loss": 0.5747850706067049, "acc1": 86.26000259765625, "acc5": 97.93800240234376, "cfg_id": 5, "flops": 42619198464}
|
7 |
+
{"loss": 0.5890085864812136, "acc1": 85.79200244140625, "acc5": 97.80000272460937, "cfg_id": 6, "flops": 37877682688}
|
8 |
+
{"loss": 0.6165087098876635, "acc1": 85.08200264648437, "acc5": 97.55600231445312, "cfg_id": 7, "flops": 33136166912}
|
9 |
+
{"loss": 0.6652509210574807, "acc1": 83.69200263671875, "acc5": 97.23600259765625, "cfg_id": 8, "flops": 28394651136}
|
10 |
+
{"loss": 0.7374675334290122, "acc1": 81.7120026171875, "acc5": 96.53200251953125, "cfg_id": 9, "flops": 23653135360}
|
11 |
+
{"loss": 0.7991558508665273, "acc1": 79.50600241210938, "acc5": 95.90200240234375, "cfg_id": 10, "flops": 18911619584}
|
12 |
+
{"loss": 0.7554851990531791, "acc1": 80.63600265625, "acc5": 96.09000245117187, "cfg_id": 11, "flops": 14170103808}
|
13 |
+
{"loss": 0.7068120487824534, "acc1": 82.25000237304687, "acc5": 96.35600284179688, "cfg_id": 12, "flops": 9428588032}
|
14 |
+
{"loss": 0.7329587066038088, "acc1": 82.6600027734375, "acc5": 96.58200264648437, "cfg_id": 14, "flops": 9523655552}
|
15 |
+
{"loss": 0.7238117807516546, "acc1": 82.94800252929687, "acc5": 96.68600232421875, "cfg_id": 15, "flops": 14265171328}
|
16 |
+
{"loss": 0.7139950410434694, "acc1": 83.0860026953125, "acc5": 96.75800252929687, "cfg_id": 16, "flops": 19006687104}
|
17 |
+
{"loss": 0.7004092067028537, "acc1": 83.25400249023437, "acc5": 96.8740026171875, "cfg_id": 17, "flops": 23748202880}
|
18 |
+
{"loss": 0.6828147762201049, "acc1": 83.45000244140626, "acc5": 96.9520026171875, "cfg_id": 18, "flops": 28489718656}
|
19 |
+
{"loss": 0.6787144099221085, "acc1": 83.56400258789063, "acc5": 97.0600024609375, "cfg_id": 19, "flops": 33231234432}
|
20 |
+
{"loss": 0.6765228407175252, "acc1": 83.43400251953125, "acc5": 97.19200266601563, "cfg_id": 20, "flops": 37972750208}
|
21 |
+
{"loss": 0.6841061733888857, "acc1": 83.5900022265625, "acc5": 97.20800275390626, "cfg_id": 21, "flops": 42714265984}
|
22 |
+
{"loss": 0.6446758140104286, "acc1": 84.8660023828125, "acc5": 97.44400258789062, "cfg_id": 22, "flops": 47455781760}
|
23 |
+
{"loss": 0.5939652780917558, "acc1": 86.23000265625, "acc5": 97.69200270507812, "cfg_id": 23, "flops": 52197297536}
|
24 |
+
{"loss": 0.5654762382760192, "acc1": 86.43400250976562, "acc5": 97.632002578125, "cfg_id": 24, "flops": 56938813312}
|
25 |
+
{"loss": 0.5636055112788172, "acc1": 86.39000270507813, "acc5": 98.04800252929688, "cfg_id": 25, "flops": 57017547264}
|
26 |
+
{"loss": 0.5706944450397383, "acc1": 86.234002578125, "acc5": 98.00000237304687, "cfg_id": 26, "flops": 52276031488}
|
27 |
+
{"loss": 0.5833309799658529, "acc1": 85.9240025390625, "acc5": 97.9260024609375, "cfg_id": 27, "flops": 47534515712}
|
28 |
+
{"loss": 0.5972222860225223, "acc1": 85.57400262695313, "acc5": 97.70800255859375, "cfg_id": 28, "flops": 42792999936}
|
29 |
+
{"loss": 0.6253456006560362, "acc1": 84.89800259765624, "acc5": 97.47400255859375, "cfg_id": 29, "flops": 38051484160}
|
30 |
+
{"loss": 0.6745385262889393, "acc1": 83.5380026171875, "acc5": 97.07600244140625, "cfg_id": 30, "flops": 33309968384}
|
31 |
+
{"loss": 0.7486309014034994, "acc1": 81.42600245117187, "acc5": 96.33600258789062, "cfg_id": 31, "flops": 28568452608}
|
32 |
+
{"loss": 0.8134756960877867, "acc1": 79.16000235351562, "acc5": 95.72400271484375, "cfg_id": 32, "flops": 23826936832}
|
33 |
+
{"loss": 0.7671100513050051, "acc1": 80.37200240234375, "acc5": 95.98400258789063, "cfg_id": 33, "flops": 19085421056}
|
34 |
+
{"loss": 0.7206548866674756, "acc1": 81.91000239257812, "acc5": 96.23800239257812, "cfg_id": 34, "flops": 14343905280}
|
35 |
+
{"loss": 0.5626872230998494, "acc1": 86.44600235351562, "acc5": 98.062002421875, "cfg_id": 35, "flops": 57017547264}
|
36 |
+
{"loss": 0.5785287711769342, "acc1": 86.06400251953124, "acc5": 97.9420023046875, "cfg_id": 36, "flops": 52276031488}
|
37 |
+
{"loss": 0.5930487287202568, "acc1": 85.78400234375, "acc5": 97.79000255859376, "cfg_id": 37, "flops": 47534515712}
|
38 |
+
{"loss": 0.6189901619923838, "acc1": 85.10800268554688, "acc5": 97.50400228515625, "cfg_id": 38, "flops": 42792999936}
|
39 |
+
{"loss": 0.6674688318462083, "acc1": 83.76600272460938, "acc5": 97.09400264648437, "cfg_id": 39, "flops": 38051484160}
|
40 |
+
{"loss": 0.7388352820593299, "acc1": 81.70200266601563, "acc5": 96.47000241210938, "cfg_id": 40, "flops": 33309968384}
|
41 |
+
{"loss": 0.803126322613521, "acc1": 79.4560025390625, "acc5": 95.81400245117187, "cfg_id": 41, "flops": 28568452608}
|
42 |
+
{"loss": 0.7581946616145697, "acc1": 80.70800243164062, "acc5": 96.08600255859375, "cfg_id": 42, "flops": 23826936832}
|
43 |
+
{"loss": 0.7118472667467414, "acc1": 82.22600248046875, "acc5": 96.31000268554688, "cfg_id": 43, "flops": 19085421056}
|
44 |
+
{"loss": 0.5727639499713074, "acc1": 86.2180025, "acc5": 97.98200247070312, "cfg_id": 44, "flops": 57017547264}
|
45 |
+
{"loss": 0.5866389607615543, "acc1": 85.84400263671876, "acc5": 97.8600024609375, "cfg_id": 45, "flops": 52276031488}
|
46 |
+
{"loss": 0.6107792718279542, "acc1": 85.19800255859376, "acc5": 97.61800235351562, "cfg_id": 46, "flops": 47534515712}
|
47 |
+
{"loss": 0.6602028349809574, "acc1": 83.92600266601562, "acc5": 97.23800282226563, "cfg_id": 47, "flops": 42792999936}
|
48 |
+
{"loss": 0.7285334389431007, "acc1": 82.0040028125, "acc5": 96.52400247070312, "cfg_id": 48, "flops": 38051484160}
|
49 |
+
{"loss": 0.7910783413910505, "acc1": 79.69600262695313, "acc5": 95.95800241210938, "cfg_id": 49, "flops": 33309968384}
|
50 |
+
{"loss": 0.7478298004152197, "acc1": 80.89400243164063, "acc5": 96.172002421875, "cfg_id": 50, "flops": 28568452608}
|
51 |
+
{"loss": 0.7014034044449077, "acc1": 82.45600264648438, "acc5": 96.438002421875, "cfg_id": 51, "flops": 23826936832}
|
52 |
+
{"loss": 0.5799332931637764, "acc1": 85.92400239257813, "acc5": 97.94000249023438, "cfg_id": 52, "flops": 57017547264}
|
53 |
+
{"loss": 0.6004864230300441, "acc1": 85.43800255859375, "acc5": 97.70800227539063, "cfg_id": 53, "flops": 52276031488}
|
54 |
+
{"loss": 0.647012604287628, "acc1": 84.20200264648437, "acc5": 97.30600264648437, "cfg_id": 54, "flops": 47534515712}
|
55 |
+
{"loss": 0.7162722434961435, "acc1": 82.29000248046874, "acc5": 96.6640023046875, "cfg_id": 55, "flops": 42792999936}
|
56 |
+
{"loss": 0.7757266998065241, "acc1": 79.9760025, "acc5": 96.050002421875, "cfg_id": 56, "flops": 38051484160}
|
57 |
+
{"loss": 0.7351311787285588, "acc1": 81.04400232421875, "acc5": 96.2400026953125, "cfg_id": 57, "flops": 33309968384}
|
58 |
+
{"loss": 0.6896895027408997, "acc1": 82.6220026171875, "acc5": 96.55400252929688, "cfg_id": 58, "flops": 28568452608}
|
59 |
+
{"loss": 0.5911911701727094, "acc1": 85.53000266601562, "acc5": 97.76600241210937, "cfg_id": 59, "flops": 57017547264}
|
60 |
+
{"loss": 0.6371258264125297, "acc1": 84.41200249023437, "acc5": 97.44200245117187, "cfg_id": 60, "flops": 52276031488}
|
61 |
+
{"loss": 0.7022040815403064, "acc1": 82.49400240234375, "acc5": 96.74400272460937, "cfg_id": 61, "flops": 47534515712}
|
62 |
+
{"loss": 0.7612808859257987, "acc1": 80.29200265625, "acc5": 96.15800239257813, "cfg_id": 62, "flops": 42792999936}
|
63 |
+
{"loss": 0.7246641330420971, "acc1": 81.20400250976563, "acc5": 96.42400264648437, "cfg_id": 63, "flops": 38051484160}
|
64 |
+
{"loss": 0.6782861414619468, "acc1": 82.8040024609375, "acc5": 96.60800270507812, "cfg_id": 64, "flops": 33309968384}
|
65 |
+
{"loss": 0.629801401755575, "acc1": 84.65200262695312, "acc5": 97.54200265625, "cfg_id": 65, "flops": 57017547264}
|
66 |
+
{"loss": 0.6992729283643492, "acc1": 82.58200259765626, "acc5": 96.85600262695313, "cfg_id": 66, "flops": 52276031488}
|
67 |
+
{"loss": 0.7595290538262237, "acc1": 80.35600247070313, "acc5": 96.27000247070312, "cfg_id": 67, "flops": 47534515712}
|
68 |
+
{"loss": 0.7238247728709019, "acc1": 81.37600248046876, "acc5": 96.46200267578125, "cfg_id": 68, "flops": 42792999936}
|
69 |
+
{"loss": 0.6760879844765771, "acc1": 82.96800264648438, "acc5": 96.69400274414062, "cfg_id": 69, "flops": 38051484160}
|
70 |
+
{"loss": 0.68392569430624, "acc1": 83.16200254882813, "acc5": 97.09200258789062, "cfg_id": 70, "flops": 57017547264}
|
71 |
+
{"loss": 0.7509645553249301, "acc1": 80.68000260742187, "acc5": 96.3900025, "cfg_id": 71, "flops": 52276031488}
|
72 |
+
{"loss": 0.7208586267449639, "acc1": 81.55200274414062, "acc5": 96.62200272460937, "cfg_id": 72, "flops": 47534515712}
|
73 |
+
{"loss": 0.6785354860352747, "acc1": 82.86000262695312, "acc5": 96.80000244140625, "cfg_id": 73, "flops": 42792999936}
|
74 |
+
{"loss": 0.7184764705598354, "acc1": 81.61200241210938, "acc5": 96.63200258789063, "cfg_id": 74, "flops": 57017547264}
|
75 |
+
{"loss": 0.7229886900520686, "acc1": 81.45000249023437, "acc5": 96.62200250976562, "cfg_id": 75, "flops": 52276031488}
|
76 |
+
{"loss": 0.6883746685855316, "acc1": 83.01600272460938, "acc5": 96.83200240234375, "cfg_id": 76, "flops": 47534515712}
|
77 |
+
{"loss": 0.6293963799535325, "acc1": 83.90800231445313, "acc5": 97.27400245117188, "cfg_id": 77, "flops": 57017547264}
|
78 |
+
{"loss": 0.642419446824175, "acc1": 84.31400258789063, "acc5": 97.19200287109375, "cfg_id": 78, "flops": 52276031488}
|
79 |
+
{"loss": 0.5880116202275861, "acc1": 85.90600231445312, "acc5": 97.58400263671875, "cfg_id": 79, "flops": 57017547264}
|
80 |
+
{"loss": 0.750676692096573, "acc1": 82.14200271484376, "acc5": 96.36600260742188, "cfg_id": 80, "flops": 9504781184}
|
81 |
+
{"loss": 0.7431871895537232, "acc1": 82.234002578125, "acc5": 96.39200241210938, "cfg_id": 81, "flops": 14246296960}
|
82 |
+
{"loss": 0.7236298957105839, "acc1": 82.62600249023437, "acc5": 96.57600243164063, "cfg_id": 82, "flops": 18987812736}
|
83 |
+
{"loss": 0.7074674766397837, "acc1": 82.84600237304687, "acc5": 96.68800247070313, "cfg_id": 83, "flops": 23729328512}
|
84 |
+
{"loss": 0.7014015182062532, "acc1": 82.99000265625, "acc5": 96.7740025, "cfg_id": 84, "flops": 28470844288}
|
85 |
+
{"loss": 0.6996880258348855, "acc1": 82.98000252929687, "acc5": 96.90200263671875, "cfg_id": 85, "flops": 33212360064}
|
86 |
+
{"loss": 0.7077699161953095, "acc1": 82.96200270507812, "acc5": 96.98600258789062, "cfg_id": 86, "flops": 37953875840}
|
87 |
+
{"loss": 0.6674120087515224, "acc1": 84.33000274414063, "acc5": 97.2640025, "cfg_id": 87, "flops": 42695391616}
|
88 |
+
{"loss": 0.6169534720141779, "acc1": 85.86200280273438, "acc5": 97.49200272460938, "cfg_id": 88, "flops": 47436907392}
|
89 |
+
{"loss": 0.5848360503600403, "acc1": 86.02600271484376, "acc5": 97.4480026171875, "cfg_id": 89, "flops": 52178423168}
|
90 |
+
{"loss": 0.7346750153510859, "acc1": 82.5540027734375, "acc5": 96.52400241210937, "cfg_id": 90, "flops": 9504781184}
|
91 |
+
{"loss": 0.7158081559182117, "acc1": 82.82200255859375, "acc5": 96.65600255859376, "cfg_id": 91, "flops": 14246296960}
|
92 |
+
{"loss": 0.6994372372600165, "acc1": 83.03600239257813, "acc5": 96.74600252929687, "cfg_id": 92, "flops": 18987812736}
|
93 |
+
{"loss": 0.6947964186582601, "acc1": 83.07400255859375, "acc5": 96.88200241210937, "cfg_id": 93, "flops": 23729328512}
|
94 |
+
{"loss": 0.6946824553112189, "acc1": 83.0200026171875, "acc5": 96.99400251953125, "cfg_id": 94, "flops": 28470844288}
|
95 |
+
{"loss": 0.7041463901599249, "acc1": 83.21600236328125, "acc5": 97.00400250976563, "cfg_id": 95, "flops": 33212360064}
|
96 |
+
{"loss": 0.6699620116163384, "acc1": 84.54600258789063, "acc5": 97.31600244140625, "cfg_id": 96, "flops": 37953875840}
|
97 |
+
{"loss": 0.6176637105192199, "acc1": 85.91800228515625, "acc5": 97.57000255859376, "cfg_id": 97, "flops": 42695391616}
|
98 |
+
{"loss": 0.5765587539045196, "acc1": 86.1880023046875, "acc5": 97.54000249023437, "cfg_id": 98, "flops": 47436907392}
|
99 |
+
{"loss": 0.7319535712401072, "acc1": 82.49200258789062, "acc5": 96.46600271484375, "cfg_id": 99, "flops": 9504781184}
|
100 |
+
{"loss": 0.710809505516381, "acc1": 82.7860023046875, "acc5": 96.6260026171875, "cfg_id": 100, "flops": 14246296960}
|
101 |
+
{"loss": 0.7044268037107858, "acc1": 82.93000239257813, "acc5": 96.79800258789062, "cfg_id": 101, "flops": 18987812736}
|
102 |
+
{"loss": 0.7076575808001287, "acc1": 82.83200243164063, "acc5": 96.8820025390625, "cfg_id": 102, "flops": 23729328512}
|
103 |
+
{"loss": 0.7188302328189214, "acc1": 82.88200259765625, "acc5": 96.93600249023437, "cfg_id": 103, "flops": 28470844288}
|
104 |
+
{"loss": 0.6856357377361167, "acc1": 84.2520023046875, "acc5": 97.2200026171875, "cfg_id": 104, "flops": 33212360064}
|
105 |
+
{"loss": 0.6273381847210906, "acc1": 85.758002734375, "acc5": 97.44800275390625, "cfg_id": 105, "flops": 37953875840}
|
106 |
+
{"loss": 0.5830204013847944, "acc1": 86.01000260742188, "acc5": 97.4780026171875, "cfg_id": 106, "flops": 42695391616}
|
107 |
+
{"loss": 0.7305513945492831, "acc1": 82.3180023828125, "acc5": 96.47200256835937, "cfg_id": 107, "flops": 9504781184}
|
108 |
+
{"loss": 0.7206297208639708, "acc1": 82.37200228515626, "acc5": 96.61600234375, "cfg_id": 108, "flops": 14246296960}
|
109 |
+
{"loss": 0.7241401795975186, "acc1": 82.33800244140625, "acc5": 96.74400267578125, "cfg_id": 109, "flops": 18987812736}
|
110 |
+
{"loss": 0.7350799917723193, "acc1": 82.57800231445313, "acc5": 96.78000252929688, "cfg_id": 110, "flops": 23729328512}
|
111 |
+
{"loss": 0.7013292148935072, "acc1": 83.95000255859375, "acc5": 97.17600254882812, "cfg_id": 111, "flops": 28470844288}
|
112 |
+
{"loss": 0.64130035031474, "acc1": 85.54600244140624, "acc5": 97.36800262695313, "cfg_id": 112, "flops": 33212360064}
|
113 |
+
{"loss": 0.5961506485826138, "acc1": 85.76800252929688, "acc5": 97.33400268554688, "cfg_id": 113, "flops": 37953875840}
|
114 |
+
{"loss": 0.7443677056580782, "acc1": 81.86200244140625, "acc5": 96.29200241210937, "cfg_id": 114, "flops": 9504781184}
|
115 |
+
{"loss": 0.7442678388659701, "acc1": 81.82000262695313, "acc5": 96.3900026953125, "cfg_id": 115, "flops": 14246296960}
|
116 |
+
{"loss": 0.749958168037913, "acc1": 81.99400229492187, "acc5": 96.49600264648437, "cfg_id": 116, "flops": 18987812736}
|
117 |
+
{"loss": 0.7073916116672935, "acc1": 83.43200267578125, "acc5": 96.99200250976563, "cfg_id": 117, "flops": 23729328512}
|
118 |
+
{"loss": 0.6501240834706661, "acc1": 85.1480025, "acc5": 97.212002734375, "cfg_id": 118, "flops": 28470844288}
|
119 |
+
{"loss": 0.6135486943477934, "acc1": 85.35400266601563, "acc5": 97.16000249023438, "cfg_id": 119, "flops": 33212360064}
|
120 |
+
{"loss": 0.7824224890633062, "acc1": 81.16600250976562, "acc5": 96.11200255859374, "cfg_id": 120, "flops": 9504781184}
|
121 |
+
{"loss": 0.7932735298844901, "acc1": 80.93600233398438, "acc5": 96.12600254882813, "cfg_id": 121, "flops": 14246296960}
|
122 |
+
{"loss": 0.7476008046757091, "acc1": 82.49800259765625, "acc5": 96.65000265625, "cfg_id": 122, "flops": 18987812736}
|
123 |
+
{"loss": 0.6842290776019747, "acc1": 84.35800275390625, "acc5": 96.94000270507813, "cfg_id": 123, "flops": 23729328512}
|
124 |
+
{"loss": 0.6409159135073423, "acc1": 84.520002578125, "acc5": 96.91800252929687, "cfg_id": 124, "flops": 28470844288}
|
125 |
+
{"loss": 0.8394820786109476, "acc1": 79.96600270507813, "acc5": 95.5980024609375, "cfg_id": 125, "flops": 9504781184}
|
126 |
+
{"loss": 0.7924092794683847, "acc1": 81.1100024609375, "acc5": 96.12600264648438, "cfg_id": 126, "flops": 14246296960}
|
127 |
+
{"loss": 0.724013783997207, "acc1": 83.04000248046874, "acc5": 96.46600255859374, "cfg_id": 127, "flops": 18987812736}
|
128 |
+
{"loss": 0.6948224610338608, "acc1": 83.016002421875, "acc5": 96.44000282226563, "cfg_id": 128, "flops": 23729328512}
|
129 |
+
{"loss": 0.8688964597655066, "acc1": 79.15800264160156, "acc5": 95.2460025, "cfg_id": 129, "flops": 9504781184}
|
130 |
+
{"loss": 0.8095247168658357, "acc1": 80.57800265625, "acc5": 95.63000274414063, "cfg_id": 130, "flops": 14246296960}
|
131 |
+
{"loss": 0.7750140779059042, "acc1": 80.9680026171875, "acc5": 95.68200263671875, "cfg_id": 131, "flops": 18987812736}
|
132 |
+
{"loss": 0.9017180648039688, "acc1": 77.88400251953125, "acc5": 94.78400234375, "cfg_id": 132, "flops": 9504781184}
|
133 |
+
{"loss": 0.8799216277671583, "acc1": 77.84800252929688, "acc5": 94.79800247070312, "cfg_id": 133, "flops": 14246296960}
|
134 |
+
{"loss": 0.8987371790589709, "acc1": 78.12000266601562, "acc5": 94.92200256835937, "cfg_id": 134, "flops": 9504781184}
|
utils.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
"""
|
4 |
+
Misc functions, including distributed helpers.
|
5 |
+
|
6 |
+
Mostly copy-paste from torchvision references.
|
7 |
+
"""
|
8 |
+
import io
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
from collections import defaultdict, deque
|
12 |
+
import datetime
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.distributed as dist
|
16 |
+
import logging
|
17 |
+
|
18 |
+
logger_initialized = {}
|
19 |
+
|
20 |
+
def group_subnets_by_flops(data, flops_gap=1.0):
|
21 |
+
sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])}
|
22 |
+
candidate_idx = []
|
23 |
+
grouped_cands = []
|
24 |
+
last_flops = 0
|
25 |
+
for cfg_id, flops in sorted_data.items():
|
26 |
+
flops = flops / 1e9
|
27 |
+
if abs(last_flops - flops) > flops_gap:
|
28 |
+
if len(candidate_idx) > 0:
|
29 |
+
grouped_cands.append(sorted(candidate_idx))
|
30 |
+
candidate_idx = [int(cfg_id)]
|
31 |
+
last_flops = flops
|
32 |
+
else:
|
33 |
+
candidate_idx.append(int(cfg_id))
|
34 |
+
|
35 |
+
if len(candidate_idx) > 0:
|
36 |
+
grouped_cands.append(sorted(candidate_idx))
|
37 |
+
|
38 |
+
return grouped_cands
|
39 |
+
|
40 |
+
def find_best_candidates(data):
|
41 |
+
sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])}
|
42 |
+
candidate_idx = []
|
43 |
+
last_flops = 0
|
44 |
+
for cfg_id, values in sorted_data.items():
|
45 |
+
flops, score = values
|
46 |
+
if abs(last_flops - flops) > 1:
|
47 |
+
candidate_idx.append(cfg_id)
|
48 |
+
last_flops = flops
|
49 |
+
else:
|
50 |
+
if score > data[candidate_idx[-1]][1]:
|
51 |
+
candidate_idx[-1] = cfg_id
|
52 |
+
|
53 |
+
return candidate_idx
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
def find_top_candidates(data, ratio=0.9):
|
58 |
+
sorted_data = {k: v for k, v in sorted(data.items(), key=lambda item: item[1])}
|
59 |
+
candidate_idx = []
|
60 |
+
grouped_cands = []
|
61 |
+
last_flops = 0
|
62 |
+
for cfg_id, values in sorted_data.items():
|
63 |
+
flops, score = values
|
64 |
+
if abs(last_flops - flops) > 3:
|
65 |
+
if len(candidate_idx) > 0:
|
66 |
+
grouped_cands.append(candidate_idx)
|
67 |
+
candidate_idx = [cfg_id]
|
68 |
+
last_flops = flops
|
69 |
+
else:
|
70 |
+
candidate_idx.append(cfg_id)
|
71 |
+
|
72 |
+
if len(candidate_idx) > 0:
|
73 |
+
grouped_cands.append(candidate_idx)
|
74 |
+
|
75 |
+
final_list = []
|
76 |
+
for group in grouped_cands:
|
77 |
+
if len(group) == 1:
|
78 |
+
final_list += list(map(int, group))
|
79 |
+
continue
|
80 |
+
scores = torch.tensor([sorted_data[cfg_id][-1] for cfg_id in group])
|
81 |
+
|
82 |
+
indices = torch.argsort(scores, descending=True)
|
83 |
+
num_selected = int(ratio*len(group)) if int(ratio*len(group)) > 0 else 1
|
84 |
+
|
85 |
+
top_ids = indices[:num_selected].tolist()
|
86 |
+
selected = [group[idx] for idx in top_ids]
|
87 |
+
final_list += list(map(int, selected))
|
88 |
+
|
89 |
+
return final_list
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
|
94 |
+
"""Initialize and get a logger by name.
|
95 |
+
|
96 |
+
If the logger has not been initialized, this method will initialize the
|
97 |
+
logger by adding one or two handlers, otherwise the initialized logger will
|
98 |
+
be directly returned. During initialization, a StreamHandler will always be
|
99 |
+
added. If `log_file` is specified and the process rank is 0, a FileHandler
|
100 |
+
will also be added.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
name (str): Logger name.
|
104 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
105 |
+
will be added to the logger.
|
106 |
+
log_level (int): The logger level. Note that only the process of
|
107 |
+
rank 0 is affected, and other processes will set the level to
|
108 |
+
"Error" thus be silent most of the time.
|
109 |
+
file_mode (str): The file mode used in opening log file.
|
110 |
+
Defaults to 'w'.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
logging.Logger: The expected logger.
|
114 |
+
"""
|
115 |
+
logger = logging.getLogger(name)
|
116 |
+
if name in logger_initialized:
|
117 |
+
return logger
|
118 |
+
# handle hierarchical names
|
119 |
+
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
120 |
+
# initialization since it is a child of "a".
|
121 |
+
for logger_name in logger_initialized:
|
122 |
+
if name.startswith(logger_name):
|
123 |
+
return logger
|
124 |
+
|
125 |
+
stream_handler = logging.StreamHandler()
|
126 |
+
handlers = [stream_handler]
|
127 |
+
|
128 |
+
if dist.is_available() and dist.is_initialized():
|
129 |
+
rank = dist.get_rank()
|
130 |
+
else:
|
131 |
+
rank = 0
|
132 |
+
|
133 |
+
# only rank 0 will add a FileHandler
|
134 |
+
if rank == 0 and log_file is not None:
|
135 |
+
# Here, the default behaviour of the official logger is 'a'. Thus, we
|
136 |
+
# provide an interface to change the file mode to the default
|
137 |
+
# behaviour.
|
138 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
139 |
+
handlers.append(file_handler)
|
140 |
+
|
141 |
+
formatter = logging.Formatter(
|
142 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
143 |
+
for handler in handlers:
|
144 |
+
handler.setFormatter(formatter)
|
145 |
+
handler.setLevel(log_level)
|
146 |
+
logger.addHandler(handler)
|
147 |
+
|
148 |
+
if rank == 0:
|
149 |
+
logger.setLevel(log_level)
|
150 |
+
else:
|
151 |
+
logger.setLevel(logging.ERROR)
|
152 |
+
|
153 |
+
logger_initialized[name] = True
|
154 |
+
|
155 |
+
return logger
|
156 |
+
|
157 |
+
|
158 |
+
def get_root_logger(log_file=None, log_level=logging.INFO):
|
159 |
+
"""Get the root logger.
|
160 |
+
|
161 |
+
The logger will be initialized if it has not been initialized. By default a
|
162 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
163 |
+
also be added. The name of the root logger is the top-level package name,
|
164 |
+
e.g., "mmseg".
|
165 |
+
|
166 |
+
Args:
|
167 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
168 |
+
will be added to the root logger.
|
169 |
+
log_level (int): The root logger level. Note that only the process of
|
170 |
+
rank 0 is affected, while other processes will set the level to
|
171 |
+
"Error" and be silent most of the time.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
logging.Logger: The root logger.
|
175 |
+
"""
|
176 |
+
|
177 |
+
logger = get_logger(name='snnet', log_file=log_file, log_level=log_level)
|
178 |
+
|
179 |
+
return logger
|
180 |
+
|
181 |
+
class SmoothedValue(object):
|
182 |
+
"""Track a series of values and provide access to smoothed values over a
|
183 |
+
window or the global series average.
|
184 |
+
"""
|
185 |
+
|
186 |
+
def __init__(self, window_size=20, fmt=None):
|
187 |
+
if fmt is None:
|
188 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
189 |
+
self.deque = deque(maxlen=window_size)
|
190 |
+
self.total = 0.0
|
191 |
+
self.count = 0
|
192 |
+
self.fmt = fmt
|
193 |
+
|
194 |
+
def update(self, value, n=1):
|
195 |
+
self.deque.append(value)
|
196 |
+
self.count += n
|
197 |
+
self.total += value * n
|
198 |
+
|
199 |
+
def synchronize_between_processes(self):
|
200 |
+
"""
|
201 |
+
Warning: does not synchronize the deque!
|
202 |
+
"""
|
203 |
+
if not is_dist_avail_and_initialized():
|
204 |
+
return
|
205 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
206 |
+
dist.barrier()
|
207 |
+
dist.all_reduce(t)
|
208 |
+
t = t.tolist()
|
209 |
+
self.count = int(t[0])
|
210 |
+
self.total = t[1]
|
211 |
+
|
212 |
+
@property
|
213 |
+
def median(self):
|
214 |
+
d = torch.tensor(list(self.deque))
|
215 |
+
return d.median().item()
|
216 |
+
|
217 |
+
@property
|
218 |
+
def avg(self):
|
219 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
220 |
+
return d.mean().item()
|
221 |
+
|
222 |
+
@property
|
223 |
+
def global_avg(self):
|
224 |
+
return self.total / self.count
|
225 |
+
|
226 |
+
@property
|
227 |
+
def max(self):
|
228 |
+
return max(self.deque)
|
229 |
+
|
230 |
+
@property
|
231 |
+
def value(self):
|
232 |
+
return self.deque[-1]
|
233 |
+
|
234 |
+
def __str__(self):
|
235 |
+
return self.fmt.format(
|
236 |
+
median=self.median,
|
237 |
+
avg=self.avg,
|
238 |
+
global_avg=self.global_avg,
|
239 |
+
max=self.max,
|
240 |
+
value=self.value)
|
241 |
+
|
242 |
+
|
243 |
+
class MetricLogger(object):
|
244 |
+
def __init__(self, delimiter="\t", logger=None):
|
245 |
+
self.meters = defaultdict(SmoothedValue)
|
246 |
+
self.delimiter = delimiter
|
247 |
+
self.logger = logger
|
248 |
+
|
249 |
+
def update(self, **kwargs):
|
250 |
+
for k, v in kwargs.items():
|
251 |
+
if isinstance(v, torch.Tensor):
|
252 |
+
v = v.item()
|
253 |
+
assert isinstance(v, (float, int))
|
254 |
+
self.meters[k].update(v)
|
255 |
+
|
256 |
+
def __getattr__(self, attr):
|
257 |
+
if attr in self.meters:
|
258 |
+
return self.meters[attr]
|
259 |
+
if attr in self.__dict__:
|
260 |
+
return self.__dict__[attr]
|
261 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
262 |
+
type(self).__name__, attr))
|
263 |
+
|
264 |
+
def __str__(self):
|
265 |
+
loss_str = []
|
266 |
+
for name, meter in self.meters.items():
|
267 |
+
loss_str.append(
|
268 |
+
"{}: {}".format(name, str(meter))
|
269 |
+
)
|
270 |
+
return self.delimiter.join(loss_str)
|
271 |
+
|
272 |
+
def synchronize_between_processes(self):
|
273 |
+
for meter in self.meters.values():
|
274 |
+
meter.synchronize_between_processes()
|
275 |
+
|
276 |
+
def add_meter(self, name, meter):
|
277 |
+
self.meters[name] = meter
|
278 |
+
|
279 |
+
def log_every(self, iterable, print_freq, header=None):
|
280 |
+
i = 0
|
281 |
+
if not header:
|
282 |
+
header = ''
|
283 |
+
start_time = time.time()
|
284 |
+
end = time.time()
|
285 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
286 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
287 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
288 |
+
log_msg = [
|
289 |
+
header,
|
290 |
+
'[{0' + space_fmt + '}/{1}]',
|
291 |
+
'eta: {eta}',
|
292 |
+
'{meters}',
|
293 |
+
'time: {time}',
|
294 |
+
'data: {data}'
|
295 |
+
]
|
296 |
+
if torch.cuda.is_available():
|
297 |
+
log_msg.append('max mem: {memory:.0f}')
|
298 |
+
log_msg = self.delimiter.join(log_msg)
|
299 |
+
MB = 1024.0 * 1024.0
|
300 |
+
for obj in iterable:
|
301 |
+
data_time.update(time.time() - end)
|
302 |
+
yield obj
|
303 |
+
iter_time.update(time.time() - end)
|
304 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
305 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
306 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
307 |
+
if torch.cuda.is_available():
|
308 |
+
self.logger.info(log_msg.format(
|
309 |
+
i, len(iterable), eta=eta_string,
|
310 |
+
meters=str(self),
|
311 |
+
time=str(iter_time), data=str(data_time),
|
312 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
313 |
+
else:
|
314 |
+
self.logger.info(log_msg.format(
|
315 |
+
i, len(iterable), eta=eta_string,
|
316 |
+
meters=str(self),
|
317 |
+
time=str(iter_time), data=str(data_time)))
|
318 |
+
i += 1
|
319 |
+
end = time.time()
|
320 |
+
total_time = time.time() - start_time
|
321 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
322 |
+
self.logger.info('{} Total time: {} ({:.4f} s / it)'.format(
|
323 |
+
header, total_time_str, total_time / len(iterable)))
|
324 |
+
|
325 |
+
|
326 |
+
def _load_checkpoint_for_ema(model_ema, checkpoint):
|
327 |
+
"""
|
328 |
+
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
|
329 |
+
"""
|
330 |
+
mem_file = io.BytesIO()
|
331 |
+
torch.save({'state_dict_ema':checkpoint}, mem_file)
|
332 |
+
mem_file.seek(0)
|
333 |
+
model_ema._load_checkpoint(mem_file)
|
334 |
+
|
335 |
+
|
336 |
+
def setup_for_distributed(is_master):
|
337 |
+
"""
|
338 |
+
This function disables printing when not in master process
|
339 |
+
"""
|
340 |
+
import builtins as __builtin__
|
341 |
+
builtin_print = __builtin__.print
|
342 |
+
|
343 |
+
def print(*args, **kwargs):
|
344 |
+
force = kwargs.pop('force', False)
|
345 |
+
if is_master or force:
|
346 |
+
builtin_print(*args, **kwargs)
|
347 |
+
|
348 |
+
__builtin__.print = print
|
349 |
+
|
350 |
+
|
351 |
+
def is_dist_avail_and_initialized():
|
352 |
+
if not dist.is_available():
|
353 |
+
return False
|
354 |
+
if not dist.is_initialized():
|
355 |
+
return False
|
356 |
+
return True
|
357 |
+
|
358 |
+
|
359 |
+
def get_world_size():
|
360 |
+
if not is_dist_avail_and_initialized():
|
361 |
+
return 1
|
362 |
+
return dist.get_world_size()
|
363 |
+
|
364 |
+
|
365 |
+
def get_rank():
|
366 |
+
if not is_dist_avail_and_initialized():
|
367 |
+
return 0
|
368 |
+
return dist.get_rank()
|
369 |
+
|
370 |
+
|
371 |
+
def is_main_process():
|
372 |
+
return get_rank() == 0
|
373 |
+
|
374 |
+
|
375 |
+
def save_on_master(*args, **kwargs):
|
376 |
+
if is_main_process():
|
377 |
+
torch.save(*args, **kwargs)
|
378 |
+
|
379 |
+
|
380 |
+
def init_distributed_mode(args):
|
381 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
382 |
+
args.rank = int(os.environ["RANK"])
|
383 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
384 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
385 |
+
elif 'SLURM_PROCID' in os.environ:
|
386 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
387 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
388 |
+
else:
|
389 |
+
print('Not using distributed mode')
|
390 |
+
args.distributed = False
|
391 |
+
return
|
392 |
+
|
393 |
+
args.distributed = True
|
394 |
+
|
395 |
+
torch.cuda.set_device(args.gpu)
|
396 |
+
args.dist_backend = 'nccl'
|
397 |
+
print('| distributed init (rank {}): {}'.format(
|
398 |
+
args.rank, args.dist_url), flush=True)
|
399 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
400 |
+
world_size=args.world_size, rank=args.rank)
|
401 |
+
torch.distributed.barrier()
|
402 |
+
setup_for_distributed(args.rank == 0)
|
403 |
+
|
404 |
+
import json
|
405 |
+
def save_on_master_eval_res(log_stats, output_dir):
|
406 |
+
if is_main_process():
|
407 |
+
with open(output_dir, 'a') as f:
|
408 |
+
f.write(json.dumps(log_stats) + "\n")
|