File size: 2,543 Bytes
a378000
e921d65
 
 
a378000
e921d65
 
 
 
955daea
e921d65
 
 
a378000
955daea
e921d65
 
 
 
 
a378000
e921d65
 
 
 
 
 
 
 
 
 
 
a378000
e921d65
 
 
 
 
 
 
955daea
 
a378000
 
 
955daea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a378000
 
 
 
 
 
 
 
 
955daea
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import time
import requests
from io import BytesIO
import numpy as np
import pandas as pd
from PIL import Image
import yolov5
from yolov5.utils.plots import Annotator, colors
import gradio as gr
from huggingface_hub import get_token


def load_model(model_path, img_size=640):
    """Load model from HuggingFace Hub."""
    model = yolov5.load(model_path, hf_token=get_token())
    model.img_size = img_size  # add img_size attribute
    return model


def load_image_from_url(url):
    """Load image from URL."""
    if not url:  # empty or None
        return gr.Image(interactive=True)
    try:
        response = requests.get(url, timeout=5)
        image = Image.open(BytesIO(response.content))
    except Exception as e:
        raise gr.Error("Unable to load image from URL") from e
    return image.convert("RGB")


def inference(model, image):
    """Run inference on image and return annotated image."""
    results = model(image, size=model.img_size)
    annotator = Annotator(np.asarray(image))
    for *box, _, cls in reversed(results.pred[0]):
        # label = f'{model.names[int(cls)]} {conf:.2f}'
        # print(f'{cls} {conf:.2f} {box}')
        annotator.box_label(box, "", color=colors(cls, True))
    return annotator.im


def count_flagged_images_via_api(dataset_name, trials=10):
    """Count flagged images via API. Might be slow."""

    headers = {"Authorization": f"Bearer {get_token()}"}
    API_URL = f"https://datasets-server.huggingface.co/size?dataset={dataset_name}"

    def query():
        response = requests.get(API_URL, headers=headers, timeout=5)
        return response.json()

    for i in range(trials):
        try:
            data = query()
            if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
                print(f"[{i+1}/{trials}] {data}")
                return data["size"]["dataset"]["num_rows"]
        except Exception:
            pass
        print(f"[{i+1}/{trials}] {data}")
        time.sleep(5)

    return 0


def count_flagged_images_from_csv(dataset_name):
    """Count flagged images from CSV. Fast but relies on local files."""
    dataset_name = dataset_name.split("/")[-1]
    df = pd.read_csv(f"./flagged/{dataset_name}/data.csv")
    return len(df)


def load_badges(n):
    """Load badges."""
    return f"""
        <p style="display: flex">
        <img alt="" src="https://img.shields.io/badge/SEA.AI-beta-blue">
        &nbsp;
        <img alt="" src="https://img.shields.io/badge/%F0%9F%96%BC%EF%B8%8F-{n}-green">
        </p>
        """