File size: 3,103 Bytes
a378000
e921d65
 
246a775
e921d65
a378000
e921d65
 
 
 
955daea
e921d65
 
 
a378000
955daea
e921d65
 
 
 
 
a378000
e921d65
 
 
 
 
 
 
 
 
 
 
a378000
e921d65
 
 
 
 
 
 
955daea
 
a378000
 
955daea
 
 
 
 
 
 
246a775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import time
import requests
from io import BytesIO
from dataclasses import dataclass
import numpy as np
import pandas as pd
from PIL import Image
import yolov5
from yolov5.utils.plots import Annotator, colors
import gradio as gr
from huggingface_hub import get_token


def load_model(model_path, img_size=640):
    """Load model from HuggingFace Hub."""
    model = yolov5.load(model_path, hf_token=get_token())
    model.img_size = img_size  # add img_size attribute
    return model


def load_image_from_url(url):
    """Load image from URL."""
    if not url:  # empty or None
        return gr.Image(interactive=True)
    try:
        response = requests.get(url, timeout=5)
        image = Image.open(BytesIO(response.content))
    except Exception as e:
        raise gr.Error("Unable to load image from URL") from e
    return image.convert("RGB")


def inference(model, image):
    """Run inference on image and return annotated image."""
    results = model(image, size=model.img_size)
    annotator = Annotator(np.asarray(image))
    for *box, _, cls in reversed(results.pred[0]):
        # label = f'{model.names[int(cls)]} {conf:.2f}'
        # print(f'{cls} {conf:.2f} {box}')
        annotator.box_label(box, "", color=colors(cls, True))
    return annotator.im


def load_badges(n):
    """Load badges."""
    return f"""
        <p style="display: flex">
        <img alt="" src="https://img.shields.io/badge/SEA.AI-beta-blue">
        &nbsp;
        <img alt="" src="https://img.shields.io/badge/%F0%9F%96%BC%EF%B8%8F-{n}-green">
        </p>
        """


@dataclass
class FlaggedCounter:
    """Count flagged images in dataset."""

    dataset_name: str
    headers: dict = None

    def __post_init__(self):
        self.API_URL = (
            f"https://datasets-server.huggingface.co/size?dataset={self.dataset_name}"
        )
        self.trials = 10
        if self.headers is None:
            self.headers = {"Authorization": f"Bearer {get_token()}"}

    def query(self):
        """Query API."""
        response = requests.get(self.API_URL, headers=self.headers, timeout=5)
        return response.json()

    def from_query(self, data):
        """Count flagged images via API. Might be slow."""
        for i in range(self.trials):
            try:
                data = self.query()
                if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
                    print(f"[{i+1}/{self.trials}] {data}")
                    return data["size"]["dataset"]["num_rows"]
            except Exception:
                pass
            print(f"[{i+1}/{self.trials}] {data}")
            time.sleep(5)

        return 0

    def from_csv(self):
        """Count flagged images from CSV. Fast but relies on local files."""
        dataset_name = self.dataset_name.split("/")[-1]
        df = pd.read_csv(f"./flagged/{dataset_name}/data.csv")
        return len(df)

    def count(self):
        """Count flagged images."""
        try:
            return self.from_csv()
        except FileNotFoundError:
            return self.from_query(self.query())