edge-maxxing-dashboard / src /wandb_data.py
AlexNijjar's picture
Update dependencies, fix commitment fetch
10cac9a unverified
import os
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
import requests
import wandb
from cachetools import TTLCache, cached
import wandb.apis.public as wapi
from pydantic import BaseModel
from substrateinterface import Keypair
from chain_data import VALIDATOR_IDENTITIES, sync_chain
from src import TIMEZONE, Key
from src.chain_data import get_neurons
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
START_DATE = datetime(2025, 2, 12)
OFFSET_DAYS = 0
BLACKLIST_ENDPOINT = "https://edge-inputs.api.wombo.ai/blacklist"
DUPLICATE_SUBMISSIONS_ENDPOINT = "https://edge-inputs.api.wombo.ai/duplicate_submissions"
class DuplicateSubmission(BaseModel):
hotkey: Key
url: str
revision: str
copy_of: str
class SafeSubmissions(BaseModel):
hotkey: Key
url: str
revision: str
class DuplicateSelection(BaseModel):
safe_submissions: list[SafeSubmissions]
duplicate_submissions: list[DuplicateSubmission]
class Blacklist(BaseModel):
coldkeys: set[Key]
hotkeys: set[Key]
duplicate_selection: DuplicateSelection
class BenchmarkStatus(Enum):
NOT_STARTED = ("Not Started", "orange", False)
IN_PROGRESS = ("In Progress", "orange", False)
FINISHED = ("Finished", "springgreen", False)
INITIALIZING = ("Initializing", "orange", False)
STOPPED = ("Stopped", "red", True)
CRASHED = ("Crashed", "red", True)
FAILED = ("Failed", "red", True)
UNKNOWN = ("Unknown", "red", True)
def name(self):
return self.value[0]
def color(self):
return self.value[1]
def failed(self):
return self.value[2]
@dataclass
class Metrics:
generation_time: float
vram_used: float
watts_used: float
load_time: float
size: int
ram_used: float
@dataclass
class SubmissionInfo:
uid: int
repository: str
revision: str
block: int
@dataclass
class Submission:
info: SubmissionInfo
metrics: Metrics
average_similarity: float
min_similarity: float
score: float
@dataclass
class Run:
start_date: datetime
version: str
uid: int
name: str
hotkey: str
status: BenchmarkStatus
average_benchmarking_time: float
step: int
eta: int
baseline_metrics: Metrics | None
total_submissions: int
submissions: dict[Key, Submission]
invalid_submissions: set[Key]
RUNS: dict[str, list[Run]] = {}
def _is_valid_run(run: wapi.Run, version: str) -> bool:
required_config_keys = ["hotkey", "uid", "signature"]
for key in required_config_keys:
if key not in run.config:
return False
validator_hotkey = run.config["hotkey"]
signing_message = f"{version}:{validator_hotkey}"
return Keypair(validator_hotkey).verify(signing_message, run.config["signature"])
def _date_from_run(run: wapi.Run) -> datetime:
return datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone(TIMEZONE)
def _status_from_run(run: wapi.Run) -> BenchmarkStatus:
match run.state:
case "finished":
return BenchmarkStatus.STOPPED
case "crashed":
return BenchmarkStatus.CRASHED
case "failed":
return BenchmarkStatus.FAILED
case "running":
if "benchmarking_state" in run.summary:
states = list(BenchmarkStatus)
return states[int(run.summary["benchmarking_state"])]
else:
return BenchmarkStatus.INITIALIZING
case _:
return BenchmarkStatus.UNKNOWN
def _add_runs(wandb_runs: list[wapi.Run]):
for wandb_run in wandb_runs:
version = wandb_run.tags[1][8:]
if not _is_valid_run(wandb_run, version):
continue
metrics = wandb_run.summary
submission_info: dict[Key, SubmissionInfo] = {}
submissions: dict[Key, Submission] = {}
invalid_submissions: set[Key] = set()
baseline_metrics: Metrics | None = None
if "baseline" in metrics:
baseline = metrics["baseline"]
baseline_metrics = Metrics(
generation_time=float(baseline["generation_time"]),
vram_used=float(baseline["vram_used"]),
ram_used=float(baseline.get("ram_used", 0)),
watts_used=float(baseline["watts_used"]),
load_time=float(baseline["load_time"]),
size=int(baseline["size"]),
)
if "submissions" in metrics:
for hotkey, submission in metrics["submissions"].items():
neuron = get_neurons().get(hotkey)
if not neuron:
continue
submission_info[hotkey] = SubmissionInfo(
uid=neuron.uid,
repository=submission["repository_info"]["url"],
revision=submission["repository_info"]["revision"],
block=submission["block"],
)
if "benchmarks" in metrics:
for hotkey, benchmark in metrics["benchmarks"].items():
benchmark_metrics = benchmark["metrics"]
if hotkey not in submission_info:
continue
scores = metrics["scores"]
if hotkey not in scores:
continue
submissions[hotkey] = Submission(
info=submission_info[hotkey],
metrics=Metrics(
generation_time=float(benchmark_metrics["generation_time"]),
vram_used=float(benchmark_metrics["vram_used"]),
ram_used=float(benchmark_metrics.get("ram_used", 0)),
watts_used=float(benchmark_metrics["watts_used"]),
load_time=float(benchmark_metrics["load_time"]),
size=int(benchmark_metrics["size"]),
),
average_similarity=float(benchmark["average_similarity"]),
min_similarity=float(benchmark["min_similarity"]),
score=float(scores[hotkey]),
)
if "invalid_submissions" in metrics:
try:
for hotkey in metrics["invalid_submissions"]:
invalid_submissions.add(hotkey)
except KeyError:
...
status = _status_from_run(wandb_run)
uid = int(wandb_run.config["uid"])
hotkey = wandb_run.config["hotkey"]
date = _date_from_run(wandb_run)
id = wandb_run.id
average_benchmarking_time = float(wandb_run.summary["average_benchmarking_time"]) if "average_benchmarking_time" in wandb_run.summary else 0
# Get num_gpus from metrics, default to 1 if not found
num_gpus = int(metrics.get("num_gpus", 1))
# Update ETA calculation to account for GPUs
eta_calculation = (
max(
int(average_benchmarking_time * (len(submission_info) - len(submissions) - len(invalid_submissions))) if average_benchmarking_time else 0,
0
) // num_gpus # Divide by number of GPUs
if status != BenchmarkStatus.FINISHED else 0
)
run = Run(
start_date=date,
version=version,
uid=uid,
name=VALIDATOR_IDENTITIES.get(hotkey, f"{hotkey[:6]}..."),
hotkey=hotkey,
status=status,
average_benchmarking_time=average_benchmarking_time,
step=int(metrics["step"]),
eta=eta_calculation,
baseline_metrics=baseline_metrics,
total_submissions=len(submission_info),
submissions=submissions,
invalid_submissions=invalid_submissions,
)
if id not in RUNS:
RUNS[id] = [run]
else:
present = False
for i, existing_run in enumerate(RUNS[id]):
if existing_run.uid == run.uid:
RUNS[id][i] = run
present = True
break
if not present:
RUNS[id].append(run)
def _get_contest_start() -> datetime:
now = datetime.now(TIMEZONE)
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
if now.hour < 12:
today -= timedelta(days=1)
return today
def _fetch_history(wandb_api: wandb.Api):
wandb_runs = wandb_api.runs(
WANDB_RUN_PATH,
filters={"config.type": "validator", "created_at": {'$gt': str(START_DATE)}},
order="+created_at",
)
_add_runs(wandb_runs)
def _fetch_current_runs(wandb_api: wandb.Api):
today = _get_contest_start()
wandb_runs = wandb_api.runs(
WANDB_RUN_PATH,
filters={"config.type": "validator", "created_at": {'$gt': str(today)}},
order="+created_at",
)
_add_runs(wandb_runs)
@cached(cache=TTLCache(maxsize=1, ttl=300))
def get_blacklisted_keys() -> Blacklist:
response = requests.get(BLACKLIST_ENDPOINT)
response.raise_for_status()
data = response.json()
blacklist_hotkeys = set(data["hotkeys"])
blacklist_coldkeys = set(data["coldkeys"])
response = requests.get(DUPLICATE_SUBMISSIONS_ENDPOINT)
response.raise_for_status()
duplicate_selection = DuplicateSelection.model_validate(response.json())
return Blacklist(
hotkeys=blacklist_hotkeys,
coldkeys=blacklist_coldkeys,
duplicate_selection=duplicate_selection
)
last_sync: datetime = datetime.fromtimestamp(0, TIMEZONE)
def sync():
global last_sync
now = datetime.now(TIMEZONE)
if now - last_sync < timedelta(seconds=60):
return
last_sync = now
print("Syncing runs...")
wandb_api = wandb.Api()
if not RUNS:
_fetch_history(wandb_api)
else:
_fetch_current_runs(wandb_api)
def get_current_runs() -> list[Run]:
sync_chain()
sync()
contest_start = _get_contest_start() - timedelta(days=OFFSET_DAYS)
contest_end = contest_start + timedelta(days=1)
current_runs: list[Run] = []
for runs in RUNS.values():
for run in runs:
if contest_start <= run.start_date < contest_end:
current_runs.append(run)
return current_runs