happywhale-demo / app.py
yellowdolphin's picture
import numpy
79e7d41
raw
history blame
5.18 kB
# If TF version is not understood by tfimm requirements, try this:
#try:
# import tfimm
#except ModuleNotFoundError:
# !pip install --no-deps tfimm timm
# import timm
# import tfimm
import os
import glob
from shutil import rmtree
from pathlib import Path
from subprocess import run
import json
import gradio as gr
from huggingface_hub import hf_hub_download
from yolov5 import detect
import numpy as np
from utils import get_model, get_cfg, get_embeddings, get_comp_embeddings, get_test_embedding
# YOLOv5 parameters
yolo_input_size = 384
versions = ('2_v108', '4_v109', '0_int6', '1_v110', '3_v111')
score_thr = 0.025
iou_thr = 0.6
max_det = 1
working = Path(os.getcwd())
modelbox = "yellowdolphin/happywhale-models"
checkpoint_files = [hf_hub_download(modelbox, f'yolov5_l6_{yolo_input_size}_fold{x}.pt') for x in versions]
image_root = working / 'images'
# Individual identifier parameters
embedding_size = 1024
n_images = 51033 + 27956
max_distance = 0.865
normalize_similarity = None # test-train, None
gamma = 0.4
threshold = 0.09951 if (normalize_similarity == 'test-train') else 0.6 # 0.381
knn = 300
rst_names = 'convnext_base_384_in22ft1k_colab220 efnv1b7_colab216 hub_efnv2xl_v73'.split()
cfg_files = [hf_hub_download(modelbox, f'{x}_config.json') for x in rst_names]
emb_files = [hf_hub_download(modelbox, f'{x}_emb.npz') for x in rst_names]
rst_files = [hf_hub_download(modelbox, f'{x}.h5') for x in rst_names]
n_models = len(rst_names)
def fast_yolo_crop(image):
rmtree(working / 'labels')
rmtree(working / 'results_ensemble')
mpimg.imsave(yolo_source, image)
#print(f"\nInference on best {len(checkpoint_files[5:])} models with detect.py ...")
detect.run(weights=checkpoint_files[4:],
source=yolo_source,
data='data/dataset.yaml',
imgsz=yolo_input_size,
conf_thres=score_thr,
iou_thres=iou_thr,
max_det=max_det,
save_txt=False,
save_conf=False,
save_crop=True,
exist_ok=True,
name=str(working / 'results_ensemble'))
#print(f"YOLOv5 inference finished in {(perf_counter() - t0) / 60:.2f} min")
cropped = sorted(glob(f'{working}/results_ensemble/crops/*/{Path(yolo_source).name}'))
assert len(cropped) == 1, f'{len(cropped)} maritime species detected'
cropped = cropped[0]
species = Path(cropped).parent.name
cropped_image = mpimg.imread(cropped)
return cropped_image, species.replace('_', ' ')
# Preload embeddings for known individuals
comp_embeddings = get_comp_embeddings(emb_files)
# Preload embedding models, input sizes
K.clear_session()
embed_models, sizes = [], []
for cfg_file, rst_file, npz_file in zip (cfg_files, rst_files, emb_files):
cfg = get_cfg(cfg_file)
assert cfg.FOLD_TO_RUN == use_fold[npz_file]
cfg.pretrained = None # avoid weight downloads
if isinstance(cfg.IMAGE_SIZE, int):
cfg.IMAGE_SIZE = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)
sizes.append(cfg.IMAGE_SIZE)
model, embed_model = get_model(cfg)
model.load_weights(rst_file)
print(f"\nWeights loaded from {rst_file}")
print(f"input_size {scaled_img.shape[:2]}, fold {cfg.FOLD_TO_RUN}, arch {cfg.arch_name}, ",
f"DATASET {cfg.DATASET}, dropout_ps {cfg.dropout_ps}, subcenters {cfg.subcenters}")
embed_models.append(embed_model)
def pred_fn(image, fake=False):
if fake:
x0, x1 = (int(f * image.shape[0]) for f in (0.2, 0.8))
y0, y1 = (int(f * image.shape[1]) for f in (0.2, 0.8))
cropped_image = image[x0:x1, y0:y1, :]
response_str = f"This looks like a common dolphin, but I have not seen this individual before (0.834 confidence).\n" \
"Go submit your photo on www.happywhale.com!"
return cropped_image, response_str
cropped_image, species = fast_yolo_crop(image)
test_embedding = get_test_embedding(embed_models, sizes)
cosine_similarity = np.dot(comp_embeddings, test_embedding[0]) / n_models
cosine_distances = 1 - cosine_similarity
normalized_distances = cosine_distances / max_distance
normalized_similarities = 1 - normalized_distances
min_similarity = normalized_similarities.min()
max_similarity = normalized_similarities.max()
confidence = get_confidence(max_similarity, threshold)
print(f"Similarities: {min_similarity:.4f} ... {max_similarity:.4f}")
print(f"Threshold:", threshold)
if max_similarity > threshold:
response_str = f"This looks like a {species} I have seen before ({confidence:.3f} confidence).\n" \
"You might find its previous encounters on www.happywhale.com"
else:
response_str = f"This looks like a {species}, but I have not seen this individual before ({confidence:.3f} confidence).\n" \
"Go submit your photo on www.happywhale.com!"
return cropped_image, response_str
examples = [str(image_root / f'negative{i:03d}') for i in range(3)]
demo = gr.Interface(fn=pred_fn, inputs="image", outputs=["image", "text"],
examples=examples)
demo.launch()