yellowdolphin commited on
Commit
b9b435f
·
1 Parent(s): fa37760

initial version with models, embeddings

Browse files
Files changed (3) hide show
  1. app.py +79 -20
  2. requirements.txt +9 -0
  3. utils.py +365 -0
app.py CHANGED
@@ -1,41 +1,100 @@
 
 
 
 
 
 
 
 
1
  import os
 
2
  from pathlib import Path
3
  from subprocess import run
 
4
 
5
  import gradio as gr
 
 
6
 
7
 
 
8
  yolo_input_size = 384
9
  versions = ('2_v108', '4_v109', '0_int6', '1_v110', '3_v111')
10
-
11
  score_thr = 0.025
12
- zoom_score_thr = 0.35
13
  iou_thr = 0.6
14
  max_det = 1
15
- yolo_ens = 'fast' # fast, val, detect, detect_internal, all
16
- output_size = (512, 512)
17
- bs = 1 #128 if 'CUDA_VERSION' in os.environ else 16
18
-
19
- project_dir = None
20
  working = Path(os.getcwd())
21
  modelbox = working / 'models'
22
  checkpoint_files = [modelbox / f'yolov5_l6_{yolo_input_size}_fold{x}.pt' for x in versions]
23
- image_root = working / 'images' / 'subdir'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- image_urls = [
26
- # Negatives
27
- 'https://upload.wikimedia.org/wikipedia/commons/c/c5/Common_Dolphin.jpg',
28
- 'https://upload.wikimedia.org/wikipedia/commons/b/b8/Beluga847.jpg',
29
- 'https://upload.wikimedia.org/wikipedia/commons/e/ea/Beluga_1_1999-07-03.jpg',
30
- 'https://upload.wikimedia.org/wikipedia/commons/2/2b/Whale_Watching_in_Gloucester%2C_Massachusetts_5.jpg',
31
 
32
- # Positives
33
- '/kaggle/input/happy-whale-and-dolphin/test_images/00098d1376dab2.jpg',
34
- ]
35
- yolo_source = f'{image_root}/negative001.jpg'
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
- def pred_fn(image, fake=True):
39
  if fake:
40
  x0, x1 = (int(f * image.shape[0]) for f in (0.2, 0.8))
41
  y0, y1 = (int(f * image.shape[1]) for f in (0.2, 0.8))
@@ -72,4 +131,4 @@ examples = [str(image_root / f'negative{i:03d}') for i in range(3)]
72
 
73
  demo = gr.Interface(fn=pred_fn, inputs="image", outputs=["image", "text"],
74
  examples=examples)
75
- demo.launch()
 
1
+ # If TF version is not understood by tfimm requirements, try this:
2
+ #try:
3
+ # import tfimm
4
+ #except ModuleNotFoundError:
5
+ # !pip install --no-deps tfimm timm
6
+ # import timm
7
+ # import tfimm
8
+
9
  import os
10
+ import glob
11
  from pathlib import Path
12
  from subprocess import run
13
+ import json
14
 
15
  import gradio as gr
16
+ from yolov5 import detect
17
+ from utils import get_models, get_cfg, get_embeddings, get_comp_embeddings, get_test_embedding
18
 
19
 
20
+ # YOLOv5 parameters
21
  yolo_input_size = 384
22
  versions = ('2_v108', '4_v109', '0_int6', '1_v110', '3_v111')
 
23
  score_thr = 0.025
 
24
  iou_thr = 0.6
25
  max_det = 1
 
 
 
 
 
26
  working = Path(os.getcwd())
27
  modelbox = working / 'models'
28
  checkpoint_files = [modelbox / f'yolov5_l6_{yolo_input_size}_fold{x}.pt' for x in versions]
29
+ image_root = working / 'images'
30
+
31
+
32
+ # Individual identifier parameters
33
+ embedding_size = 1024
34
+ n_images = 51033 + 27956
35
+ max_distance = 0.865
36
+ normalize_similarity = None # test-train, None
37
+ gamma = 0.4
38
+ threshold = 0.09951 if (normalize_similarity == 'test-train') else 0.6 # 0.381
39
+ knn = 300
40
+ emb_path = '/kaggle/input/happywhale-embeddings'
41
+ rst_path = '/kaggle/input/happywhale-models'
42
+ rst_files = sorted(glob(f'{rst_path}/*.h5'))
43
+ n_models = len(rst_files)
44
+
45
+
46
+ def fast_yolo_crop(image):
47
+ !rm -rf {working}/labels {working}/results_ensemble
48
+ #%cd {working}/yolov5
49
+ %cd {working}
50
+ mpimg.imsave(yolo_source, image)
51
+
52
+ #print(f"\nInference on best {len(checkpoint_files[5:])} models with detect.py ...")
53
+ detect.run(weights=checkpoint_files[4:],
54
+ source=yolo_source,
55
+ data='data/dataset.yaml',
56
+ imgsz=yolo_input_size,
57
+ conf_thres=score_thr,
58
+ iou_thres=iou_thr,
59
+ max_det=max_det,
60
+ save_txt=False,
61
+ save_conf=False,
62
+ save_crop=True,
63
+ exist_ok=True,
64
+ name=str(working / 'results_ensemble'))
65
+
66
+ #print(f"YOLOv5 inference finished in {(perf_counter() - t0) / 60:.2f} min")
67
+ cropped = sorted(glob(f'{working}/results_ensemble/crops/*/{Path(yolo_source).name}'))
68
+ assert len(cropped) == 1, f'{len(cropped)} maritime species detected'
69
+ cropped = cropped[0]
70
+ species = Path(cropped).parent.name
71
+ cropped_image = mpimg.imread(cropped)
72
+ return cropped_image, species.replace('_', ' ')
73
+
74
 
75
+ # Preload embeddings for known individuals
76
+ comp_embeddings = get_comp_embeddings(rst_files)
 
 
 
 
77
 
78
+ # Preload embedding models, input sizes
79
+ K.clear_session()
80
+ embed_models, sizes = [], []
81
+ for rst_file in rst_files:
82
+ cfg = get_cfg(rst_file)
83
+ npz_file = Path(rst_file.replace('.h5', '_emb.npz')).name
84
+ assert cfg.FOLD_TO_RUN == use_fold[npz_file]
85
+ cfg.pretrained = None # avoid weight downloads
86
+ if isinstance(cfg.IMAGE_SIZE, int):
87
+ cfg.IMAGE_SIZE = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)
88
+ sizes.append(cfg.IMAGE_SIZE)
89
+ model, embed_model = get_model(cfg)
90
+ model.load_weights(rst_file)
91
+ print(f"\nWeights loaded from {rst_file}")
92
+ print(f"input_size {scaled_img.shape[:2]}, fold {cfg.FOLD_TO_RUN}, arch {cfg.arch_name}, ",
93
+ f"DATASET {cfg.DATASET}, dropout_ps {cfg.dropout_ps}, subcenters {cfg.subcenters}")
94
+ embed_models.append(embed_model)
95
 
96
 
97
+ def pred_fn(image, fake=False):
98
  if fake:
99
  x0, x1 = (int(f * image.shape[0]) for f in (0.2, 0.8))
100
  y0, y1 = (int(f * image.shape[1]) for f in (0.2, 0.8))
 
131
 
132
  demo = gr.Interface(fn=pred_fn, inputs="image", outputs=["image", "text"],
133
  examples=examples)
134
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ yolov5
3
+ ensemble-boxes
4
+ tensorflow
5
+ tfimm
6
+ timm
7
+ efficientnet
8
+ keras-efficientnet-v2
9
+ tensorflow-hub
utils.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import tensorflow as tf
4
+ import tfimm
5
+ import efficientnet
6
+ import efficientnet.tfkeras as efnv1
7
+ import keras_efficientnet_v2 as efnv2
8
+ import tensorflow_hub as hub
9
+
10
+
11
+ class DotDict(dict):
12
+ """dot.notation access to dictionary attributes
13
+
14
+ Reference:
15
+ https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary/23689767#23689767
16
+ """
17
+ __getattr__ = dict.get # returns None if missing key, don't use getattr() with default!
18
+ __setattr__ = dict.__setitem__
19
+ __delattr__ = dict.__delitem__
20
+
21
+
22
+ def get_cfg(rst_file):
23
+ json_file = str(rst_file).replace('.h5', '_config.json')
24
+ config_dict = json.load(open(json_file))
25
+ return DotDict(config_dict)
26
+
27
+
28
+ def get_embeddings(img, embed_model):
29
+ inp = img[None, ...]
30
+ embeddings = embed_model.predict(inp, verbose=1, batch_size=1, workers=4, use_multiprocessing=True)
31
+ return embeddings
32
+
33
+
34
+ # Train embeddings have to be re-ordered: embeddings were concatenated (train, valid)
35
+ # in the training notebook and the valid fold is different for each ensemble model.
36
+ FOLDS = 10
37
+ shards, n_total = [], 0
38
+ for fold in range(10):
39
+ n_img = 5104 if fold <= 2 else 5103
40
+ shards.append(list(range(n_total, n_total + n_img)))
41
+ n_total += n_img
42
+ assert n_total == 51033
43
+
44
+ def get_train_idx(use_fold):
45
+ "Return embedding index that restores the order of images in the tfrec files."
46
+ train_folds = [i for i in range(10) if i % FOLDS != use_fold]
47
+ valid_folds = [i for i in range(10) if i % FOLDS == use_fold]
48
+ folds = train_folds + valid_folds
49
+
50
+ # order of saved embeddings (train + valid)
51
+ train_idx = []
52
+ for fold in folds:
53
+ train_idx.append(shards[fold])
54
+ train_idx = np.concatenate(train_idx)
55
+
56
+ return np.argsort(train_idx)
57
+
58
+ use_fold = {
59
+ 'efnv1b7_colab216_emb.npz': 4,
60
+ 'efnv1b7_colab225_emb.npz': 1,
61
+ 'efnv1b7_colab197_emb.npz': 0,
62
+ 'efnv1b7_colab227_emb.npz': 5,
63
+ 'efnv1b7_v72_emb.npz': 6,
64
+ 'efnv1b7_colab229_emb.npz': 9,
65
+ 'efnv1b6_colab217_emb.npz': 5,
66
+ 'efnv1b6_colab218_emb.npz': 6,
67
+ 'hub_efnv2xl_colab221_emb.npz': 8,
68
+ 'hub_efnv2xl_v69_emb.npz': 2,
69
+ 'hub_efnv2xl_v73_emb.npz': 0,
70
+ 'efnv1b6_colab226_emb.npz': 2,
71
+ 'hub_efnv2l_v70_emb.npz': 3,
72
+ 'hub_efnv2l_colab200_emb.npz': 2,
73
+ 'hub_efnv2l_colab199_emb.npz': 1,
74
+ 'convnext_base_384_in22ft1k_v68_emb.npz': 0,
75
+ 'convnext_base_384_in22ft1k_colab220_emb.npz': 9,
76
+ 'convnext_base_384_in22ft1k_colab201_emb.npz': 3, # new
77
+ }
78
+
79
+
80
+ def get_comp_embeddings(rst_files):
81
+ "Load embeddings for competition images [n_images, embedding_size]"
82
+
83
+ comp_embeddings = []
84
+
85
+ for rst_file in rst_files:
86
+ # Get embeddings for all competition images
87
+ npz_file = Path(rst_file.replace('.h5', '_emb.npz')).name
88
+ d = np.load(str(Path(emb_path) / npz_file))
89
+ comp_train_emb = d['train']
90
+ comp_test_emb = d['test']
91
+
92
+ # Restore original order of comp_train_emb, targets (use targets as fingerprint-check)
93
+ comp_train_idx = get_train_idx(use_fold[npz_file])
94
+ comp_train_emb = comp_train_emb[comp_train_idx, :]
95
+ comp_embs = np.concatenate([comp_train_emb, comp_test_emb], axis=0)
96
+ assert comp_embs.shape == (n_images, embedding_size)
97
+
98
+ # Normalize embeddings
99
+ comp_embs_norms = np.linalg.norm(comp_embs, axis=1)
100
+ print("comp_embs norm:", comp_embs_norms.min(), "...", comp_embs_norms.max())
101
+ comp_embs /= comp_embs_norms[:, None]
102
+
103
+ comp_embeddings.append(comp_embs)
104
+
105
+ return np.concatenate(comp_embeddings, axis=1)
106
+
107
+
108
+ def get_test_embedding(embed_models, sizes):
109
+ test_embedding, similarities = [], []
110
+
111
+ for embed_model, size in zip(embed_models, sizes):
112
+ # Get model input
113
+ scaled_img = tf.image.resize(img, size)
114
+ scaled_img = tf.cast(scaled_img, tf.float32) / 255.0
115
+ #print("test image normalized and resized to", scaled_img.shape[:2])
116
+
117
+ # Get embedding for test image
118
+ test_emb = get_embeddings(scaled_img, embed_model) # shape: [1, embedding_size]
119
+ assert test_emb.shape == (1, embedding_size)
120
+
121
+ # Normalize embeddings
122
+ test_emb_norm = np.linalg.norm(test_emb, axis=1)
123
+ #print("test_emb norm: ", test_emb_norm[0])
124
+ test_emb /= test_emb_norm[:, None]
125
+
126
+ test_embedding.append(test_emb)
127
+
128
+ return np.concatenate(test_embedding, axis=1) # [1, embedding_size]
129
+
130
+
131
+ class ArcMarginProductSubCenter(tf.keras.layers.Layer):
132
+ '''
133
+ Implements large margin arc distance.
134
+
135
+ References:
136
+ https://arxiv.org/pdf/1801.07698.pdf
137
+ https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
138
+ https://github.com/haqishen/Google-Landmark-Recognition-2020-3rd-Place-Solution/
139
+
140
+ Sub-center version:
141
+ for k > 1, the embedding layer can learn k sub-centers per class
142
+ '''
143
+ def __init__(self, n_classes, s=30, m=0.50, k=3, easy_margin=False,
144
+ ls_eps=0.0, **kwargs):
145
+
146
+ super(ArcMarginProductSubCenter, self).__init__(**kwargs)
147
+
148
+ self.n_classes = n_classes
149
+ self.s = s
150
+ self.m = m
151
+ self.k = k
152
+ self.ls_eps = ls_eps
153
+ self.easy_margin = easy_margin
154
+ self.cos_m = tf.math.cos(m)
155
+ self.sin_m = tf.math.sin(m)
156
+ self.th = tf.math.cos(math.pi - m)
157
+ self.mm = tf.math.sin(math.pi - m) * m
158
+
159
+ def get_config(self):
160
+
161
+ config = super().get_config().copy()
162
+ config.update({
163
+ 'n_classes': self.n_classes,
164
+ 's': self.s,
165
+ 'm': self.m,
166
+ 'k': self.k,
167
+ 'ls_eps': self.ls_eps,
168
+ 'easy_margin': self.easy_margin,
169
+ })
170
+ return config
171
+
172
+ def build(self, input_shape):
173
+ super(ArcMarginProductSubCenter, self).build(input_shape[0])
174
+
175
+ self.W = self.add_weight(
176
+ name='W',
177
+ shape=(int(input_shape[0][-1]), self.n_classes * self.k),
178
+ initializer='glorot_uniform',
179
+ dtype='float32',
180
+ trainable=True)
181
+
182
+ def call(self, inputs):
183
+ X, y = inputs
184
+ y = tf.cast(y, dtype=tf.int32)
185
+ cosine_all = tf.matmul(
186
+ tf.math.l2_normalize(X, axis=1),
187
+ tf.math.l2_normalize(self.W, axis=0)
188
+ )
189
+ if self.k > 1:
190
+ cosine_all = tf.reshape(cosine_all, [-1, self.n_classes, self.k])
191
+ cosine = tf.math.reduce_max(cosine_all, axis=2)
192
+ else:
193
+ cosine = cosine_all
194
+ sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
195
+ phi = cosine * self.cos_m - sine * self.sin_m
196
+ if self.easy_margin:
197
+ phi = tf.where(cosine > 0, phi, cosine)
198
+ else:
199
+ phi = tf.where(cosine > self.th, phi, cosine - self.mm)
200
+ one_hot = tf.cast(
201
+ tf.one_hot(y, depth=self.n_classes),
202
+ dtype=cosine.dtype
203
+ )
204
+ if self.ls_eps > 0:
205
+ one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes
206
+
207
+ output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
208
+ output *= self.s
209
+ return output
210
+
211
+
212
+ TFHUB = {
213
+ 'hub_efnv2s': "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_s/feature_vector/2",
214
+ 'hub_efnv2m': "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_m/feature_vector/2",
215
+ 'hub_efnv2l': "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_l/feature_vector/2",
216
+ 'hub_efnv2xl': "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_xl/feature_vector/2",
217
+ 'bit_m-r50x1': "https://tfhub.dev/google/bit/m-r50x1/1",
218
+ 'bit_m-r50x3': "https://tfhub.dev/google/bit/m-r50x3/1",
219
+ 'bit_m-r101x1': "https://tfhub.dev/google/bit/m-r101x1/1",
220
+ 'bit_m-r101x3': "https://tfhub.dev/google/bit/m-r101x3/1",
221
+ 'bit_m-r152x4': "https://tfhub.dev/google/bit/m-r152x4/1",
222
+ }
223
+
224
+
225
+ def get_model(cfg):
226
+ aux_arcface = False # Chris Deotte suggested this
227
+ if cfg.head == 'arcface2':
228
+ head = ArcMarginPenaltyLogists
229
+ elif cfg.head == 'arcface':
230
+ head = ArcMarginProductSubCenter
231
+ elif cfg.head == 'addface':
232
+ head = AddMarginProductSubCenter
233
+ else:
234
+ assert False, "INVALID HEAD"
235
+
236
+ if cfg.adaptive_margin:
237
+ # define adaptive margins depending on class frequencies (dynamic margins)
238
+ df = pd.read_csv(f'{project_dir}/train.csv')
239
+ fewness = df['individual_id'].value_counts().sort_index() ** (-1/4)
240
+ fewness -= fewness.min()
241
+ fewness /= fewness.max() - fewness.min()
242
+ adaptive_margin = cfg.margin_min + fewness * (cfg.margin_max - cfg.margin_min)
243
+
244
+ # align margins with targets
245
+ splits_path = '/kaggle/input/happywhale-splits'
246
+ with open (f'{splits_path}/individual_ids.json', "r") as f:
247
+ target_encodings = json.loads(f.read()) # individual_id: index
248
+ individual_ids = pd.Series(target_encodings).sort_values().index.values
249
+ adaptive_margin = adaptive_margin.loc[individual_ids].values.astype(np.float32)
250
+
251
+ if cfg.arch_name.startswith('efnv1'):
252
+ EFN = {'efnv1b0': efnv1.EfficientNetB0, 'efnv1b1': efnv1.EfficientNetB1,
253
+ 'efnv1b2': efnv1.EfficientNetB2, 'efnv1b3': efnv1.EfficientNetB3,
254
+ 'efnv1b4': efnv1.EfficientNetB4, 'efnv1b5': efnv1.EfficientNetB5,
255
+ 'efnv1b6': efnv1.EfficientNetB6, 'efnv1b7': efnv1.EfficientNetB7}
256
+
257
+ if cfg.arch_name.startswith('efnv2'):
258
+ EFN = {'efnv2s': efnv2.EfficientNetV2S, 'efnv2m': efnv2.EfficientNetV2M,
259
+ 'efnv2l': efnv2.EfficientNetV2L, 'efnv2xl': efnv2.EfficientNetV2XL}
260
+
261
+
262
+ with strategy.scope():
263
+
264
+ margin = head(
265
+ n_classes = cfg.N_CLASSES,
266
+ s = 30,
267
+ m = adaptive_margin if cfg.adaptive_margin else 0.3,
268
+ k = cfg.subcenters or 1,
269
+ easy_margin = False,
270
+ name=f'head/{cfg.head}',
271
+ dtype='float32')
272
+
273
+ inp = tf.keras.layers.Input(shape = [*cfg.IMAGE_SIZE, 3], name = 'inp1')
274
+ label = tf.keras.layers.Input(shape = (), name = 'inp2')
275
+ if aux_arcface:
276
+ label2 = tf.keras.layers.Input(shape = (), name = 'inp3')
277
+
278
+ if cfg.arch_name.startswith('efnv1'):
279
+ x = EFN[cfg.arch_name](weights=cfg.pretrained, include_top=False)(inp)
280
+ if cfg.pool == 'flatten':
281
+ embed = tf.keras.layers.Flatten()(x)
282
+ elif cfg.pool == 'fc':
283
+ embed = tf.keras.layers.Flatten()(x)
284
+ embed = tf.keras.layers.Dropout(0.1)(embed)
285
+ embed = tf.keras.layers.Dense(1024)(embed)
286
+ elif cfg.pool == 'gem':
287
+ embed = GeMPoolingLayer(train_p=True)(x)
288
+ elif cfg.pool == 'concat':
289
+ embed = tf.keras.layers.concatenate([tf.keras.layers.GlobalAveragePooling2D()(x),
290
+ tf.keras.layers.GlobalAveragePooling2D()(x)])
291
+ elif cfg.pool == 'max':
292
+ embed = tf.keras.layers.GlobalMaxPooling2D()(x)
293
+ else:
294
+ embed = tf.keras.layers.GlobalAveragePooling2D()(x)
295
+
296
+ elif cfg.arch_name.startswith('efnv2'):
297
+ x = EFN[cfg.arch_name](input_shape=(None, None, 3), num_classes=0,
298
+ pretrained=cfg.pretrained)(inp)
299
+ if cfg.pool == 'flatten':
300
+ embed = tf.keras.layers.Flatten()(x)
301
+ elif cfg.pool == 'fc':
302
+ embed = tf.keras.layers.Flatten()(x)
303
+ embed = tf.keras.layers.Dropout(0.1)(embed)
304
+ embed = tf.keras.layers.Dense(1024)(embed)
305
+ elif cfg.pool == 'gem':
306
+ embed = GeMPoolingLayer(train_p=True)(x)
307
+ elif cfg.pool == 'concat':
308
+ embed = tf.keras.layers.concatenate([tf.keras.layers.GlobalAveragePooling2D()(x),
309
+ tf.keras.layers.GlobalAveragePooling2D()(x)])
310
+ elif cfg.pool == 'max':
311
+ embed = tf.keras.layers.GlobalMaxPooling2D()(x)
312
+ else:
313
+ embed = tf.keras.layers.GlobalAveragePooling2D()(x)
314
+
315
+ elif cfg.arch_name in TFHUB:
316
+ # tfhub models cannot be modified => Pooling cannot be changed!
317
+ url = TFHUB[cfg.arch_name]
318
+ model = hub.KerasLayer(url, trainable=True)
319
+ embed = model(inp)
320
+ #print(f"{cfg.arch_name} from tfhub")
321
+ assert cfg.pool in [None, False, 'avg', ''], 'tfhub model, no custom pooling supported!'
322
+
323
+ elif cfg.arch_name in tfimm.list_models(pretrained="timm"):
324
+ #print(f"{cfg.arch_name} from tfimm")
325
+ #embed = tfimm.create_model(cfg.arch_name, pretrained="timm", nb_classes=0)(inp)
326
+ embed = tfimm.create_model(cfg.arch_name, pretrained=None, nb_classes=0)(inp)
327
+ # create_model(nb_classes=0) includes pooling as last layer
328
+
329
+ if len(cfg.dropout_ps) > 0:
330
+ # Chris Deotte posted model code without Dropout/FC1 after pooling
331
+ embed = tf.keras.layers.Dropout(cfg.dropout_ps[0])(embed)
332
+ embed = tf.keras.layers.Dense(1024)(embed) # tunable embedding size
333
+ embed = tf.keras.layers.BatchNormalization()(embed) # missing in public notebooks
334
+ x = margin([embed, label])
335
+
336
+ output = tf.keras.layers.Softmax(dtype='float32', name='arc' if cfg.aux_loss else None)(x)
337
+
338
+ if cfg.aux_loss and aux_arcface:
339
+ # Use 2nd arcface head for species (aux loss)
340
+ head2 = ArcMarginProductSubCenter
341
+ margin2 = head(
342
+ n_classes = cfg.n_species,
343
+ s = 30,
344
+ m = 0.3,
345
+ k = 1,
346
+ easy_margin = False,
347
+ name=f'auxhead/{cfg.head}',
348
+ dtype='float32')
349
+ aux_features = margin2([embed, label2])
350
+ aux_output = tf.keras.layers.Softmax(dtype='float32', name='aux')(aux_features)
351
+
352
+ elif cfg.aux_loss:
353
+ aux_features = tf.keras.layers.Dense(cfg.n_species)(embed)
354
+ aux_output = tf.keras.layers.Softmax(dtype='float32', name='aux')(aux_features)
355
+ inputs = [inp, label, label2] if (cfg.aux_loss and aux_arcface) else [inp, label]
356
+ outputs = (output, aux_output) if cfg.aux_loss else [output]
357
+
358
+ model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
359
+ embed_model = tf.keras.models.Model(inputs=inp, outputs=embed)
360
+
361
+ opt = tf.keras.optimizers.Adam(learning_rate=cfg.LR)
362
+ if cfg.FREEZE_BATCH_NORM:
363
+ freeze_BN(model)
364
+
365
+ return model, embed_model