detection-demo / utils.py
kevinconka's picture
Refactor flagged image counting logic
246a775
raw
history blame
3.1 kB
import time
import requests
from io import BytesIO
from dataclasses import dataclass
import numpy as np
import pandas as pd
from PIL import Image
import yolov5
from yolov5.utils.plots import Annotator, colors
import gradio as gr
from huggingface_hub import get_token
def load_model(model_path, img_size=640):
"""Load model from HuggingFace Hub."""
model = yolov5.load(model_path, hf_token=get_token())
model.img_size = img_size # add img_size attribute
return model
def load_image_from_url(url):
"""Load image from URL."""
if not url: # empty or None
return gr.Image(interactive=True)
try:
response = requests.get(url, timeout=5)
image = Image.open(BytesIO(response.content))
except Exception as e:
raise gr.Error("Unable to load image from URL") from e
return image.convert("RGB")
def inference(model, image):
"""Run inference on image and return annotated image."""
results = model(image, size=model.img_size)
annotator = Annotator(np.asarray(image))
for *box, _, cls in reversed(results.pred[0]):
# label = f'{model.names[int(cls)]} {conf:.2f}'
# print(f'{cls} {conf:.2f} {box}')
annotator.box_label(box, "", color=colors(cls, True))
return annotator.im
def load_badges(n):
"""Load badges."""
return f"""
<p style="display: flex">
<img alt="" src="https://img.shields.io/badge/SEA.AI-beta-blue">
&nbsp;
<img alt="" src="https://img.shields.io/badge/%F0%9F%96%BC%EF%B8%8F-{n}-green">
</p>
"""
@dataclass
class FlaggedCounter:
"""Count flagged images in dataset."""
dataset_name: str
headers: dict = None
def __post_init__(self):
self.API_URL = (
f"https://datasets-server.huggingface.co/size?dataset={self.dataset_name}"
)
self.trials = 10
if self.headers is None:
self.headers = {"Authorization": f"Bearer {get_token()}"}
def query(self):
"""Query API."""
response = requests.get(self.API_URL, headers=self.headers, timeout=5)
return response.json()
def from_query(self, data):
"""Count flagged images via API. Might be slow."""
for i in range(self.trials):
try:
data = self.query()
if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
print(f"[{i+1}/{self.trials}] {data}")
return data["size"]["dataset"]["num_rows"]
except Exception:
pass
print(f"[{i+1}/{self.trials}] {data}")
time.sleep(5)
return 0
def from_csv(self):
"""Count flagged images from CSV. Fast but relies on local files."""
dataset_name = self.dataset_name.split("/")[-1]
df = pd.read_csv(f"./flagged/{dataset_name}/data.csv")
return len(df)
def count(self):
"""Count flagged images."""
try:
return self.from_csv()
except FileNotFoundError:
return self.from_query(self.query())