Spaces:
Sleeping
Sleeping
Commit
·
b9b435f
1
Parent(s):
fa37760
initial version with models, embeddings
Browse files- app.py +79 -20
- requirements.txt +9 -0
- utils.py +365 -0
app.py
CHANGED
@@ -1,41 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
from pathlib import Path
|
3 |
from subprocess import run
|
|
|
4 |
|
5 |
import gradio as gr
|
|
|
|
|
6 |
|
7 |
|
|
|
8 |
yolo_input_size = 384
|
9 |
versions = ('2_v108', '4_v109', '0_int6', '1_v110', '3_v111')
|
10 |
-
|
11 |
score_thr = 0.025
|
12 |
-
zoom_score_thr = 0.35
|
13 |
iou_thr = 0.6
|
14 |
max_det = 1
|
15 |
-
yolo_ens = 'fast' # fast, val, detect, detect_internal, all
|
16 |
-
output_size = (512, 512)
|
17 |
-
bs = 1 #128 if 'CUDA_VERSION' in os.environ else 16
|
18 |
-
|
19 |
-
project_dir = None
|
20 |
working = Path(os.getcwd())
|
21 |
modelbox = working / 'models'
|
22 |
checkpoint_files = [modelbox / f'yolov5_l6_{yolo_input_size}_fold{x}.pt' for x in versions]
|
23 |
-
image_root = working / 'images'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
'https://upload.wikimedia.org/wikipedia/commons/c/c5/Common_Dolphin.jpg',
|
28 |
-
'https://upload.wikimedia.org/wikipedia/commons/b/b8/Beluga847.jpg',
|
29 |
-
'https://upload.wikimedia.org/wikipedia/commons/e/ea/Beluga_1_1999-07-03.jpg',
|
30 |
-
'https://upload.wikimedia.org/wikipedia/commons/2/2b/Whale_Watching_in_Gloucester%2C_Massachusetts_5.jpg',
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
]
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
-
def pred_fn(image, fake=
|
39 |
if fake:
|
40 |
x0, x1 = (int(f * image.shape[0]) for f in (0.2, 0.8))
|
41 |
y0, y1 = (int(f * image.shape[1]) for f in (0.2, 0.8))
|
@@ -72,4 +131,4 @@ examples = [str(image_root / f'negative{i:03d}') for i in range(3)]
|
|
72 |
|
73 |
demo = gr.Interface(fn=pred_fn, inputs="image", outputs=["image", "text"],
|
74 |
examples=examples)
|
75 |
-
demo.launch()
|
|
|
1 |
+
# If TF version is not understood by tfimm requirements, try this:
|
2 |
+
#try:
|
3 |
+
# import tfimm
|
4 |
+
#except ModuleNotFoundError:
|
5 |
+
# !pip install --no-deps tfimm timm
|
6 |
+
# import timm
|
7 |
+
# import tfimm
|
8 |
+
|
9 |
import os
|
10 |
+
import glob
|
11 |
from pathlib import Path
|
12 |
from subprocess import run
|
13 |
+
import json
|
14 |
|
15 |
import gradio as gr
|
16 |
+
from yolov5 import detect
|
17 |
+
from utils import get_models, get_cfg, get_embeddings, get_comp_embeddings, get_test_embedding
|
18 |
|
19 |
|
20 |
+
# YOLOv5 parameters
|
21 |
yolo_input_size = 384
|
22 |
versions = ('2_v108', '4_v109', '0_int6', '1_v110', '3_v111')
|
|
|
23 |
score_thr = 0.025
|
|
|
24 |
iou_thr = 0.6
|
25 |
max_det = 1
|
|
|
|
|
|
|
|
|
|
|
26 |
working = Path(os.getcwd())
|
27 |
modelbox = working / 'models'
|
28 |
checkpoint_files = [modelbox / f'yolov5_l6_{yolo_input_size}_fold{x}.pt' for x in versions]
|
29 |
+
image_root = working / 'images'
|
30 |
+
|
31 |
+
|
32 |
+
# Individual identifier parameters
|
33 |
+
embedding_size = 1024
|
34 |
+
n_images = 51033 + 27956
|
35 |
+
max_distance = 0.865
|
36 |
+
normalize_similarity = None # test-train, None
|
37 |
+
gamma = 0.4
|
38 |
+
threshold = 0.09951 if (normalize_similarity == 'test-train') else 0.6 # 0.381
|
39 |
+
knn = 300
|
40 |
+
emb_path = '/kaggle/input/happywhale-embeddings'
|
41 |
+
rst_path = '/kaggle/input/happywhale-models'
|
42 |
+
rst_files = sorted(glob(f'{rst_path}/*.h5'))
|
43 |
+
n_models = len(rst_files)
|
44 |
+
|
45 |
+
|
46 |
+
def fast_yolo_crop(image):
|
47 |
+
!rm -rf {working}/labels {working}/results_ensemble
|
48 |
+
#%cd {working}/yolov5
|
49 |
+
%cd {working}
|
50 |
+
mpimg.imsave(yolo_source, image)
|
51 |
+
|
52 |
+
#print(f"\nInference on best {len(checkpoint_files[5:])} models with detect.py ...")
|
53 |
+
detect.run(weights=checkpoint_files[4:],
|
54 |
+
source=yolo_source,
|
55 |
+
data='data/dataset.yaml',
|
56 |
+
imgsz=yolo_input_size,
|
57 |
+
conf_thres=score_thr,
|
58 |
+
iou_thres=iou_thr,
|
59 |
+
max_det=max_det,
|
60 |
+
save_txt=False,
|
61 |
+
save_conf=False,
|
62 |
+
save_crop=True,
|
63 |
+
exist_ok=True,
|
64 |
+
name=str(working / 'results_ensemble'))
|
65 |
+
|
66 |
+
#print(f"YOLOv5 inference finished in {(perf_counter() - t0) / 60:.2f} min")
|
67 |
+
cropped = sorted(glob(f'{working}/results_ensemble/crops/*/{Path(yolo_source).name}'))
|
68 |
+
assert len(cropped) == 1, f'{len(cropped)} maritime species detected'
|
69 |
+
cropped = cropped[0]
|
70 |
+
species = Path(cropped).parent.name
|
71 |
+
cropped_image = mpimg.imread(cropped)
|
72 |
+
return cropped_image, species.replace('_', ' ')
|
73 |
+
|
74 |
|
75 |
+
# Preload embeddings for known individuals
|
76 |
+
comp_embeddings = get_comp_embeddings(rst_files)
|
|
|
|
|
|
|
|
|
77 |
|
78 |
+
# Preload embedding models, input sizes
|
79 |
+
K.clear_session()
|
80 |
+
embed_models, sizes = [], []
|
81 |
+
for rst_file in rst_files:
|
82 |
+
cfg = get_cfg(rst_file)
|
83 |
+
npz_file = Path(rst_file.replace('.h5', '_emb.npz')).name
|
84 |
+
assert cfg.FOLD_TO_RUN == use_fold[npz_file]
|
85 |
+
cfg.pretrained = None # avoid weight downloads
|
86 |
+
if isinstance(cfg.IMAGE_SIZE, int):
|
87 |
+
cfg.IMAGE_SIZE = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)
|
88 |
+
sizes.append(cfg.IMAGE_SIZE)
|
89 |
+
model, embed_model = get_model(cfg)
|
90 |
+
model.load_weights(rst_file)
|
91 |
+
print(f"\nWeights loaded from {rst_file}")
|
92 |
+
print(f"input_size {scaled_img.shape[:2]}, fold {cfg.FOLD_TO_RUN}, arch {cfg.arch_name}, ",
|
93 |
+
f"DATASET {cfg.DATASET}, dropout_ps {cfg.dropout_ps}, subcenters {cfg.subcenters}")
|
94 |
+
embed_models.append(embed_model)
|
95 |
|
96 |
|
97 |
+
def pred_fn(image, fake=False):
|
98 |
if fake:
|
99 |
x0, x1 = (int(f * image.shape[0]) for f in (0.2, 0.8))
|
100 |
y0, y1 = (int(f * image.shape[1]) for f in (0.2, 0.8))
|
|
|
131 |
|
132 |
demo = gr.Interface(fn=pred_fn, inputs="image", outputs=["image", "text"],
|
133 |
examples=examples)
|
134 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
yolov5
|
3 |
+
ensemble-boxes
|
4 |
+
tensorflow
|
5 |
+
tfimm
|
6 |
+
timm
|
7 |
+
efficientnet
|
8 |
+
keras-efficientnet-v2
|
9 |
+
tensorflow-hub
|
utils.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
import tfimm
|
5 |
+
import efficientnet
|
6 |
+
import efficientnet.tfkeras as efnv1
|
7 |
+
import keras_efficientnet_v2 as efnv2
|
8 |
+
import tensorflow_hub as hub
|
9 |
+
|
10 |
+
|
11 |
+
class DotDict(dict):
|
12 |
+
"""dot.notation access to dictionary attributes
|
13 |
+
|
14 |
+
Reference:
|
15 |
+
https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary/23689767#23689767
|
16 |
+
"""
|
17 |
+
__getattr__ = dict.get # returns None if missing key, don't use getattr() with default!
|
18 |
+
__setattr__ = dict.__setitem__
|
19 |
+
__delattr__ = dict.__delitem__
|
20 |
+
|
21 |
+
|
22 |
+
def get_cfg(rst_file):
|
23 |
+
json_file = str(rst_file).replace('.h5', '_config.json')
|
24 |
+
config_dict = json.load(open(json_file))
|
25 |
+
return DotDict(config_dict)
|
26 |
+
|
27 |
+
|
28 |
+
def get_embeddings(img, embed_model):
|
29 |
+
inp = img[None, ...]
|
30 |
+
embeddings = embed_model.predict(inp, verbose=1, batch_size=1, workers=4, use_multiprocessing=True)
|
31 |
+
return embeddings
|
32 |
+
|
33 |
+
|
34 |
+
# Train embeddings have to be re-ordered: embeddings were concatenated (train, valid)
|
35 |
+
# in the training notebook and the valid fold is different for each ensemble model.
|
36 |
+
FOLDS = 10
|
37 |
+
shards, n_total = [], 0
|
38 |
+
for fold in range(10):
|
39 |
+
n_img = 5104 if fold <= 2 else 5103
|
40 |
+
shards.append(list(range(n_total, n_total + n_img)))
|
41 |
+
n_total += n_img
|
42 |
+
assert n_total == 51033
|
43 |
+
|
44 |
+
def get_train_idx(use_fold):
|
45 |
+
"Return embedding index that restores the order of images in the tfrec files."
|
46 |
+
train_folds = [i for i in range(10) if i % FOLDS != use_fold]
|
47 |
+
valid_folds = [i for i in range(10) if i % FOLDS == use_fold]
|
48 |
+
folds = train_folds + valid_folds
|
49 |
+
|
50 |
+
# order of saved embeddings (train + valid)
|
51 |
+
train_idx = []
|
52 |
+
for fold in folds:
|
53 |
+
train_idx.append(shards[fold])
|
54 |
+
train_idx = np.concatenate(train_idx)
|
55 |
+
|
56 |
+
return np.argsort(train_idx)
|
57 |
+
|
58 |
+
use_fold = {
|
59 |
+
'efnv1b7_colab216_emb.npz': 4,
|
60 |
+
'efnv1b7_colab225_emb.npz': 1,
|
61 |
+
'efnv1b7_colab197_emb.npz': 0,
|
62 |
+
'efnv1b7_colab227_emb.npz': 5,
|
63 |
+
'efnv1b7_v72_emb.npz': 6,
|
64 |
+
'efnv1b7_colab229_emb.npz': 9,
|
65 |
+
'efnv1b6_colab217_emb.npz': 5,
|
66 |
+
'efnv1b6_colab218_emb.npz': 6,
|
67 |
+
'hub_efnv2xl_colab221_emb.npz': 8,
|
68 |
+
'hub_efnv2xl_v69_emb.npz': 2,
|
69 |
+
'hub_efnv2xl_v73_emb.npz': 0,
|
70 |
+
'efnv1b6_colab226_emb.npz': 2,
|
71 |
+
'hub_efnv2l_v70_emb.npz': 3,
|
72 |
+
'hub_efnv2l_colab200_emb.npz': 2,
|
73 |
+
'hub_efnv2l_colab199_emb.npz': 1,
|
74 |
+
'convnext_base_384_in22ft1k_v68_emb.npz': 0,
|
75 |
+
'convnext_base_384_in22ft1k_colab220_emb.npz': 9,
|
76 |
+
'convnext_base_384_in22ft1k_colab201_emb.npz': 3, # new
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def get_comp_embeddings(rst_files):
|
81 |
+
"Load embeddings for competition images [n_images, embedding_size]"
|
82 |
+
|
83 |
+
comp_embeddings = []
|
84 |
+
|
85 |
+
for rst_file in rst_files:
|
86 |
+
# Get embeddings for all competition images
|
87 |
+
npz_file = Path(rst_file.replace('.h5', '_emb.npz')).name
|
88 |
+
d = np.load(str(Path(emb_path) / npz_file))
|
89 |
+
comp_train_emb = d['train']
|
90 |
+
comp_test_emb = d['test']
|
91 |
+
|
92 |
+
# Restore original order of comp_train_emb, targets (use targets as fingerprint-check)
|
93 |
+
comp_train_idx = get_train_idx(use_fold[npz_file])
|
94 |
+
comp_train_emb = comp_train_emb[comp_train_idx, :]
|
95 |
+
comp_embs = np.concatenate([comp_train_emb, comp_test_emb], axis=0)
|
96 |
+
assert comp_embs.shape == (n_images, embedding_size)
|
97 |
+
|
98 |
+
# Normalize embeddings
|
99 |
+
comp_embs_norms = np.linalg.norm(comp_embs, axis=1)
|
100 |
+
print("comp_embs norm:", comp_embs_norms.min(), "...", comp_embs_norms.max())
|
101 |
+
comp_embs /= comp_embs_norms[:, None]
|
102 |
+
|
103 |
+
comp_embeddings.append(comp_embs)
|
104 |
+
|
105 |
+
return np.concatenate(comp_embeddings, axis=1)
|
106 |
+
|
107 |
+
|
108 |
+
def get_test_embedding(embed_models, sizes):
|
109 |
+
test_embedding, similarities = [], []
|
110 |
+
|
111 |
+
for embed_model, size in zip(embed_models, sizes):
|
112 |
+
# Get model input
|
113 |
+
scaled_img = tf.image.resize(img, size)
|
114 |
+
scaled_img = tf.cast(scaled_img, tf.float32) / 255.0
|
115 |
+
#print("test image normalized and resized to", scaled_img.shape[:2])
|
116 |
+
|
117 |
+
# Get embedding for test image
|
118 |
+
test_emb = get_embeddings(scaled_img, embed_model) # shape: [1, embedding_size]
|
119 |
+
assert test_emb.shape == (1, embedding_size)
|
120 |
+
|
121 |
+
# Normalize embeddings
|
122 |
+
test_emb_norm = np.linalg.norm(test_emb, axis=1)
|
123 |
+
#print("test_emb norm: ", test_emb_norm[0])
|
124 |
+
test_emb /= test_emb_norm[:, None]
|
125 |
+
|
126 |
+
test_embedding.append(test_emb)
|
127 |
+
|
128 |
+
return np.concatenate(test_embedding, axis=1) # [1, embedding_size]
|
129 |
+
|
130 |
+
|
131 |
+
class ArcMarginProductSubCenter(tf.keras.layers.Layer):
|
132 |
+
'''
|
133 |
+
Implements large margin arc distance.
|
134 |
+
|
135 |
+
References:
|
136 |
+
https://arxiv.org/pdf/1801.07698.pdf
|
137 |
+
https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
|
138 |
+
https://github.com/haqishen/Google-Landmark-Recognition-2020-3rd-Place-Solution/
|
139 |
+
|
140 |
+
Sub-center version:
|
141 |
+
for k > 1, the embedding layer can learn k sub-centers per class
|
142 |
+
'''
|
143 |
+
def __init__(self, n_classes, s=30, m=0.50, k=3, easy_margin=False,
|
144 |
+
ls_eps=0.0, **kwargs):
|
145 |
+
|
146 |
+
super(ArcMarginProductSubCenter, self).__init__(**kwargs)
|
147 |
+
|
148 |
+
self.n_classes = n_classes
|
149 |
+
self.s = s
|
150 |
+
self.m = m
|
151 |
+
self.k = k
|
152 |
+
self.ls_eps = ls_eps
|
153 |
+
self.easy_margin = easy_margin
|
154 |
+
self.cos_m = tf.math.cos(m)
|
155 |
+
self.sin_m = tf.math.sin(m)
|
156 |
+
self.th = tf.math.cos(math.pi - m)
|
157 |
+
self.mm = tf.math.sin(math.pi - m) * m
|
158 |
+
|
159 |
+
def get_config(self):
|
160 |
+
|
161 |
+
config = super().get_config().copy()
|
162 |
+
config.update({
|
163 |
+
'n_classes': self.n_classes,
|
164 |
+
's': self.s,
|
165 |
+
'm': self.m,
|
166 |
+
'k': self.k,
|
167 |
+
'ls_eps': self.ls_eps,
|
168 |
+
'easy_margin': self.easy_margin,
|
169 |
+
})
|
170 |
+
return config
|
171 |
+
|
172 |
+
def build(self, input_shape):
|
173 |
+
super(ArcMarginProductSubCenter, self).build(input_shape[0])
|
174 |
+
|
175 |
+
self.W = self.add_weight(
|
176 |
+
name='W',
|
177 |
+
shape=(int(input_shape[0][-1]), self.n_classes * self.k),
|
178 |
+
initializer='glorot_uniform',
|
179 |
+
dtype='float32',
|
180 |
+
trainable=True)
|
181 |
+
|
182 |
+
def call(self, inputs):
|
183 |
+
X, y = inputs
|
184 |
+
y = tf.cast(y, dtype=tf.int32)
|
185 |
+
cosine_all = tf.matmul(
|
186 |
+
tf.math.l2_normalize(X, axis=1),
|
187 |
+
tf.math.l2_normalize(self.W, axis=0)
|
188 |
+
)
|
189 |
+
if self.k > 1:
|
190 |
+
cosine_all = tf.reshape(cosine_all, [-1, self.n_classes, self.k])
|
191 |
+
cosine = tf.math.reduce_max(cosine_all, axis=2)
|
192 |
+
else:
|
193 |
+
cosine = cosine_all
|
194 |
+
sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
|
195 |
+
phi = cosine * self.cos_m - sine * self.sin_m
|
196 |
+
if self.easy_margin:
|
197 |
+
phi = tf.where(cosine > 0, phi, cosine)
|
198 |
+
else:
|
199 |
+
phi = tf.where(cosine > self.th, phi, cosine - self.mm)
|
200 |
+
one_hot = tf.cast(
|
201 |
+
tf.one_hot(y, depth=self.n_classes),
|
202 |
+
dtype=cosine.dtype
|
203 |
+
)
|
204 |
+
if self.ls_eps > 0:
|
205 |
+
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes
|
206 |
+
|
207 |
+
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
208 |
+
output *= self.s
|
209 |
+
return output
|
210 |
+
|
211 |
+
|
212 |
+
TFHUB = {
|
213 |
+
'hub_efnv2s': "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_s/feature_vector/2",
|
214 |
+
'hub_efnv2m': "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_m/feature_vector/2",
|
215 |
+
'hub_efnv2l': "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_l/feature_vector/2",
|
216 |
+
'hub_efnv2xl': "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_xl/feature_vector/2",
|
217 |
+
'bit_m-r50x1': "https://tfhub.dev/google/bit/m-r50x1/1",
|
218 |
+
'bit_m-r50x3': "https://tfhub.dev/google/bit/m-r50x3/1",
|
219 |
+
'bit_m-r101x1': "https://tfhub.dev/google/bit/m-r101x1/1",
|
220 |
+
'bit_m-r101x3': "https://tfhub.dev/google/bit/m-r101x3/1",
|
221 |
+
'bit_m-r152x4': "https://tfhub.dev/google/bit/m-r152x4/1",
|
222 |
+
}
|
223 |
+
|
224 |
+
|
225 |
+
def get_model(cfg):
|
226 |
+
aux_arcface = False # Chris Deotte suggested this
|
227 |
+
if cfg.head == 'arcface2':
|
228 |
+
head = ArcMarginPenaltyLogists
|
229 |
+
elif cfg.head == 'arcface':
|
230 |
+
head = ArcMarginProductSubCenter
|
231 |
+
elif cfg.head == 'addface':
|
232 |
+
head = AddMarginProductSubCenter
|
233 |
+
else:
|
234 |
+
assert False, "INVALID HEAD"
|
235 |
+
|
236 |
+
if cfg.adaptive_margin:
|
237 |
+
# define adaptive margins depending on class frequencies (dynamic margins)
|
238 |
+
df = pd.read_csv(f'{project_dir}/train.csv')
|
239 |
+
fewness = df['individual_id'].value_counts().sort_index() ** (-1/4)
|
240 |
+
fewness -= fewness.min()
|
241 |
+
fewness /= fewness.max() - fewness.min()
|
242 |
+
adaptive_margin = cfg.margin_min + fewness * (cfg.margin_max - cfg.margin_min)
|
243 |
+
|
244 |
+
# align margins with targets
|
245 |
+
splits_path = '/kaggle/input/happywhale-splits'
|
246 |
+
with open (f'{splits_path}/individual_ids.json', "r") as f:
|
247 |
+
target_encodings = json.loads(f.read()) # individual_id: index
|
248 |
+
individual_ids = pd.Series(target_encodings).sort_values().index.values
|
249 |
+
adaptive_margin = adaptive_margin.loc[individual_ids].values.astype(np.float32)
|
250 |
+
|
251 |
+
if cfg.arch_name.startswith('efnv1'):
|
252 |
+
EFN = {'efnv1b0': efnv1.EfficientNetB0, 'efnv1b1': efnv1.EfficientNetB1,
|
253 |
+
'efnv1b2': efnv1.EfficientNetB2, 'efnv1b3': efnv1.EfficientNetB3,
|
254 |
+
'efnv1b4': efnv1.EfficientNetB4, 'efnv1b5': efnv1.EfficientNetB5,
|
255 |
+
'efnv1b6': efnv1.EfficientNetB6, 'efnv1b7': efnv1.EfficientNetB7}
|
256 |
+
|
257 |
+
if cfg.arch_name.startswith('efnv2'):
|
258 |
+
EFN = {'efnv2s': efnv2.EfficientNetV2S, 'efnv2m': efnv2.EfficientNetV2M,
|
259 |
+
'efnv2l': efnv2.EfficientNetV2L, 'efnv2xl': efnv2.EfficientNetV2XL}
|
260 |
+
|
261 |
+
|
262 |
+
with strategy.scope():
|
263 |
+
|
264 |
+
margin = head(
|
265 |
+
n_classes = cfg.N_CLASSES,
|
266 |
+
s = 30,
|
267 |
+
m = adaptive_margin if cfg.adaptive_margin else 0.3,
|
268 |
+
k = cfg.subcenters or 1,
|
269 |
+
easy_margin = False,
|
270 |
+
name=f'head/{cfg.head}',
|
271 |
+
dtype='float32')
|
272 |
+
|
273 |
+
inp = tf.keras.layers.Input(shape = [*cfg.IMAGE_SIZE, 3], name = 'inp1')
|
274 |
+
label = tf.keras.layers.Input(shape = (), name = 'inp2')
|
275 |
+
if aux_arcface:
|
276 |
+
label2 = tf.keras.layers.Input(shape = (), name = 'inp3')
|
277 |
+
|
278 |
+
if cfg.arch_name.startswith('efnv1'):
|
279 |
+
x = EFN[cfg.arch_name](weights=cfg.pretrained, include_top=False)(inp)
|
280 |
+
if cfg.pool == 'flatten':
|
281 |
+
embed = tf.keras.layers.Flatten()(x)
|
282 |
+
elif cfg.pool == 'fc':
|
283 |
+
embed = tf.keras.layers.Flatten()(x)
|
284 |
+
embed = tf.keras.layers.Dropout(0.1)(embed)
|
285 |
+
embed = tf.keras.layers.Dense(1024)(embed)
|
286 |
+
elif cfg.pool == 'gem':
|
287 |
+
embed = GeMPoolingLayer(train_p=True)(x)
|
288 |
+
elif cfg.pool == 'concat':
|
289 |
+
embed = tf.keras.layers.concatenate([tf.keras.layers.GlobalAveragePooling2D()(x),
|
290 |
+
tf.keras.layers.GlobalAveragePooling2D()(x)])
|
291 |
+
elif cfg.pool == 'max':
|
292 |
+
embed = tf.keras.layers.GlobalMaxPooling2D()(x)
|
293 |
+
else:
|
294 |
+
embed = tf.keras.layers.GlobalAveragePooling2D()(x)
|
295 |
+
|
296 |
+
elif cfg.arch_name.startswith('efnv2'):
|
297 |
+
x = EFN[cfg.arch_name](input_shape=(None, None, 3), num_classes=0,
|
298 |
+
pretrained=cfg.pretrained)(inp)
|
299 |
+
if cfg.pool == 'flatten':
|
300 |
+
embed = tf.keras.layers.Flatten()(x)
|
301 |
+
elif cfg.pool == 'fc':
|
302 |
+
embed = tf.keras.layers.Flatten()(x)
|
303 |
+
embed = tf.keras.layers.Dropout(0.1)(embed)
|
304 |
+
embed = tf.keras.layers.Dense(1024)(embed)
|
305 |
+
elif cfg.pool == 'gem':
|
306 |
+
embed = GeMPoolingLayer(train_p=True)(x)
|
307 |
+
elif cfg.pool == 'concat':
|
308 |
+
embed = tf.keras.layers.concatenate([tf.keras.layers.GlobalAveragePooling2D()(x),
|
309 |
+
tf.keras.layers.GlobalAveragePooling2D()(x)])
|
310 |
+
elif cfg.pool == 'max':
|
311 |
+
embed = tf.keras.layers.GlobalMaxPooling2D()(x)
|
312 |
+
else:
|
313 |
+
embed = tf.keras.layers.GlobalAveragePooling2D()(x)
|
314 |
+
|
315 |
+
elif cfg.arch_name in TFHUB:
|
316 |
+
# tfhub models cannot be modified => Pooling cannot be changed!
|
317 |
+
url = TFHUB[cfg.arch_name]
|
318 |
+
model = hub.KerasLayer(url, trainable=True)
|
319 |
+
embed = model(inp)
|
320 |
+
#print(f"{cfg.arch_name} from tfhub")
|
321 |
+
assert cfg.pool in [None, False, 'avg', ''], 'tfhub model, no custom pooling supported!'
|
322 |
+
|
323 |
+
elif cfg.arch_name in tfimm.list_models(pretrained="timm"):
|
324 |
+
#print(f"{cfg.arch_name} from tfimm")
|
325 |
+
#embed = tfimm.create_model(cfg.arch_name, pretrained="timm", nb_classes=0)(inp)
|
326 |
+
embed = tfimm.create_model(cfg.arch_name, pretrained=None, nb_classes=0)(inp)
|
327 |
+
# create_model(nb_classes=0) includes pooling as last layer
|
328 |
+
|
329 |
+
if len(cfg.dropout_ps) > 0:
|
330 |
+
# Chris Deotte posted model code without Dropout/FC1 after pooling
|
331 |
+
embed = tf.keras.layers.Dropout(cfg.dropout_ps[0])(embed)
|
332 |
+
embed = tf.keras.layers.Dense(1024)(embed) # tunable embedding size
|
333 |
+
embed = tf.keras.layers.BatchNormalization()(embed) # missing in public notebooks
|
334 |
+
x = margin([embed, label])
|
335 |
+
|
336 |
+
output = tf.keras.layers.Softmax(dtype='float32', name='arc' if cfg.aux_loss else None)(x)
|
337 |
+
|
338 |
+
if cfg.aux_loss and aux_arcface:
|
339 |
+
# Use 2nd arcface head for species (aux loss)
|
340 |
+
head2 = ArcMarginProductSubCenter
|
341 |
+
margin2 = head(
|
342 |
+
n_classes = cfg.n_species,
|
343 |
+
s = 30,
|
344 |
+
m = 0.3,
|
345 |
+
k = 1,
|
346 |
+
easy_margin = False,
|
347 |
+
name=f'auxhead/{cfg.head}',
|
348 |
+
dtype='float32')
|
349 |
+
aux_features = margin2([embed, label2])
|
350 |
+
aux_output = tf.keras.layers.Softmax(dtype='float32', name='aux')(aux_features)
|
351 |
+
|
352 |
+
elif cfg.aux_loss:
|
353 |
+
aux_features = tf.keras.layers.Dense(cfg.n_species)(embed)
|
354 |
+
aux_output = tf.keras.layers.Softmax(dtype='float32', name='aux')(aux_features)
|
355 |
+
inputs = [inp, label, label2] if (cfg.aux_loss and aux_arcface) else [inp, label]
|
356 |
+
outputs = (output, aux_output) if cfg.aux_loss else [output]
|
357 |
+
|
358 |
+
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
|
359 |
+
embed_model = tf.keras.models.Model(inputs=inp, outputs=embed)
|
360 |
+
|
361 |
+
opt = tf.keras.optimizers.Adam(learning_rate=cfg.LR)
|
362 |
+
if cfg.FREEZE_BATCH_NORM:
|
363 |
+
freeze_BN(model)
|
364 |
+
|
365 |
+
return model, embed_model
|