File size: 7,882 Bytes
b9b435f
 
 
 
 
 
 
 
fa37760
c6447ed
6677709
fa37760
 
 
d8877a5
9ce9714
b9b435f
79e7d41
4f64be2
9ce9714
fa37760
 
b9b435f
fa37760
 
 
 
 
 
d8877a5
 
b9b435f
9ce9714
b9b435f
 
 
 
 
9ce9714
d8877a5
2ae9de4
 
 
 
 
 
 
 
 
9ce9714
2ae9de4
 
 
 
9ce9714
2ae9de4
 
 
 
 
d8877a5
 
 
2ae9de4
d8877a5
b9b435f
 
 
a962f25
 
6677709
5816e16
 
 
 
b9b435f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5816e16
 
 
b9b435f
 
 
 
 
 
fa37760
b9b435f
2ae9de4
fa37760
b9b435f
 
 
de293c0
d8877a5
de293c0
b9b435f
 
 
 
 
 
 
9ce9714
b9b435f
 
fa37760
 
b9b435f
fa37760
 
 
 
9ce9714
fa37760
 
 
 
9ce9714
 
fa37760
 
 
 
 
 
 
 
 
 
9ce9714
 
fa37760
 
 
 
 
 
 
 
 
9ce9714
56d4dc6
fa37760
c437056
 
 
 
 
 
 
 
 
 
0781862
 
5d13e00
c437056
 
 
 
 
 
 
 
 
 
 
 
0781862
 
 
c437056
 
 
 
 
fa37760
c437056
5d13e00
c437056
 
b9b435f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# If TF version is not understood by tfimm requirements, try this:
#try:
#    import tfimm
#except ModuleNotFoundError:
#    !pip install --no-deps tfimm timm
#    import timm
#    import tfimm

import os
from glob import glob
from shutil import rmtree
from pathlib import Path

import gradio as gr
from huggingface_hub import hf_hub_download
import matplotlib.image as mpimg
from yolov5 import detect
import numpy as np
from tensorflow.keras import backend as K
from utils import get_model, get_cfg, get_comp_embeddings, get_test_embedding, get_confidence


# YOLOv5 parameters
yolo_input_size = 384
versions = ('2_v108', '4_v109', '0_int6', '1_v110', '3_v111')
score_thr = 0.025
iou_thr = 0.6
max_det = 1
working = Path(os.getcwd())
modelbox = "yellowdolphin/happywhale-models"
checkpoint_files = [hf_hub_download(modelbox, f'yolov5_l6_{yolo_input_size}_fold{x}.pt') for x in versions]
image_root = working / 'images'
yolo_source = str(image_root / 'testimage.jpg')


# Individual identifier parameters
max_distance = 0.865
normalize_similarity = None  # test-train, None
threshold = 0.09951 if (normalize_similarity == 'test-train') else 0.6  # 0.381
rst_names = 'convnext_base_384_in22ft1k_colab220 efnv1b7_colab216 hub_efnv2xl_v73'.split()
use_fold = {
    'efnv1b7_colab216': 4,
    'efnv1b7_colab225': 1,
    'efnv1b7_colab197': 0,
    'efnv1b7_colab227': 5,
    'efnv1b7_v72': 6,
    'efnv1b7_colab229': 9,
    'efnv1b6_colab217': 5,
    'efnv1b6_colab218': 6,
    'hub_efnv2xl_colab221': 8,
    'hub_efnv2xl_v69': 2,
    'hub_efnv2xl_v73': 0,
    'efnv1b6_colab226': 2,
    'hub_efnv2l_v70': 3,
    'hub_efnv2l_colab200': 2,
    'hub_efnv2l_colab199': 1,
    'convnext_base_384_in22ft1k_v68': 0,
    'convnext_base_384_in22ft1k_colab220': 9,
    'convnext_base_384_in22ft1k_colab201': 3,  # new
}
cfg_files = [hf_hub_download(modelbox, f'{x}_config.json') for x in rst_names]
emb_files = [hf_hub_download(modelbox, f'{x}_emb.npz') for x in rst_names]
rst_files = [hf_hub_download(modelbox, f'{x}.h5') for x in rst_names]
use_folds = [use_fold[x] for x in rst_names]
n_models = len(rst_names)


def fast_yolo_crop(image):
    rmtree(working / 'labels', ignore_errors=True)
    rmtree(working / 'results_ensemble', ignore_errors=True)

    print("image:", type(image))
    print(image.shape)
    print("yolo_source:", yolo_source)
    print(type(yolo_source))
    mpimg.imsave(yolo_source, image)

    detect.run(weights=checkpoint_files[4:],
               source=yolo_source,
               data='data/dataset.yaml',
               imgsz=yolo_input_size,
               conf_thres=score_thr,
               iou_thres=iou_thr,
               max_det=max_det,
               save_txt=False,
               save_conf=False,
               save_crop=True,
               exist_ok=True,
               name=str(working / 'results_ensemble'))

    glob_pattern = f'{working}/results_ensemble/crops/*/{Path(yolo_source).name}'
    print("glob_pattern:", glob_pattern)
    cropped = sorted(glob(glob_pattern))
    assert len(cropped) == 1, f'{len(cropped)} maritime species detected'
    cropped = cropped[0]
    species = Path(cropped).parent.name
    cropped_image = mpimg.imread(cropped)
    return cropped_image, species.replace('_', ' ')


# Preload embeddings for known individuals
comp_embeddings = get_comp_embeddings(emb_files, use_folds)

# Preload embedding models, input sizes
K.clear_session()
embed_models, sizes = [], []
for cfg_file, rst_file, use_fold in zip(cfg_files, rst_files, use_folds):
    cfg = get_cfg(cfg_file)
    assert cfg.FOLD_TO_RUN == use_fold
    cfg.pretrained = None  # avoid weight downloads
    if isinstance(cfg.IMAGE_SIZE, int):
        cfg.IMAGE_SIZE = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)
    sizes.append(cfg.IMAGE_SIZE)
    model, embed_model = get_model(cfg)
    model.load_weights(rst_file)
    print(f"\nWeights loaded from {rst_file}")
    print(f"input_size {cfg.IMAGE_SIZE}, fold {cfg.FOLD_TO_RUN}, arch {cfg.arch_name}, ",
          f"DATASET {cfg.DATASET}, dropout_ps {cfg.dropout_ps}, subcenters {cfg.subcenters}")
    embed_models.append(embed_model)


def pred_fn(image, fake=False):
    if fake:
        x0, x1 = (int(f * image.shape[0]) for f in (0.2, 0.8))
        y0, y1 = (int(f * image.shape[1]) for f in (0.2, 0.8))
        cropped_image = image[x0:x1, y0:y1, :]
        response_str = "This looks like a common dolphin, but I have not seen this individual before (0.834 confidence).\n" \
                       "Go submit your photo on www.happywhale.com!"
        return cropped_image, response_str

    cropped_image, species = fast_yolo_crop(image)
    test_embedding = get_test_embedding(cropped_image, embed_models, sizes)

    cosine_similarity = np.dot(comp_embeddings, test_embedding[0]) / n_models
    cosine_distances = 1 - cosine_similarity
    normalized_distances = cosine_distances / max_distance
    normalized_similarities = 1 - normalized_distances

    min_similarity = normalized_similarities.min()
    max_similarity = normalized_similarities.max()
    confidence = get_confidence(max_similarity, threshold)

    print(f"Similarities: {min_similarity:.4f} ... {max_similarity:.4f}")
    print(f"Threshold: {threshold}")

    if max_similarity > threshold:
        response_str = f"This looks like a {species} I have seen before ({confidence:.3f} confidence).\n" \
                       "You might find its previous encounters on www.happywhale.com"
    else:
        response_str = f"This looks like a {species}, but I have not seen this individual before ({confidence:.3f} confidence).\n" \
                       "Go submit your photo on www.happywhale.com!"

    return cropped_image, response_str


examples = [str(image_root / f'negative{i:03d}.jpg') for i in range(3)]

description = """
Is it possible to identify and track individual marine mammals based on 
community photos, taken by tourist whale-watchers on their cameras or 
smartphones?

Researchers use [photographic identification](https://whalescientists.com/photo-id/) 
(photo-ID) of individual whales since 
decades to study their migration, population, and behavior. While this is a 
tedious and costly process, it is tempting to leverage the huge amount of 
image data collected by the whale-watching community and private encounters around 
the globe. Organizations like [WildMe](https://www.wildme.org) or 
[happywhale](https://www.happywhale.com) develop AI models for automated identification at
scale. To push the state-of-the-art, happywhale hosted two competitions on kaggle, 
the 2018 [Humpback Whale Identification](https://www.kaggle.com/c/humpback-whale-identification) 
and the 2022 [Happywhale](https://www.kaggle.com/competitions/happy-whale-and-dolphin) 
competition, which included 28 marine whale and dolphin species.

Top solutions used a two-step process of cropping the raw image using an 
image detector like [YOLOv5](https://pytorch.org/hub/ultralytics_yolov5) 
and presenting high-resolution crops to an identifier trained with an 
ArcFace-based loss function. The detector had to be fine-tuned on the 
competition images with auto- or manually generated labels.

Below you can test my solution (down-cut version) on your own images. 
The detector is an ensemble of five YOLOv5 models, the identifier ensembles three
models with EfficientNet-B7, EfficientNetV2-XL, and ConvNext-base backbone. 
You can find model code and training pipelines in the 
[DeepTrane](https://github.com/yellowdolphin/deeptrane) repository.
"""  # appears between title and input/output

article = """
"""  # appears below input/output

demo = gr.Interface(fn=pred_fn, inputs="image", outputs=["image", "text"],
                    examples=examples,
                    title='Happywhale: Individual Identification for Marine Mammals',
                    description=description,
                    article=None,)
demo.launch()