detection-RGB / utils.py
kevinconka's picture
renamed badge
a49fee8
import time
import requests
from io import BytesIO
from urllib.parse import quote
from dataclasses import dataclass
import pandas as pd
from PIL import Image
import gradio as gr
from huggingface_hub import get_token
def check_image(image):
"""Check image."""
if image is None:
raise gr.Error("Oops! It looks like you forgot to upload an image.")
def load_image_from_url(url):
"""Load image from URL."""
if not url: # empty or None
return gr.Image(interactive=True)
try:
response = requests.get(url, timeout=5)
image = Image.open(BytesIO(response.content))
except Exception as e:
raise gr.Error("Unable to load image from URL") from e
return image.convert("RGB")
def load_badges(n):
"""Load badges."""
badges = [
"https://img.shields.io/badge/version-beta-blue",
f"https://img.shields.io/badge/{quote('๐Ÿ–ผ๏ธ')}{quote('๐Ÿšฉ')}-{n}-green",
]
return f"""
<p style="display: flex">
{"&nbsp".join([f'<img alt="" src="{badge}">' for badge in badges])}
</p>
"""
@dataclass
class FlaggedCounter:
"""Count flagged images in dataset."""
dataset_name: str
headers: dict = None
def __post_init__(self):
self.API_URL = (
f"https://datasets-server.huggingface.co/size?dataset={self.dataset_name}"
)
self.trials = 10
if self.headers is None:
self.headers = {"Authorization": f"Bearer {get_token()}"}
def query(self):
"""Query API."""
response = requests.get(self.API_URL, headers=self.headers, timeout=5)
return response.json()
def from_query(self, data):
"""Count flagged images via API. Might be slow."""
for i in range(self.trials):
try:
data = self.query()
if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
print(f"[{i+1}/{self.trials}] {data}")
return data["size"]["dataset"]["num_rows"]
except requests.exceptions.RequestException:
pass
print(f"[{i+1}/{self.trials}] {data}")
time.sleep(5)
return 0
def from_csv(self):
"""Count flagged images from CSV. Fast but relies on local files."""
dataset_name = self.dataset_name.split("/")[-1]
df = pd.read_csv(f"./flagged/{dataset_name}/data.csv")
return len(df)
def count(self):
"""Count flagged images."""
try:
return self.from_csv()
except FileNotFoundError:
return self.from_query(self.query())