import os
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum

import requests
import wandb
from cachetools import TTLCache, cached
import wandb.apis.public as wapi
from pydantic import BaseModel
from substrateinterface import Keypair

from chain_data import VALIDATOR_IDENTITIES, sync_chain
from src import TIMEZONE, Key
from src.chain_data import get_neurons

WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]

START_DATE = datetime(2025, 2, 12)
OFFSET_DAYS = 0

BLACKLIST_ENDPOINT = "https://edge-inputs.api.wombo.ai/blacklist"
DUPLICATE_SUBMISSIONS_ENDPOINT = "https://edge-inputs.api.wombo.ai/duplicate_submissions"


class DuplicateSubmission(BaseModel):
    hotkey: Key
    url: str
    revision: str
    copy_of: str


class SafeSubmissions(BaseModel):
    hotkey: Key
    url: str
    revision: str


class DuplicateSelection(BaseModel):
    safe_submissions: list[SafeSubmissions]
    duplicate_submissions: list[DuplicateSubmission]

class Blacklist(BaseModel):
    coldkeys: set[Key]
    hotkeys: set[Key]
    duplicate_selection: DuplicateSelection


class BenchmarkStatus(Enum):
    NOT_STARTED = ("Not Started", "orange", False)
    IN_PROGRESS = ("In Progress", "orange", False)
    FINISHED = ("Finished", "springgreen", False)
    INITIALIZING = ("Initializing", "orange", False)
    STOPPED = ("Stopped", "red", True)
    CRASHED = ("Crashed", "red", True)
    FAILED = ("Failed", "red", True)
    UNKNOWN = ("Unknown", "red", True)

    def name(self):
        return self.value[0]

    def color(self):
        return self.value[1]

    def failed(self):
        return self.value[2]


@dataclass
class Metrics:
    generation_time: float
    vram_used: float
    watts_used: float
    load_time: float
    size: int
    ram_used: float


@dataclass
class SubmissionInfo:
    uid: int
    repository: str
    revision: str
    block: int


@dataclass
class Submission:
    info: SubmissionInfo
    metrics: Metrics
    average_similarity: float
    min_similarity: float
    score: float


@dataclass
class Run:
    start_date: datetime
    version: str
    uid: int
    name: str
    hotkey: str
    status: BenchmarkStatus
    average_benchmarking_time: float
    step: int
    eta: int
    baseline_metrics: Metrics | None
    total_submissions: int
    submissions: dict[Key, Submission]
    invalid_submissions: set[Key]


RUNS: dict[str, list[Run]] = {}


def _is_valid_run(run: wapi.Run, version: str) -> bool:
    required_config_keys = ["hotkey", "uid", "signature"]

    for key in required_config_keys:
        if key not in run.config:
            return False

    validator_hotkey = run.config["hotkey"]

    signing_message = f"{version}:{validator_hotkey}"

    return Keypair(validator_hotkey).verify(signing_message, run.config["signature"])


def _date_from_run(run: wapi.Run) -> datetime:
    return datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone(TIMEZONE)


def _status_from_run(run: wapi.Run) -> BenchmarkStatus:
    match run.state:
        case "finished":
            return BenchmarkStatus.STOPPED
        case "crashed":
            return BenchmarkStatus.CRASHED
        case "failed":
            return BenchmarkStatus.FAILED
        case "running":
            if "benchmarking_state" in run.summary:
                states = list(BenchmarkStatus)
                return states[int(run.summary["benchmarking_state"])]
            else:
                return BenchmarkStatus.INITIALIZING
        case _:
            return BenchmarkStatus.UNKNOWN


def _add_runs(wandb_runs: list[wapi.Run]):
    for wandb_run in wandb_runs:
        version = wandb_run.tags[1][8:]
        if not _is_valid_run(wandb_run, version):
            continue

        metrics = wandb_run.summary

        submission_info: dict[Key, SubmissionInfo] = {}
        submissions: dict[Key, Submission] = {}
        invalid_submissions: set[Key] = set()

        baseline_metrics: Metrics | None = None
        if "baseline" in metrics:
            baseline = metrics["baseline"]
            baseline_metrics = Metrics(
                generation_time=float(baseline["generation_time"]),
                vram_used=float(baseline["vram_used"]),
                ram_used=float(baseline.get("ram_used", 0)),
                watts_used=float(baseline["watts_used"]),
                load_time=float(baseline["load_time"]),
                size=int(baseline["size"]),
            )

        if "submissions" in metrics:
            for hotkey, submission in metrics["submissions"].items():
                neuron = get_neurons().get(hotkey)
                if not neuron:
                    continue
                submission_info[hotkey] = SubmissionInfo(
                    uid=neuron.uid,
                    repository=submission["repository_info"]["url"],
                    revision=submission["repository_info"]["revision"],
                    block=submission["block"],
                )

        if "benchmarks" in metrics:
            for hotkey, benchmark in metrics["benchmarks"].items():
                benchmark_metrics = benchmark["metrics"]
                if hotkey not in submission_info:
                    continue
                scores = metrics["scores"]
                if hotkey not in scores:
                    continue
                submissions[hotkey] = Submission(
                    info=submission_info[hotkey],
                    metrics=Metrics(
                        generation_time=float(benchmark_metrics["generation_time"]),
                        vram_used=float(benchmark_metrics["vram_used"]),
                        ram_used=float(benchmark_metrics.get("ram_used", 0)),
                        watts_used=float(benchmark_metrics["watts_used"]),
                        load_time=float(benchmark_metrics["load_time"]),
                        size=int(benchmark_metrics["size"]),
                    ),
                    average_similarity=float(benchmark["average_similarity"]),
                    min_similarity=float(benchmark["min_similarity"]),
                    score=float(scores[hotkey]),
                )

        if "invalid_submissions" in metrics:
            try:
                for hotkey in metrics["invalid_submissions"]:
                    invalid_submissions.add(hotkey)
            except KeyError:
                ...

        status = _status_from_run(wandb_run)
        uid = int(wandb_run.config["uid"])
        hotkey = wandb_run.config["hotkey"]
        date = _date_from_run(wandb_run)
        id = wandb_run.id
        average_benchmarking_time = float(wandb_run.summary["average_benchmarking_time"]) if "average_benchmarking_time" in wandb_run.summary else 0

        # Get num_gpus from metrics, default to 1 if not found
        num_gpus = int(metrics.get("num_gpus", 1))

        # Update ETA calculation to account for GPUs
        eta_calculation = (
            max(
                int(average_benchmarking_time * (len(submission_info) - len(submissions) - len(invalid_submissions))) if average_benchmarking_time else 0,
                0
            ) // num_gpus  # Divide by number of GPUs
            if status != BenchmarkStatus.FINISHED else 0
        )

        run = Run(
            start_date=date,
            version=version,
            uid=uid,
            name=VALIDATOR_IDENTITIES.get(hotkey, f"{hotkey[:6]}..."),
            hotkey=hotkey,
            status=status,
            average_benchmarking_time=average_benchmarking_time,
            step=int(metrics["step"]),
            eta=eta_calculation,
            baseline_metrics=baseline_metrics,
            total_submissions=len(submission_info),
            submissions=submissions,
            invalid_submissions=invalid_submissions,
        )

        if id not in RUNS:
            RUNS[id] = [run]
        else:
            present = False
            for i, existing_run in enumerate(RUNS[id]):
                if existing_run.uid == run.uid:
                    RUNS[id][i] = run
                    present = True
                    break

            if not present:
                RUNS[id].append(run)


def _get_contest_start() -> datetime:
    now = datetime.now(TIMEZONE)
    today = now.replace(hour=0, minute=0, second=0, microsecond=0)
    if now.hour < 12:
        today -= timedelta(days=1)
    return today


def _fetch_history(wandb_api: wandb.Api):
    wandb_runs = wandb_api.runs(
        WANDB_RUN_PATH,
        filters={"config.type": "validator", "created_at": {'$gt': str(START_DATE)}},
        order="+created_at",
    )
    _add_runs(wandb_runs)


def _fetch_current_runs(wandb_api: wandb.Api):
    today = _get_contest_start()
    wandb_runs = wandb_api.runs(
        WANDB_RUN_PATH,
        filters={"config.type": "validator", "created_at": {'$gt': str(today)}},
        order="+created_at",
    )
    _add_runs(wandb_runs)

@cached(cache=TTLCache(maxsize=1, ttl=300))
def get_blacklisted_keys() -> Blacklist:
    response = requests.get(BLACKLIST_ENDPOINT)
    response.raise_for_status()
    data = response.json()

    blacklist_hotkeys = set(data["hotkeys"])
    blacklist_coldkeys = set(data["coldkeys"])

    response = requests.get(DUPLICATE_SUBMISSIONS_ENDPOINT)
    response.raise_for_status()

    duplicate_selection = DuplicateSelection.model_validate(response.json())

    return Blacklist(
        hotkeys=blacklist_hotkeys,
        coldkeys=blacklist_coldkeys,
        duplicate_selection=duplicate_selection
    )


last_sync: datetime = datetime.fromtimestamp(0, TIMEZONE)


def sync():
    global last_sync
    now = datetime.now(TIMEZONE)
    if now - last_sync < timedelta(seconds=60):
        return
    last_sync = now

    print("Syncing runs...")
    wandb_api = wandb.Api()
    if not RUNS:
        _fetch_history(wandb_api)
    else:
        _fetch_current_runs(wandb_api)


def get_current_runs() -> list[Run]:
    sync_chain()
    sync()

    contest_start = _get_contest_start() - timedelta(days=OFFSET_DAYS)
    contest_end = contest_start + timedelta(days=1)

    current_runs: list[Run] = []

    for runs in RUNS.values():
        for run in runs:
            if contest_start <= run.start_date < contest_end:
                current_runs.append(run)
    return current_runs