import time
import requests
from io import BytesIO
from urllib.parse import quote
from dataclasses import dataclass
import pandas as pd
from PIL import Image
import gradio as gr
from huggingface_hub import get_token


def check_image(image):
    """Check image."""
    if image is None:
        raise gr.Error("Oops! It looks like you forgot to upload an image.")


def load_image_from_url(url):
    """Load image from URL."""
    if not url:  # empty or None
        return gr.Image(interactive=True)
    try:
        response = requests.get(url, timeout=5)
        image = Image.open(BytesIO(response.content))
    except Exception as e:
        raise gr.Error("Unable to load image from URL") from e
    return image.convert("RGB")


def load_badges(n):
    """Load badges."""
    badges = [
        "https://img.shields.io/badge/version-beta-blue",
        f"https://img.shields.io/badge/{quote('🖼️')}{quote('🚩')}-{n}-green",
    ]
    return f"""
        <p style="display: flex">
        {"&nbsp".join([f'<img alt="" src="{badge}">' for badge in badges])}
        </p>
        """


@dataclass
class FlaggedCounter:
    """Count flagged images in dataset."""

    dataset_name: str
    headers: dict = None

    def __post_init__(self):
        self.API_URL = (
            f"https://datasets-server.huggingface.co/size?dataset={self.dataset_name}"
        )
        self.trials = 10
        if self.headers is None:
            self.headers = {"Authorization": f"Bearer {get_token()}"}

    def query(self):
        """Query API."""
        response = requests.get(self.API_URL, headers=self.headers, timeout=5)
        return response.json()

    def from_query(self, data):
        """Count flagged images via API. Might be slow."""
        for i in range(self.trials):
            try:
                data = self.query()
                if "error" not in data and data["size"]["dataset"]["num_rows"] > 0:
                    print(f"[{i+1}/{self.trials}] {data}")
                    return data["size"]["dataset"]["num_rows"]
            except requests.exceptions.RequestException:
                pass
            print(f"[{i+1}/{self.trials}] {data}")
            time.sleep(5)

        return 0

    def from_csv(self):
        """Count flagged images from CSV. Fast but relies on local files."""
        dataset_name = self.dataset_name.split("/")[-1]
        df = pd.read_csv(f"./flagged/{dataset_name}/data.csv")
        return len(df)

    def count(self):
        """Count flagged images."""
        try:
            return self.from_csv()
        except FileNotFoundError:
            return self.from_query(self.query())