happywhale-demo / app.py
yellowdolphin's picture
replace ipynb magic
6677709
raw
history blame
4.93 kB
# If TF version is not understood by tfimm requirements, try this:
#try:
# import tfimm
#except ModuleNotFoundError:
# !pip install --no-deps tfimm timm
# import timm
# import tfimm
import os
import glob
from shutil import rmtree
from pathlib import Path
from subprocess import run
import json
import gradio as gr
from yolov5 import detect
from utils import get_models, get_cfg, get_embeddings, get_comp_embeddings, get_test_embedding
# YOLOv5 parameters
yolo_input_size = 384
versions = ('2_v108', '4_v109', '0_int6', '1_v110', '3_v111')
score_thr = 0.025
iou_thr = 0.6
max_det = 1
working = Path(os.getcwd())
modelbox = working / 'models'
checkpoint_files = [modelbox / f'yolov5_l6_{yolo_input_size}_fold{x}.pt' for x in versions]
image_root = working / 'images'
# Individual identifier parameters
embedding_size = 1024
n_images = 51033 + 27956
max_distance = 0.865
normalize_similarity = None # test-train, None
gamma = 0.4
threshold = 0.09951 if (normalize_similarity == 'test-train') else 0.6 # 0.381
knn = 300
emb_path = '/kaggle/input/happywhale-embeddings'
rst_path = '/kaggle/input/happywhale-models'
rst_files = sorted(glob(f'{rst_path}/*.h5'))
n_models = len(rst_files)
def fast_yolo_crop(image):
rmtree(working / 'labels')
rmtree(working / 'results_ensemble')
mpimg.imsave(yolo_source, image)
#print(f"\nInference on best {len(checkpoint_files[5:])} models with detect.py ...")
detect.run(weights=checkpoint_files[4:],
source=yolo_source,
data='data/dataset.yaml',
imgsz=yolo_input_size,
conf_thres=score_thr,
iou_thres=iou_thr,
max_det=max_det,
save_txt=False,
save_conf=False,
save_crop=True,
exist_ok=True,
name=str(working / 'results_ensemble'))
#print(f"YOLOv5 inference finished in {(perf_counter() - t0) / 60:.2f} min")
cropped = sorted(glob(f'{working}/results_ensemble/crops/*/{Path(yolo_source).name}'))
assert len(cropped) == 1, f'{len(cropped)} maritime species detected'
cropped = cropped[0]
species = Path(cropped).parent.name
cropped_image = mpimg.imread(cropped)
return cropped_image, species.replace('_', ' ')
# Preload embeddings for known individuals
comp_embeddings = get_comp_embeddings(rst_files)
# Preload embedding models, input sizes
K.clear_session()
embed_models, sizes = [], []
for rst_file in rst_files:
cfg = get_cfg(rst_file)
npz_file = Path(rst_file.replace('.h5', '_emb.npz')).name
assert cfg.FOLD_TO_RUN == use_fold[npz_file]
cfg.pretrained = None # avoid weight downloads
if isinstance(cfg.IMAGE_SIZE, int):
cfg.IMAGE_SIZE = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)
sizes.append(cfg.IMAGE_SIZE)
model, embed_model = get_model(cfg)
model.load_weights(rst_file)
print(f"\nWeights loaded from {rst_file}")
print(f"input_size {scaled_img.shape[:2]}, fold {cfg.FOLD_TO_RUN}, arch {cfg.arch_name}, ",
f"DATASET {cfg.DATASET}, dropout_ps {cfg.dropout_ps}, subcenters {cfg.subcenters}")
embed_models.append(embed_model)
def pred_fn(image, fake=False):
if fake:
x0, x1 = (int(f * image.shape[0]) for f in (0.2, 0.8))
y0, y1 = (int(f * image.shape[1]) for f in (0.2, 0.8))
cropped_image = image[x0:x1, y0:y1, :]
response_str = f"This looks like a common dolphin, but I have not seen this individual before (0.834 confidence).\n" \
"Go submit your photo on www.happywhale.com!"
return cropped_image, response_str
cropped_image, species = fast_yolo_crop(image)
test_embedding = get_test_embedding(embed_models, sizes)
cosine_similarity = np.dot(comp_embeddings, test_embedding[0]) / n_models
cosine_distances = 1 - cosine_similarity
normalized_distances = cosine_distances / max_distance
normalized_similarities = 1 - normalized_distances
min_similarity = normalized_similarities.min()
max_similarity = normalized_similarities.max()
confidence = get_confidence(max_similarity, threshold)
print(f"Similarities: {min_similarity:.4f} ... {max_similarity:.4f}")
print(f"Threshold:", threshold)
if max_similarity > threshold:
response_str = f"This looks like a {species} I have seen before ({confidence:.3f} confidence).\n" \
"You might find its previous encounters on www.happywhale.com"
else:
response_str = f"This looks like a {species}, but I have not seen this individual before ({confidence:.3f} confidence).\n" \
"Go submit your photo on www.happywhale.com!"
return cropped_image, response_str
examples = [str(image_root / f'negative{i:03d}') for i in range(3)]
demo = gr.Interface(fn=pred_fn, inputs="image", outputs=["image", "text"],
examples=examples)
demo.launch()