File size: 5,180 Bytes
b9b435f
 
 
 
 
 
 
 
fa37760
b9b435f
6677709
fa37760
 
b9b435f
fa37760
 
d8877a5
b9b435f
79e7d41
d8877a5
fa37760
 
b9b435f
fa37760
 
 
 
 
 
d8877a5
 
b9b435f
 
 
 
 
 
 
 
 
 
 
d8877a5
 
 
 
 
b9b435f
 
 
6677709
 
 
b9b435f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa37760
b9b435f
d8877a5
fa37760
b9b435f
 
 
d8877a5
 
b9b435f
 
 
 
 
 
 
 
 
 
 
fa37760
 
b9b435f
fa37760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9b435f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# If TF version is not understood by tfimm requirements, try this:
#try:
#    import tfimm
#except ModuleNotFoundError:
#    !pip install --no-deps tfimm timm
#    import timm
#    import tfimm

import os
import glob
from shutil import rmtree
from pathlib import Path
from subprocess import run
import json

import gradio as gr
from huggingface_hub import hf_hub_download
from yolov5 import detect
import numpy as np
from utils import get_model, get_cfg, get_embeddings, get_comp_embeddings, get_test_embedding


# YOLOv5 parameters
yolo_input_size = 384
versions = ('2_v108', '4_v109', '0_int6', '1_v110', '3_v111')
score_thr = 0.025
iou_thr = 0.6
max_det = 1
working = Path(os.getcwd())
modelbox = "yellowdolphin/happywhale-models"
checkpoint_files = [hf_hub_download(modelbox, f'yolov5_l6_{yolo_input_size}_fold{x}.pt') for x in versions]
image_root = working / 'images'


# Individual identifier parameters
embedding_size = 1024
n_images = 51033 + 27956
max_distance = 0.865
normalize_similarity = None  # test-train, None
gamma = 0.4
threshold = 0.09951 if (normalize_similarity == 'test-train') else 0.6 # 0.381
knn = 300
rst_names = 'convnext_base_384_in22ft1k_colab220 efnv1b7_colab216 hub_efnv2xl_v73'.split()
cfg_files = [hf_hub_download(modelbox, f'{x}_config.json') for x in rst_names]
emb_files = [hf_hub_download(modelbox, f'{x}_emb.npz') for x in rst_names]
rst_files = [hf_hub_download(modelbox, f'{x}.h5') for x in rst_names]
n_models = len(rst_names)


def fast_yolo_crop(image):
    rmtree(working / 'labels')
    rmtree(working / 'results_ensemble')

    mpimg.imsave(yolo_source, image)

    #print(f"\nInference on best {len(checkpoint_files[5:])} models with detect.py ...")
    detect.run(weights=checkpoint_files[4:],
               source=yolo_source,
               data='data/dataset.yaml',
               imgsz=yolo_input_size,
               conf_thres=score_thr,
               iou_thres=iou_thr,
               max_det=max_det,
               save_txt=False,
               save_conf=False,
               save_crop=True,
               exist_ok=True,
               name=str(working / 'results_ensemble'))

    #print(f"YOLOv5 inference finished in {(perf_counter() - t0) / 60:.2f} min")
    cropped = sorted(glob(f'{working}/results_ensemble/crops/*/{Path(yolo_source).name}'))
    assert len(cropped) == 1, f'{len(cropped)} maritime species detected'
    cropped = cropped[0]
    species = Path(cropped).parent.name
    cropped_image = mpimg.imread(cropped)
    return cropped_image, species.replace('_', ' ')


# Preload embeddings for known individuals
comp_embeddings = get_comp_embeddings(emb_files)

# Preload embedding models, input sizes
K.clear_session()
embed_models, sizes = [], []
for cfg_file, rst_file, npz_file in zip (cfg_files, rst_files, emb_files):
    cfg = get_cfg(cfg_file)
    assert cfg.FOLD_TO_RUN == use_fold[npz_file]
    cfg.pretrained = None  # avoid weight downloads
    if isinstance(cfg.IMAGE_SIZE, int):
        cfg.IMAGE_SIZE = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)
    sizes.append(cfg.IMAGE_SIZE)
    model, embed_model = get_model(cfg)
    model.load_weights(rst_file)
    print(f"\nWeights loaded from {rst_file}")
    print(f"input_size {scaled_img.shape[:2]}, fold {cfg.FOLD_TO_RUN}, arch {cfg.arch_name}, ",
          f"DATASET {cfg.DATASET}, dropout_ps {cfg.dropout_ps}, subcenters {cfg.subcenters}")
    embed_models.append(embed_model)


def pred_fn(image, fake=False):
    if fake:
        x0, x1 = (int(f * image.shape[0]) for f in (0.2, 0.8))
        y0, y1 = (int(f * image.shape[1]) for f in (0.2, 0.8))
        cropped_image = image[x0:x1, y0:y1, :]
        response_str = f"This looks like a common dolphin, but I have not seen this individual before (0.834 confidence).\n" \
                       "Go submit your photo on www.happywhale.com!"
        return cropped_image, response_str

    cropped_image, species = fast_yolo_crop(image)
    test_embedding = get_test_embedding(embed_models, sizes)
    
    cosine_similarity = np.dot(comp_embeddings, test_embedding[0]) / n_models
    cosine_distances = 1 - cosine_similarity
    normalized_distances = cosine_distances / max_distance
    normalized_similarities = 1 - normalized_distances

    min_similarity = normalized_similarities.min()
    max_similarity = normalized_similarities.max()
    confidence = get_confidence(max_similarity, threshold)

    print(f"Similarities: {min_similarity:.4f} ... {max_similarity:.4f}")
    print(f"Threshold:", threshold)
    
    if max_similarity > threshold:
        response_str = f"This looks like a {species} I have seen before ({confidence:.3f} confidence).\n" \
                       "You might find its previous encounters on www.happywhale.com"
    else:
        response_str = f"This looks like a {species}, but I have not seen this individual before ({confidence:.3f} confidence).\n" \
                       "Go submit your photo on www.happywhale.com!"

    return cropped_image, response_str

examples = [str(image_root / f'negative{i:03d}') for i in range(3)]

demo = gr.Interface(fn=pred_fn, inputs="image", outputs=["image", "text"],
                    examples=examples)
demo.launch()