Spaces:
Running
Running
import os | |
from dataclasses import dataclass | |
from datetime import datetime, timedelta, timezone | |
from enum import Enum | |
import requests | |
import wandb | |
from cachetools import TTLCache, cached | |
import wandb.apis.public as wapi | |
from pydantic import BaseModel | |
from substrateinterface import Keypair | |
from chain_data import VALIDATOR_IDENTITIES, sync_chain | |
from src import TIMEZONE, Key | |
from src.chain_data import get_neurons | |
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"] | |
START_DATE = datetime(2025, 2, 12) | |
OFFSET_DAYS = 0 | |
BLACKLIST_ENDPOINT = "https://edge-inputs.api.wombo.ai/blacklist" | |
DUPLICATE_SUBMISSIONS_ENDPOINT = "https://edge-inputs.api.wombo.ai/duplicate_submissions" | |
class DuplicateSubmission(BaseModel): | |
hotkey: Key | |
url: str | |
revision: str | |
copy_of: str | |
class SafeSubmissions(BaseModel): | |
hotkey: Key | |
url: str | |
revision: str | |
class DuplicateSelection(BaseModel): | |
safe_submissions: list[SafeSubmissions] | |
duplicate_submissions: list[DuplicateSubmission] | |
class Blacklist(BaseModel): | |
coldkeys: set[Key] | |
hotkeys: set[Key] | |
duplicate_selection: DuplicateSelection | |
class BenchmarkStatus(Enum): | |
NOT_STARTED = ("Not Started", "orange", False) | |
IN_PROGRESS = ("In Progress", "orange", False) | |
FINISHED = ("Finished", "springgreen", False) | |
INITIALIZING = ("Initializing", "orange", False) | |
STOPPED = ("Stopped", "red", True) | |
CRASHED = ("Crashed", "red", True) | |
FAILED = ("Failed", "red", True) | |
UNKNOWN = ("Unknown", "red", True) | |
def name(self): | |
return self.value[0] | |
def color(self): | |
return self.value[1] | |
def failed(self): | |
return self.value[2] | |
class Metrics: | |
generation_time: float | |
vram_used: float | |
watts_used: float | |
load_time: float | |
size: int | |
ram_used: float | |
class SubmissionInfo: | |
uid: int | |
repository: str | |
revision: str | |
block: int | |
class Submission: | |
info: SubmissionInfo | |
metrics: Metrics | |
average_similarity: float | |
min_similarity: float | |
score: float | |
class Run: | |
start_date: datetime | |
version: str | |
uid: int | |
name: str | |
hotkey: str | |
status: BenchmarkStatus | |
average_benchmarking_time: float | |
step: int | |
eta: int | |
baseline_metrics: Metrics | None | |
total_submissions: int | |
submissions: dict[Key, Submission] | |
invalid_submissions: set[Key] | |
RUNS: dict[str, list[Run]] = {} | |
def _is_valid_run(run: wapi.Run, version: str) -> bool: | |
required_config_keys = ["hotkey", "uid", "signature"] | |
for key in required_config_keys: | |
if key not in run.config: | |
return False | |
validator_hotkey = run.config["hotkey"] | |
signing_message = f"{version}:{validator_hotkey}" | |
return Keypair(validator_hotkey).verify(signing_message, run.config["signature"]) | |
def _date_from_run(run: wapi.Run) -> datetime: | |
return datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone(TIMEZONE) | |
def _status_from_run(run: wapi.Run) -> BenchmarkStatus: | |
match run.state: | |
case "finished": | |
return BenchmarkStatus.STOPPED | |
case "crashed": | |
return BenchmarkStatus.CRASHED | |
case "failed": | |
return BenchmarkStatus.FAILED | |
case "running": | |
if "benchmarking_state" in run.summary: | |
states = list(BenchmarkStatus) | |
return states[int(run.summary["benchmarking_state"])] | |
else: | |
return BenchmarkStatus.INITIALIZING | |
case _: | |
return BenchmarkStatus.UNKNOWN | |
def _add_runs(wandb_runs: list[wapi.Run]): | |
for wandb_run in wandb_runs: | |
version = wandb_run.tags[1][8:] | |
if not _is_valid_run(wandb_run, version): | |
continue | |
metrics = wandb_run.summary | |
submission_info: dict[Key, SubmissionInfo] = {} | |
submissions: dict[Key, Submission] = {} | |
invalid_submissions: set[Key] = set() | |
baseline_metrics: Metrics | None = None | |
if "baseline" in metrics: | |
baseline = metrics["baseline"] | |
baseline_metrics = Metrics( | |
generation_time=float(baseline["generation_time"]), | |
vram_used=float(baseline["vram_used"]), | |
ram_used=float(baseline.get("ram_used", 0)), | |
watts_used=float(baseline["watts_used"]), | |
load_time=float(baseline["load_time"]), | |
size=int(baseline["size"]), | |
) | |
if "submissions" in metrics: | |
for hotkey, submission in metrics["submissions"].items(): | |
neuron = get_neurons().get(hotkey) | |
if not neuron: | |
continue | |
submission_info[hotkey] = SubmissionInfo( | |
uid=neuron.uid, | |
repository=submission["repository_info"]["url"], | |
revision=submission["repository_info"]["revision"], | |
block=submission["block"], | |
) | |
if "benchmarks" in metrics: | |
for hotkey, benchmark in metrics["benchmarks"].items(): | |
benchmark_metrics = benchmark["metrics"] | |
if hotkey not in submission_info: | |
continue | |
scores = metrics["scores"] | |
if hotkey not in scores: | |
continue | |
submissions[hotkey] = Submission( | |
info=submission_info[hotkey], | |
metrics=Metrics( | |
generation_time=float(benchmark_metrics["generation_time"]), | |
vram_used=float(benchmark_metrics["vram_used"]), | |
ram_used=float(benchmark_metrics.get("ram_used", 0)), | |
watts_used=float(benchmark_metrics["watts_used"]), | |
load_time=float(benchmark_metrics["load_time"]), | |
size=int(benchmark_metrics["size"]), | |
), | |
average_similarity=float(benchmark["average_similarity"]), | |
min_similarity=float(benchmark["min_similarity"]), | |
score=float(scores[hotkey]), | |
) | |
if "invalid_submissions" in metrics: | |
try: | |
for hotkey in metrics["invalid_submissions"]: | |
invalid_submissions.add(hotkey) | |
except KeyError: | |
... | |
status = _status_from_run(wandb_run) | |
uid = int(wandb_run.config["uid"]) | |
hotkey = wandb_run.config["hotkey"] | |
date = _date_from_run(wandb_run) | |
id = wandb_run.id | |
average_benchmarking_time = float(wandb_run.summary["average_benchmarking_time"]) if "average_benchmarking_time" in wandb_run.summary else 0 | |
# Get num_gpus from metrics, default to 1 if not found | |
num_gpus = int(metrics.get("num_gpus", 1)) | |
# Update ETA calculation to account for GPUs | |
eta_calculation = ( | |
max( | |
int(average_benchmarking_time * (len(submission_info) - len(submissions) - len(invalid_submissions))) if average_benchmarking_time else 0, | |
0 | |
) // num_gpus # Divide by number of GPUs | |
if status != BenchmarkStatus.FINISHED else 0 | |
) | |
run = Run( | |
start_date=date, | |
version=version, | |
uid=uid, | |
name=VALIDATOR_IDENTITIES.get(hotkey, f"{hotkey[:6]}..."), | |
hotkey=hotkey, | |
status=status, | |
average_benchmarking_time=average_benchmarking_time, | |
step=int(metrics["step"]), | |
eta=eta_calculation, | |
baseline_metrics=baseline_metrics, | |
total_submissions=len(submission_info), | |
submissions=submissions, | |
invalid_submissions=invalid_submissions, | |
) | |
if id not in RUNS: | |
RUNS[id] = [run] | |
else: | |
present = False | |
for i, existing_run in enumerate(RUNS[id]): | |
if existing_run.uid == run.uid: | |
RUNS[id][i] = run | |
present = True | |
break | |
if not present: | |
RUNS[id].append(run) | |
def _get_contest_start() -> datetime: | |
now = datetime.now(TIMEZONE) | |
today = now.replace(hour=0, minute=0, second=0, microsecond=0) | |
if now.hour < 12: | |
today -= timedelta(days=1) | |
return today | |
def _fetch_history(wandb_api: wandb.Api): | |
wandb_runs = wandb_api.runs( | |
WANDB_RUN_PATH, | |
filters={"config.type": "validator", "created_at": {'$gt': str(START_DATE)}}, | |
order="+created_at", | |
) | |
_add_runs(wandb_runs) | |
def _fetch_current_runs(wandb_api: wandb.Api): | |
today = _get_contest_start() | |
wandb_runs = wandb_api.runs( | |
WANDB_RUN_PATH, | |
filters={"config.type": "validator", "created_at": {'$gt': str(today)}}, | |
order="+created_at", | |
) | |
_add_runs(wandb_runs) | |
def get_blacklisted_keys() -> Blacklist: | |
response = requests.get(BLACKLIST_ENDPOINT) | |
response.raise_for_status() | |
data = response.json() | |
blacklist_hotkeys = set(data["hotkeys"]) | |
blacklist_coldkeys = set(data["coldkeys"]) | |
response = requests.get(DUPLICATE_SUBMISSIONS_ENDPOINT) | |
response.raise_for_status() | |
duplicate_selection = DuplicateSelection.model_validate(response.json()) | |
return Blacklist( | |
hotkeys=blacklist_hotkeys, | |
coldkeys=blacklist_coldkeys, | |
duplicate_selection=duplicate_selection | |
) | |
last_sync: datetime = datetime.fromtimestamp(0, TIMEZONE) | |
def sync(): | |
global last_sync | |
now = datetime.now(TIMEZONE) | |
if now - last_sync < timedelta(seconds=60): | |
return | |
last_sync = now | |
print("Syncing runs...") | |
wandb_api = wandb.Api() | |
if not RUNS: | |
_fetch_history(wandb_api) | |
else: | |
_fetch_current_runs(wandb_api) | |
def get_current_runs() -> list[Run]: | |
sync_chain() | |
sync() | |
contest_start = _get_contest_start() - timedelta(days=OFFSET_DAYS) | |
contest_end = contest_start + timedelta(days=1) | |
current_runs: list[Run] = [] | |
for runs in RUNS.values(): | |
for run in runs: | |
if contest_start <= run.start_date < contest_end: | |
current_runs.append(run) | |
return current_runs | |