Spaces:
Sleeping
Sleeping
# If TF version is not understood by tfimm requirements, try this: | |
#try: | |
# import tfimm | |
#except ModuleNotFoundError: | |
# !pip install --no-deps tfimm timm | |
# import timm | |
# import tfimm | |
import os | |
from glob import glob | |
from shutil import rmtree | |
from pathlib import Path | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import matplotlib.image as mpimg | |
from yolov5 import detect | |
import numpy as np | |
from tensorflow.keras import backend as K | |
from utils import get_model, get_cfg, get_comp_embeddings, get_test_embedding, get_confidence | |
# YOLOv5 parameters | |
yolo_input_size = 384 | |
versions = ('2_v108', '4_v109', '0_int6', '1_v110', '3_v111') | |
score_thr = 0.025 | |
iou_thr = 0.6 | |
max_det = 1 | |
working = Path(os.getcwd()) | |
modelbox = "yellowdolphin/happywhale-models" | |
checkpoint_files = [hf_hub_download(modelbox, f'yolov5_l6_{yolo_input_size}_fold{x}.pt') for x in versions] | |
image_root = working / 'images' | |
yolo_source = str(image_root / 'testimage.jpg') | |
# Individual identifier parameters | |
max_distance = 0.865 | |
normalize_similarity = None # test-train, None | |
threshold = 0.09951 if (normalize_similarity == 'test-train') else 0.6 # 0.381 | |
rst_names = 'convnext_base_384_in22ft1k_colab220 efnv1b7_colab216 hub_efnv2xl_v73'.split() | |
use_fold = { | |
'efnv1b7_colab216': 4, | |
'efnv1b7_colab225': 1, | |
'efnv1b7_colab197': 0, | |
'efnv1b7_colab227': 5, | |
'efnv1b7_v72': 6, | |
'efnv1b7_colab229': 9, | |
'efnv1b6_colab217': 5, | |
'efnv1b6_colab218': 6, | |
'hub_efnv2xl_colab221': 8, | |
'hub_efnv2xl_v69': 2, | |
'hub_efnv2xl_v73': 0, | |
'efnv1b6_colab226': 2, | |
'hub_efnv2l_v70': 3, | |
'hub_efnv2l_colab200': 2, | |
'hub_efnv2l_colab199': 1, | |
'convnext_base_384_in22ft1k_v68': 0, | |
'convnext_base_384_in22ft1k_colab220': 9, | |
'convnext_base_384_in22ft1k_colab201': 3, # new | |
} | |
cfg_files = [hf_hub_download(modelbox, f'{x}_config.json') for x in rst_names] | |
emb_files = [hf_hub_download(modelbox, f'{x}_emb.npz') for x in rst_names] | |
rst_files = [hf_hub_download(modelbox, f'{x}.h5') for x in rst_names] | |
use_folds = [use_fold[x] for x in rst_names] | |
n_models = len(rst_names) | |
def fast_yolo_crop(image): | |
rmtree(working / 'labels', ignore_errors=True) | |
rmtree(working / 'results_ensemble', ignore_errors=True) | |
print("image:", type(image)) | |
print(image.shape) | |
print("yolo_source:", yolo_source) | |
print(type(yolo_source)) | |
mpimg.imsave(yolo_source, image) | |
detect.run(weights=checkpoint_files[4:], | |
source=yolo_source, | |
data='data/dataset.yaml', | |
imgsz=yolo_input_size, | |
conf_thres=score_thr, | |
iou_thres=iou_thr, | |
max_det=max_det, | |
save_txt=False, | |
save_conf=False, | |
save_crop=True, | |
exist_ok=True, | |
name=str(working / 'results_ensemble')) | |
glob_pattern = f'{working}/results_ensemble/crops/*/{Path(yolo_source).name}' | |
print("glob_pattern:", glob_pattern) | |
cropped = sorted(glob(glob_pattern)) | |
assert len(cropped) == 1, f'{len(cropped)} maritime species detected' | |
cropped = cropped[0] | |
species = Path(cropped).parent.name | |
cropped_image = mpimg.imread(cropped) | |
return cropped_image, species.replace('_', ' ') | |
# Preload embeddings for known individuals | |
comp_embeddings = get_comp_embeddings(emb_files, use_folds) | |
# Preload embedding models, input sizes | |
K.clear_session() | |
embed_models, sizes = [], [] | |
for cfg_file, rst_file, use_fold in zip(cfg_files, rst_files, use_folds): | |
cfg = get_cfg(cfg_file) | |
assert cfg.FOLD_TO_RUN == use_fold | |
cfg.pretrained = None # avoid weight downloads | |
if isinstance(cfg.IMAGE_SIZE, int): | |
cfg.IMAGE_SIZE = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE) | |
sizes.append(cfg.IMAGE_SIZE) | |
model, embed_model = get_model(cfg) | |
model.load_weights(rst_file) | |
print(f"\nWeights loaded from {rst_file}") | |
print(f"input_size {cfg.IMAGE_SIZE}, fold {cfg.FOLD_TO_RUN}, arch {cfg.arch_name}, ", | |
f"DATASET {cfg.DATASET}, dropout_ps {cfg.dropout_ps}, subcenters {cfg.subcenters}") | |
embed_models.append(embed_model) | |
def pred_fn(image, fake=False): | |
if fake: | |
x0, x1 = (int(f * image.shape[0]) for f in (0.2, 0.8)) | |
y0, y1 = (int(f * image.shape[1]) for f in (0.2, 0.8)) | |
cropped_image = image[x0:x1, y0:y1, :] | |
response_str = "This looks like a common dolphin, but I have not seen this individual before (0.834 confidence).\n" \ | |
"Go submit your photo on www.happywhale.com!" | |
return cropped_image, response_str | |
cropped_image, species = fast_yolo_crop(image) | |
test_embedding = get_test_embedding(cropped_image, embed_models, sizes) | |
cosine_similarity = np.dot(comp_embeddings, test_embedding[0]) / n_models | |
cosine_distances = 1 - cosine_similarity | |
normalized_distances = cosine_distances / max_distance | |
normalized_similarities = 1 - normalized_distances | |
min_similarity = normalized_similarities.min() | |
max_similarity = normalized_similarities.max() | |
confidence = get_confidence(max_similarity, threshold) | |
print(f"Similarities: {min_similarity:.4f} ... {max_similarity:.4f}") | |
print(f"Threshold: {threshold}") | |
if max_similarity > threshold: | |
response_str = f"This looks like a {species} I have seen before ({confidence:.3f} confidence).\n" \ | |
"You might find its previous encounters on www.happywhale.com" | |
else: | |
response_str = f"This looks like a {species}, but I have not seen this individual before ({confidence:.3f} confidence).\n" \ | |
"Go submit your photo on www.happywhale.com!" | |
return cropped_image, response_str | |
examples = [str(image_root / f'negative{i:03d}.jpg') for i in range(3)] | |
description = """ | |
Is it possible to identify and track individual marine mammals based on | |
community photos, taken by tourist whale-watchers on their cameras or | |
smartphones? | |
Researchers use [photographic identification](https://whalescientists.com/photo-id/) | |
(photo-ID) of individual whales since | |
decades to study their migration, population, and behavior. While this is a | |
tedious and costly process, it is tempting to leverage the huge amount of | |
image data collected by the whale-watching community and private encounters around | |
the globe. Organizations like [WildMe](https://www.wildme.org) or | |
[happywhale](https://www.happywhale.com) develop AI models for automated identification at | |
scale. To push the state-of-the-art, happywhale hosted two competitions on kaggle, | |
the 2018 [Humpback Whale Identification](https://www.kaggle.com/c/humpback-whale-identification) | |
and the 2022 [Happywhale](https://www.kaggle.com/competitions/happy-whale-and-dolphin) | |
competition, which included 28 marine whale and dolphin species. | |
Top solutions used a two-step process of cropping the raw image using an | |
image detector like [YOLOv5](https://pytorch.org/hub/ultralytics_yolov5) | |
and presenting high-resolution crops to an identifier trained with an | |
ArcFace-based loss function. The detector had to be fine-tuned on the | |
competition images with auto- or manually generated labels. | |
Below you can test my solution (down-cut version) on your own images. | |
The detector is an ensemble of five YOLOv5 models, the identifier ensembles three | |
models with EfficientNet-B7, EfficientNetV2-XL, and ConvNext-base backbone. | |
You can find model code and training pipelines in the | |
[DeepTrane](https://github.com/yellowdolphin/deeptrane) repository. | |
""" # appears between title and input/output | |
article = """ | |
""" # appears below input/output | |
demo = gr.Interface(fn=pred_fn, inputs="image", outputs=["image", "text"], | |
examples=examples, | |
title='Happywhale: Individual Identification for Marine Mammals', | |
description=description, | |
article=None,) | |
demo.launch() | |