edge-maxxing-dashboard / src /wandb_data.py
AlexNijjar's picture
Initial commit
6c858ba
raw
history blame
8.79 kB
import os
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import TypeAlias
from zoneinfo import ZoneInfo
import wandb
import wandb.apis.public as wapi
from substrateinterface import Keypair
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
TIMEZONE = ZoneInfo("America/Los_Angeles")
START_DATE = datetime(2024, 11, 7)
Uid: TypeAlias = int
class BenchmarkStatus(Enum):
NOT_STARTED = ("Not Started", "orange", False)
IN_PROGRESS = ("In Progress", "orange", False)
FINISHED = ("Finished", "springgreen", False)
INITIALISING = ("Initialising", "orange", False)
STOPPED = ("Stopped", "red", True)
CRASHED = ("Crashed", "red", True)
FAILED = ("Failed", "red", True)
UNKNOWN = ("Unknown", "red", True)
def name(self):
return self.value[0]
def color(self):
return self.value[1]
def failed(self):
return self.value[2]
@dataclass
class MetricData:
generation_time: float
vram_used: float
watts_used: float
load_time: float
size: int
@dataclass
class SubmissionInfo:
uid: int
hotkey: str
repository: str
revision: str
block: int
@dataclass
class Submission:
info: SubmissionInfo
metrics: MetricData
average_similarity: float
min_similarity: float
tier: int
score: float
@dataclass
class InvalidSubmission:
info: SubmissionInfo
reason: str
@dataclass
class Run:
start_date: datetime
version: str
uid: int
name: str
hotkey: str
status: BenchmarkStatus
average_benchmark_time: float
eta: int
winner_uid: int | None
baseline_metrics: MetricData | None
total_submissions: int
submissions: dict[Uid, Submission]
invalid_submissions: dict[Uid, InvalidSubmission]
RUNS: dict[str, list[Run]] = {}
def _is_valid_run(run: wapi.Run):
required_config_keys = ["hotkey", "uid", "contest", "signature"]
for key in required_config_keys:
if key not in run.config:
return False
validator_hotkey = run.config["hotkey"]
contest_name = run.config["contest"]
signing_message = f"{run.name}:{validator_hotkey}:{contest_name}"
return Keypair(validator_hotkey).verify(signing_message, run.config["signature"])
def _date_from_run(run: wapi.Run) -> datetime:
return datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone(TIMEZONE)
def _status_from_run(run: wapi.Run) -> BenchmarkStatus:
match run.state:
case "finished":
return BenchmarkStatus.STOPPED
case "crashed":
return BenchmarkStatus.CRASHED
case "failed":
return BenchmarkStatus.FAILED
case "running":
if "benchmarking_state" in run.summary:
return BenchmarkStatus[run.summary["benchmarking_state"]]
else:
return BenchmarkStatus.INITIALISING
case _:
return BenchmarkStatus.UNKNOWN
def _add_runs(wandb_runs: list[wapi.Run]):
for wandb_run in wandb_runs:
if not _is_valid_run(wandb_run):
continue
metrics = wandb_run.summary
submission_info: dict[Uid, SubmissionInfo] = {}
submissions: dict[Uid, Submission] = {}
invalid_submissions: dict[Uid, InvalidSubmission] = {}
baseline_metrics: MetricData | None = None
if "baseline" in metrics:
baseline = metrics["baseline"]
baseline_metrics = MetricData(
generation_time=float(baseline["generation_time"]),
vram_used=float(baseline["vram_used"]),
watts_used=float(baseline["watts_used"]),
load_time=float(baseline["load_time"]),
size=int(baseline["size"]),
)
if "submissions" in metrics:
for uid, submission in metrics["submissions"].items():
submission_info[uid] = SubmissionInfo(
uid=uid,
hotkey=submission["hotkey"] if "hotkey" in submission else metrics["benchmarks"][uid]["hotkey"] if uid in metrics["benchmarks"] else "unknown",
# hotkey=submission["hotkey"], # TODO use this once validators update
repository=submission["repository"],
revision=submission["revision"],
block=submission["block"],
)
if "benchmarks" in metrics:
for uid, benchmark in metrics["benchmarks"].items():
model = benchmark["model"]
submissions[uid] = Submission(
info=submission_info[uid],
metrics=MetricData(
generation_time=float(model["generation_time"]),
vram_used=float(model["vram_used"]),
watts_used=float(model["watts_used"]),
load_time=float(model["load_time"]),
size=int(model["size"]),
),
average_similarity=float(benchmark["average_similarity"]),
min_similarity=float(benchmark["min_similarity"]),
tier=int(benchmark["tier"]),
score=float(benchmark["score"]),
)
if "invalid" in metrics:
for uid, reason in metrics["invalid"].items():
if not uid in submission_info:
continue
invalid_submissions[uid] = InvalidSubmission(
info=submission_info[uid],
reason=reason,
)
status = _status_from_run(wandb_run)
winners = sorted(
submissions.values(),
key=lambda submission: (submission.tier, -submission.info.block),
reverse=True,
)
winner_uid = winners[0].info.uid if winners and status == status.FINISHED else None
from chain_data import VALIDATOR_IDENTITIES
uid = int(wandb_run.config["uid"])
hotkey = wandb_run.config["hotkey"]
date = _date_from_run(wandb_run)
id = wandb_run.id
average_benchmark_time = float(wandb_run.summary["average_benchmark_time"]) if "average_benchmark_time" in wandb_run.summary else 0
run = Run(
start_date=date,
version=wandb_run.tags[1][8:],
uid=uid,
name=VALIDATOR_IDENTITIES.get(uid, hotkey),
hotkey=hotkey,
status=status,
average_benchmark_time=average_benchmark_time,
eta=int(average_benchmark_time * (len(submission_info) - len(submissions) - len(invalid_submissions))) if average_benchmark_time else 0,
winner_uid=winner_uid,
baseline_metrics=baseline_metrics,
total_submissions=len(submission_info),
submissions=submissions,
invalid_submissions=invalid_submissions,
)
if id not in RUNS:
RUNS[id] = [run]
else:
present = False
for i, existing_run in enumerate(RUNS[id]):
if existing_run.uid == run.uid:
RUNS[id][i] = run
present = True
break
if not present:
RUNS[id].append(run)
def _fetch_history(wandb_api: wandb.Api):
wandb_runs = wandb_api.runs(
WANDB_RUN_PATH,
filters={"config.type": "validator", "created_at": {'$gt': str(START_DATE)}},
order="+created_at",
)
_add_runs(wandb_runs)
def _fetch_current_runs(wandb_api: wandb.Api, now: datetime):
contest_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
wandb_runs = wandb_api.runs(
WANDB_RUN_PATH,
filters={"config.type": "validator", "created_at": {'$gt': str(contest_start)}},
order="+created_at",
)
_add_runs(wandb_runs)
last_sync: datetime = datetime.fromtimestamp(0, TIMEZONE)
def sync():
global last_sync
now = datetime.now(TIMEZONE)
if now - last_sync < timedelta(seconds=30):
return
print("Syncing runs...")
last_sync = now
wandb_api = wandb.Api()
if not RUNS:
_fetch_history(wandb_api)
else:
_fetch_current_runs(wandb_api, now)
def get_current_runs() -> list[Run]:
sync()
from chain_data import sync_metagraph
sync_metagraph()
now = datetime.now(TIMEZONE)
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
if now.hour < 12:
today -= timedelta(days=1)
current_runs: list[Run] = []
for runs in RUNS.values():
for run in runs:
if run.start_date >= today:
current_runs.append(run)
return current_runs