import os from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import datetime, timedelta, timezone from enum import Enum from typing import TypeAlias from zoneinfo import ZoneInfo import wandb import wandb.apis.public as wapi from substrateinterface import Keypair WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"] TIMEZONE = ZoneInfo("America/Los_Angeles") START_DATE = datetime(2024, 11, 9) Uid: TypeAlias = int class BenchmarkStatus(Enum): NOT_STARTED = ("Not Started", "orange", False) IN_PROGRESS = ("In Progress", "orange", False) FINISHED = ("Finished", "springgreen", False) INITIALIZING = ("Initializing", "orange", False) STOPPED = ("Stopped", "red", True) CRASHED = ("Crashed", "red", True) FAILED = ("Failed", "red", True) UNKNOWN = ("Unknown", "red", True) def name(self): return self.value[0] def color(self): return self.value[1] def failed(self): return self.value[2] @dataclass class MetricData: generation_time: float vram_used: float watts_used: float load_time: float size: int @dataclass class SubmissionInfo: uid: int hotkey: str repository: str revision: str block: int @dataclass class Submission: info: SubmissionInfo metrics: MetricData average_similarity: float min_similarity: float tier: int score: float @dataclass class InvalidSubmission: info: SubmissionInfo reason: str @dataclass class Run: start_date: datetime version: str uid: int name: str hotkey: str status: BenchmarkStatus average_benchmark_time: float eta: int winner_uid: int | None baseline_metrics: MetricData | None total_submissions: int submissions: dict[Uid, Submission] invalid_submissions: dict[Uid, InvalidSubmission] RUNS: dict[str, list[Run]] = {} def _is_valid_run(run: wapi.Run): required_config_keys = ["hotkey", "uid", "contest", "signature"] for key in required_config_keys: if key not in run.config: return False validator_hotkey = run.config["hotkey"] contest_name = run.config["contest"] signing_message = f"{run.name}:{validator_hotkey}:{contest_name}" return Keypair(validator_hotkey).verify(signing_message, run.config["signature"]) def _date_from_run(run: wapi.Run) -> datetime: return datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone(TIMEZONE) def _status_from_run(run: wapi.Run) -> BenchmarkStatus: match run.state: case "finished": return BenchmarkStatus.STOPPED case "crashed": return BenchmarkStatus.CRASHED case "failed": return BenchmarkStatus.FAILED case "running": if "benchmarking_state" in run.summary: return BenchmarkStatus[run.summary["benchmarking_state"]] else: return BenchmarkStatus.INITIALIZING case _: return BenchmarkStatus.UNKNOWN def _add_runs(wandb_runs: list[wapi.Run]): for wandb_run in wandb_runs: if not _is_valid_run(wandb_run): continue metrics = wandb_run.summary submission_info: dict[Uid, SubmissionInfo] = {} submissions: dict[Uid, Submission] = {} invalid_submissions: dict[Uid, InvalidSubmission] = {} baseline_metrics: MetricData | None = None if "baseline" in metrics: baseline = metrics["baseline"] baseline_metrics = MetricData( generation_time=float(baseline["generation_time"]), vram_used=float(baseline["vram_used"]), watts_used=float(baseline["watts_used"]), load_time=float(baseline["load_time"]), size=int(baseline["size"]), ) if "submissions" in metrics: for uid, submission in metrics["submissions"].items(): submission_info[uid] = SubmissionInfo( uid=uid, hotkey=submission["hotkey"], repository=submission["repository"], revision=submission["revision"], block=submission["block"], ) if "benchmarks" in metrics: for uid, benchmark in metrics["benchmarks"].items(): model = benchmark["model"] if uid not in submission_info: continue submissions[uid] = Submission( info=submission_info[uid], metrics=MetricData( generation_time=float(model["generation_time"]), vram_used=float(model["vram_used"]), watts_used=float(model["watts_used"]), load_time=float(model["load_time"]), size=int(model["size"]), ), average_similarity=float(benchmark["average_similarity"]), min_similarity=float(benchmark["min_similarity"]), tier=int(benchmark["tier"]), score=float(benchmark["score"]), ) if "invalid" in metrics: for uid, reason in metrics["invalid"].items(): if not uid in submission_info: continue invalid_submissions[uid] = InvalidSubmission( info=submission_info[uid], reason=reason, ) status = _status_from_run(wandb_run) winners = sorted( submissions.values(), key=lambda submission: (submission.tier, -submission.info.block), reverse=True, ) winner_uid = winners[0].info.uid if winners and status == status.FINISHED else None from chain_data import VALIDATOR_IDENTITIES uid = int(wandb_run.config["uid"]) hotkey = wandb_run.config["hotkey"] date = _date_from_run(wandb_run) id = wandb_run.id average_benchmark_time = float(wandb_run.summary["average_benchmark_time"]) if "average_benchmark_time" in wandb_run.summary else 0 run = Run( start_date=date, version=wandb_run.tags[1][8:], uid=uid, name=VALIDATOR_IDENTITIES.get(uid, f"{hotkey[:6]}..."), hotkey=hotkey, status=status, average_benchmark_time=average_benchmark_time, eta=max(int(average_benchmark_time * (len(submission_info) - len(submissions) - len(invalid_submissions))) if average_benchmark_time else 0, 0) if status != BenchmarkStatus.FINISHED else 0, winner_uid=winner_uid, baseline_metrics=baseline_metrics, total_submissions=len(submission_info), submissions=submissions, invalid_submissions=invalid_submissions, ) if id not in RUNS: RUNS[id] = [run] else: present = False for i, existing_run in enumerate(RUNS[id]): if existing_run.uid == run.uid: RUNS[id][i] = run present = True break if not present: RUNS[id].append(run) def _get_contest_start() -> datetime: now = datetime.now(TIMEZONE) today = now.replace(hour=0, minute=0, second=0, microsecond=0) if now.hour < 12: today -= timedelta(days=1) return today def _fetch_history(wandb_api: wandb.Api): wandb_runs = wandb_api.runs( WANDB_RUN_PATH, filters={"config.type": "validator", "created_at": {'$gt': str(START_DATE)}}, order="+created_at", ) _add_runs(wandb_runs) def _fetch_current_runs(wandb_api: wandb.Api): today = _get_contest_start() wandb_runs = wandb_api.runs( WANDB_RUN_PATH, filters={"config.type": "validator", "created_at": {'$gt': str(today)}}, order="+created_at", ) _add_runs(wandb_runs) last_sync: datetime = datetime.fromtimestamp(0, TIMEZONE) def sync(timeout: int = 10): global last_sync now = datetime.now(TIMEZONE) if now - last_sync < timedelta(seconds=60): return last_sync = now def sync_task(): print("Syncing runs...") wandb_api = wandb.Api() if not RUNS: _fetch_history(wandb_api) else: _fetch_current_runs(wandb_api) with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(sync_task) try: future.result(timeout=timeout) except TimeoutError: print("Timed out while syncing runs") except Exception as e: print(f"Error occurred while syncing runs: {e}") def get_current_runs() -> list[Run]: sync() from chain_data import sync_metagraph sync_metagraph() today = _get_contest_start() current_runs: list[Run] = [] for runs in RUNS.values(): for run in runs: if run.start_date >= today: current_runs.append(run) return current_runs