Spaces:
Running
Running
import os | |
from dataclasses import dataclass | |
from datetime import datetime, timedelta, timezone | |
from enum import Enum | |
from typing import TypeAlias | |
from zoneinfo import ZoneInfo | |
import wandb | |
import wandb.apis.public as wapi | |
from substrateinterface import Keypair | |
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"] | |
TIMEZONE = ZoneInfo("America/Los_Angeles") | |
START_DATE = datetime(2024, 11, 7) | |
Uid: TypeAlias = int | |
class BenchmarkStatus(Enum): | |
NOT_STARTED = ("Not Started", "orange", False) | |
IN_PROGRESS = ("In Progress", "orange", False) | |
FINISHED = ("Finished", "springgreen", False) | |
INITIALISING = ("Initialising", "orange", False) | |
STOPPED = ("Stopped", "red", True) | |
CRASHED = ("Crashed", "red", True) | |
FAILED = ("Failed", "red", True) | |
UNKNOWN = ("Unknown", "red", True) | |
def name(self): | |
return self.value[0] | |
def color(self): | |
return self.value[1] | |
def failed(self): | |
return self.value[2] | |
class MetricData: | |
generation_time: float | |
vram_used: float | |
watts_used: float | |
load_time: float | |
size: int | |
class SubmissionInfo: | |
uid: int | |
hotkey: str | |
repository: str | |
revision: str | |
block: int | |
class Submission: | |
info: SubmissionInfo | |
metrics: MetricData | |
average_similarity: float | |
min_similarity: float | |
tier: int | |
score: float | |
class InvalidSubmission: | |
info: SubmissionInfo | |
reason: str | |
class Run: | |
start_date: datetime | |
version: str | |
uid: int | |
name: str | |
hotkey: str | |
status: BenchmarkStatus | |
average_benchmark_time: float | |
eta: int | |
winner_uid: int | None | |
baseline_metrics: MetricData | None | |
total_submissions: int | |
submissions: dict[Uid, Submission] | |
invalid_submissions: dict[Uid, InvalidSubmission] | |
RUNS: dict[str, list[Run]] = {} | |
def _is_valid_run(run: wapi.Run): | |
required_config_keys = ["hotkey", "uid", "contest", "signature"] | |
for key in required_config_keys: | |
if key not in run.config: | |
return False | |
validator_hotkey = run.config["hotkey"] | |
contest_name = run.config["contest"] | |
signing_message = f"{run.name}:{validator_hotkey}:{contest_name}" | |
return Keypair(validator_hotkey).verify(signing_message, run.config["signature"]) | |
def _date_from_run(run: wapi.Run) -> datetime: | |
return datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone(TIMEZONE) | |
def _status_from_run(run: wapi.Run) -> BenchmarkStatus: | |
match run.state: | |
case "finished": | |
return BenchmarkStatus.STOPPED | |
case "crashed": | |
return BenchmarkStatus.CRASHED | |
case "failed": | |
return BenchmarkStatus.FAILED | |
case "running": | |
if "benchmarking_state" in run.summary: | |
return BenchmarkStatus[run.summary["benchmarking_state"]] | |
else: | |
return BenchmarkStatus.INITIALISING | |
case _: | |
return BenchmarkStatus.UNKNOWN | |
def _add_runs(wandb_runs: list[wapi.Run]): | |
for wandb_run in wandb_runs: | |
if not _is_valid_run(wandb_run): | |
continue | |
metrics = wandb_run.summary | |
submission_info: dict[Uid, SubmissionInfo] = {} | |
submissions: dict[Uid, Submission] = {} | |
invalid_submissions: dict[Uid, InvalidSubmission] = {} | |
baseline_metrics: MetricData | None = None | |
if "baseline" in metrics: | |
baseline = metrics["baseline"] | |
baseline_metrics = MetricData( | |
generation_time=float(baseline["generation_time"]), | |
vram_used=float(baseline["vram_used"]), | |
watts_used=float(baseline["watts_used"]), | |
load_time=float(baseline["load_time"]), | |
size=int(baseline["size"]), | |
) | |
if "submissions" in metrics: | |
for uid, submission in metrics["submissions"].items(): | |
submission_info[uid] = SubmissionInfo( | |
uid=uid, | |
hotkey=submission["hotkey"] if "hotkey" in submission else metrics["benchmarks"][uid]["hotkey"] if uid in metrics["benchmarks"] else "unknown", | |
# hotkey=submission["hotkey"], # TODO use this once validators update | |
repository=submission["repository"], | |
revision=submission["revision"], | |
block=submission["block"], | |
) | |
if "benchmarks" in metrics: | |
for uid, benchmark in metrics["benchmarks"].items(): | |
model = benchmark["model"] | |
submissions[uid] = Submission( | |
info=submission_info[uid], | |
metrics=MetricData( | |
generation_time=float(model["generation_time"]), | |
vram_used=float(model["vram_used"]), | |
watts_used=float(model["watts_used"]), | |
load_time=float(model["load_time"]), | |
size=int(model["size"]), | |
), | |
average_similarity=float(benchmark["average_similarity"]), | |
min_similarity=float(benchmark["min_similarity"]), | |
tier=int(benchmark["tier"]), | |
score=float(benchmark["score"]), | |
) | |
if "invalid" in metrics: | |
for uid, reason in metrics["invalid"].items(): | |
if not uid in submission_info: | |
continue | |
invalid_submissions[uid] = InvalidSubmission( | |
info=submission_info[uid], | |
reason=reason, | |
) | |
status = _status_from_run(wandb_run) | |
winners = sorted( | |
submissions.values(), | |
key=lambda submission: (submission.tier, -submission.info.block), | |
reverse=True, | |
) | |
winner_uid = winners[0].info.uid if winners and status == status.FINISHED else None | |
from chain_data import VALIDATOR_IDENTITIES | |
uid = int(wandb_run.config["uid"]) | |
hotkey = wandb_run.config["hotkey"] | |
date = _date_from_run(wandb_run) | |
id = wandb_run.id | |
average_benchmark_time = float(wandb_run.summary["average_benchmark_time"]) if "average_benchmark_time" in wandb_run.summary else 0 | |
run = Run( | |
start_date=date, | |
version=wandb_run.tags[1][8:], | |
uid=uid, | |
name=VALIDATOR_IDENTITIES.get(uid, hotkey), | |
hotkey=hotkey, | |
status=status, | |
average_benchmark_time=average_benchmark_time, | |
eta=int(average_benchmark_time * (len(submission_info) - len(submissions) - len(invalid_submissions))) if average_benchmark_time else 0, | |
winner_uid=winner_uid, | |
baseline_metrics=baseline_metrics, | |
total_submissions=len(submission_info), | |
submissions=submissions, | |
invalid_submissions=invalid_submissions, | |
) | |
if id not in RUNS: | |
RUNS[id] = [run] | |
else: | |
present = False | |
for i, existing_run in enumerate(RUNS[id]): | |
if existing_run.uid == run.uid: | |
RUNS[id][i] = run | |
present = True | |
break | |
if not present: | |
RUNS[id].append(run) | |
def _fetch_history(wandb_api: wandb.Api): | |
wandb_runs = wandb_api.runs( | |
WANDB_RUN_PATH, | |
filters={"config.type": "validator", "created_at": {'$gt': str(START_DATE)}}, | |
order="+created_at", | |
) | |
_add_runs(wandb_runs) | |
def _fetch_current_runs(wandb_api: wandb.Api, now: datetime): | |
contest_start = now.replace(hour=0, minute=0, second=0, microsecond=0) | |
wandb_runs = wandb_api.runs( | |
WANDB_RUN_PATH, | |
filters={"config.type": "validator", "created_at": {'$gt': str(contest_start)}}, | |
order="+created_at", | |
) | |
_add_runs(wandb_runs) | |
last_sync: datetime = datetime.fromtimestamp(0, TIMEZONE) | |
def sync(): | |
global last_sync | |
now = datetime.now(TIMEZONE) | |
if now - last_sync < timedelta(seconds=30): | |
return | |
print("Syncing runs...") | |
last_sync = now | |
wandb_api = wandb.Api() | |
if not RUNS: | |
_fetch_history(wandb_api) | |
else: | |
_fetch_current_runs(wandb_api, now) | |
def get_current_runs() -> list[Run]: | |
sync() | |
from chain_data import sync_metagraph | |
sync_metagraph() | |
now = datetime.now(TIMEZONE) | |
today = now.replace(hour=0, minute=0, second=0, microsecond=0) | |
if now.hour < 12: | |
today -= timedelta(days=1) | |
current_runs: list[Run] = [] | |
for runs in RUNS.values(): | |
for run in runs: | |
if run.start_date >= today: | |
current_runs.append(run) | |
return current_runs |