happywhale-demo / app.py
yellowdolphin's picture
fix urls
0781862
# If TF version is not understood by tfimm requirements, try this:
#try:
# import tfimm
#except ModuleNotFoundError:
# !pip install --no-deps tfimm timm
# import timm
# import tfimm
import os
from glob import glob
from shutil import rmtree
from pathlib import Path
import gradio as gr
from huggingface_hub import hf_hub_download
import matplotlib.image as mpimg
from yolov5 import detect
import numpy as np
from tensorflow.keras import backend as K
from utils import get_model, get_cfg, get_comp_embeddings, get_test_embedding, get_confidence
# YOLOv5 parameters
yolo_input_size = 384
versions = ('2_v108', '4_v109', '0_int6', '1_v110', '3_v111')
score_thr = 0.025
iou_thr = 0.6
max_det = 1
working = Path(os.getcwd())
modelbox = "yellowdolphin/happywhale-models"
checkpoint_files = [hf_hub_download(modelbox, f'yolov5_l6_{yolo_input_size}_fold{x}.pt') for x in versions]
image_root = working / 'images'
yolo_source = str(image_root / 'testimage.jpg')
# Individual identifier parameters
max_distance = 0.865
normalize_similarity = None # test-train, None
threshold = 0.09951 if (normalize_similarity == 'test-train') else 0.6 # 0.381
rst_names = 'convnext_base_384_in22ft1k_colab220 efnv1b7_colab216 hub_efnv2xl_v73'.split()
use_fold = {
'efnv1b7_colab216': 4,
'efnv1b7_colab225': 1,
'efnv1b7_colab197': 0,
'efnv1b7_colab227': 5,
'efnv1b7_v72': 6,
'efnv1b7_colab229': 9,
'efnv1b6_colab217': 5,
'efnv1b6_colab218': 6,
'hub_efnv2xl_colab221': 8,
'hub_efnv2xl_v69': 2,
'hub_efnv2xl_v73': 0,
'efnv1b6_colab226': 2,
'hub_efnv2l_v70': 3,
'hub_efnv2l_colab200': 2,
'hub_efnv2l_colab199': 1,
'convnext_base_384_in22ft1k_v68': 0,
'convnext_base_384_in22ft1k_colab220': 9,
'convnext_base_384_in22ft1k_colab201': 3, # new
}
cfg_files = [hf_hub_download(modelbox, f'{x}_config.json') for x in rst_names]
emb_files = [hf_hub_download(modelbox, f'{x}_emb.npz') for x in rst_names]
rst_files = [hf_hub_download(modelbox, f'{x}.h5') for x in rst_names]
use_folds = [use_fold[x] for x in rst_names]
n_models = len(rst_names)
def fast_yolo_crop(image):
rmtree(working / 'labels', ignore_errors=True)
rmtree(working / 'results_ensemble', ignore_errors=True)
print("image:", type(image))
print(image.shape)
print("yolo_source:", yolo_source)
print(type(yolo_source))
mpimg.imsave(yolo_source, image)
detect.run(weights=checkpoint_files[4:],
source=yolo_source,
data='data/dataset.yaml',
imgsz=yolo_input_size,
conf_thres=score_thr,
iou_thres=iou_thr,
max_det=max_det,
save_txt=False,
save_conf=False,
save_crop=True,
exist_ok=True,
name=str(working / 'results_ensemble'))
glob_pattern = f'{working}/results_ensemble/crops/*/{Path(yolo_source).name}'
print("glob_pattern:", glob_pattern)
cropped = sorted(glob(glob_pattern))
assert len(cropped) == 1, f'{len(cropped)} maritime species detected'
cropped = cropped[0]
species = Path(cropped).parent.name
cropped_image = mpimg.imread(cropped)
return cropped_image, species.replace('_', ' ')
# Preload embeddings for known individuals
comp_embeddings = get_comp_embeddings(emb_files, use_folds)
# Preload embedding models, input sizes
K.clear_session()
embed_models, sizes = [], []
for cfg_file, rst_file, use_fold in zip(cfg_files, rst_files, use_folds):
cfg = get_cfg(cfg_file)
assert cfg.FOLD_TO_RUN == use_fold
cfg.pretrained = None # avoid weight downloads
if isinstance(cfg.IMAGE_SIZE, int):
cfg.IMAGE_SIZE = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)
sizes.append(cfg.IMAGE_SIZE)
model, embed_model = get_model(cfg)
model.load_weights(rst_file)
print(f"\nWeights loaded from {rst_file}")
print(f"input_size {cfg.IMAGE_SIZE}, fold {cfg.FOLD_TO_RUN}, arch {cfg.arch_name}, ",
f"DATASET {cfg.DATASET}, dropout_ps {cfg.dropout_ps}, subcenters {cfg.subcenters}")
embed_models.append(embed_model)
def pred_fn(image, fake=False):
if fake:
x0, x1 = (int(f * image.shape[0]) for f in (0.2, 0.8))
y0, y1 = (int(f * image.shape[1]) for f in (0.2, 0.8))
cropped_image = image[x0:x1, y0:y1, :]
response_str = "This looks like a common dolphin, but I have not seen this individual before (0.834 confidence).\n" \
"Go submit your photo on www.happywhale.com!"
return cropped_image, response_str
cropped_image, species = fast_yolo_crop(image)
test_embedding = get_test_embedding(cropped_image, embed_models, sizes)
cosine_similarity = np.dot(comp_embeddings, test_embedding[0]) / n_models
cosine_distances = 1 - cosine_similarity
normalized_distances = cosine_distances / max_distance
normalized_similarities = 1 - normalized_distances
min_similarity = normalized_similarities.min()
max_similarity = normalized_similarities.max()
confidence = get_confidence(max_similarity, threshold)
print(f"Similarities: {min_similarity:.4f} ... {max_similarity:.4f}")
print(f"Threshold: {threshold}")
if max_similarity > threshold:
response_str = f"This looks like a {species} I have seen before ({confidence:.3f} confidence).\n" \
"You might find its previous encounters on www.happywhale.com"
else:
response_str = f"This looks like a {species}, but I have not seen this individual before ({confidence:.3f} confidence).\n" \
"Go submit your photo on www.happywhale.com!"
return cropped_image, response_str
examples = [str(image_root / f'negative{i:03d}.jpg') for i in range(3)]
description = """
Is it possible to identify and track individual marine mammals based on
community photos, taken by tourist whale-watchers on their cameras or
smartphones?
Researchers use [photographic identification](https://whalescientists.com/photo-id/)
(photo-ID) of individual whales since
decades to study their migration, population, and behavior. While this is a
tedious and costly process, it is tempting to leverage the huge amount of
image data collected by the whale-watching community and private encounters around
the globe. Organizations like [WildMe](https://www.wildme.org) or
[happywhale](https://www.happywhale.com) develop AI models for automated identification at
scale. To push the state-of-the-art, happywhale hosted two competitions on kaggle,
the 2018 [Humpback Whale Identification](https://www.kaggle.com/c/humpback-whale-identification)
and the 2022 [Happywhale](https://www.kaggle.com/competitions/happy-whale-and-dolphin)
competition, which included 28 marine whale and dolphin species.
Top solutions used a two-step process of cropping the raw image using an
image detector like [YOLOv5](https://pytorch.org/hub/ultralytics_yolov5)
and presenting high-resolution crops to an identifier trained with an
ArcFace-based loss function. The detector had to be fine-tuned on the
competition images with auto- or manually generated labels.
Below you can test my solution (down-cut version) on your own images.
The detector is an ensemble of five YOLOv5 models, the identifier ensembles three
models with EfficientNet-B7, EfficientNetV2-XL, and ConvNext-base backbone.
You can find model code and training pipelines in the
[DeepTrane](https://github.com/yellowdolphin/deeptrane) repository.
""" # appears between title and input/output
article = """
""" # appears below input/output
demo = gr.Interface(fn=pred_fn, inputs="image", outputs=["image", "text"],
examples=examples,
title='Happywhale: Individual Identification for Marine Mammals',
description=description,
article=None,)
demo.launch()