detection-demo / utils.py
kevinconka's picture
Refactor app.py and utils.py
a378000
raw
history blame
2.54 kB
import time
import requests
from io import BytesIO
import numpy as np
import pandas as pd
from PIL import Image
import yolov5
from yolov5.utils.plots import Annotator, colors
import gradio as gr
from huggingface_hub import get_token
def load_model(model_path, img_size=640):
"""Load model from HuggingFace Hub."""
model = yolov5.load(model_path, hf_token=get_token())
model.img_size = img_size # add img_size attribute
return model
def load_image_from_url(url):
"""Load image from URL."""
if not url: # empty or None
return gr.Image(interactive=True)
try:
response = requests.get(url, timeout=5)
image = Image.open(BytesIO(response.content))
except Exception as e:
raise gr.Error("Unable to load image from URL") from e
return image.convert("RGB")
def inference(model, image):
"""Run inference on image and return annotated image."""
results = model(image, size=model.img_size)
annotator = Annotator(np.asarray(image))
for *box, _, cls in reversed(results.pred[0]):
# label = f'{model.names[int(cls)]} {conf:.2f}'
# print(f'{cls} {conf:.2f} {box}')
annotator.box_label(box, "", color=colors(cls, True))
return annotator.im
def count_flagged_images_via_api(dataset_name, trials=10):
"""Count flagged images via API. Might be slow."""
headers = {"Authorization": f"Bearer {get_token()}"}
API_URL = f"https://datasets-server.huggingface.co/size?dataset={dataset_name}"
def query():
response = requests.get(API_URL, headers=headers, timeout=5)
return response.json()
for i in range(trials):
try:
data = query()
if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
print(f"[{i+1}/{trials}] {data}")
return data["size"]["dataset"]["num_rows"]
except Exception:
pass
print(f"[{i+1}/{trials}] {data}")
time.sleep(5)
return 0
def count_flagged_images_from_csv(dataset_name):
"""Count flagged images from CSV. Fast but relies on local files."""
dataset_name = dataset_name.split("/")[-1]
df = pd.read_csv(f"./flagged/{dataset_name}/data.csv")
return len(df)
def load_badges(n):
"""Load badges."""
return f"""
<p style="display: flex">
<img alt="" src="https://img.shields.io/badge/SEA.AI-beta-blue">
&nbsp;
<img alt="" src="https://img.shields.io/badge/%F0%9F%96%BC%EF%B8%8F-{n}-green">
</p>
"""