Spaces:
Running
Running
File size: 8,794 Bytes
6c858ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
import os
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import TypeAlias
from zoneinfo import ZoneInfo
import wandb
import wandb.apis.public as wapi
from substrateinterface import Keypair
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
TIMEZONE = ZoneInfo("America/Los_Angeles")
START_DATE = datetime(2024, 11, 7)
Uid: TypeAlias = int
class BenchmarkStatus(Enum):
NOT_STARTED = ("Not Started", "orange", False)
IN_PROGRESS = ("In Progress", "orange", False)
FINISHED = ("Finished", "springgreen", False)
INITIALISING = ("Initialising", "orange", False)
STOPPED = ("Stopped", "red", True)
CRASHED = ("Crashed", "red", True)
FAILED = ("Failed", "red", True)
UNKNOWN = ("Unknown", "red", True)
def name(self):
return self.value[0]
def color(self):
return self.value[1]
def failed(self):
return self.value[2]
@dataclass
class MetricData:
generation_time: float
vram_used: float
watts_used: float
load_time: float
size: int
@dataclass
class SubmissionInfo:
uid: int
hotkey: str
repository: str
revision: str
block: int
@dataclass
class Submission:
info: SubmissionInfo
metrics: MetricData
average_similarity: float
min_similarity: float
tier: int
score: float
@dataclass
class InvalidSubmission:
info: SubmissionInfo
reason: str
@dataclass
class Run:
start_date: datetime
version: str
uid: int
name: str
hotkey: str
status: BenchmarkStatus
average_benchmark_time: float
eta: int
winner_uid: int | None
baseline_metrics: MetricData | None
total_submissions: int
submissions: dict[Uid, Submission]
invalid_submissions: dict[Uid, InvalidSubmission]
RUNS: dict[str, list[Run]] = {}
def _is_valid_run(run: wapi.Run):
required_config_keys = ["hotkey", "uid", "contest", "signature"]
for key in required_config_keys:
if key not in run.config:
return False
validator_hotkey = run.config["hotkey"]
contest_name = run.config["contest"]
signing_message = f"{run.name}:{validator_hotkey}:{contest_name}"
return Keypair(validator_hotkey).verify(signing_message, run.config["signature"])
def _date_from_run(run: wapi.Run) -> datetime:
return datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone(TIMEZONE)
def _status_from_run(run: wapi.Run) -> BenchmarkStatus:
match run.state:
case "finished":
return BenchmarkStatus.STOPPED
case "crashed":
return BenchmarkStatus.CRASHED
case "failed":
return BenchmarkStatus.FAILED
case "running":
if "benchmarking_state" in run.summary:
return BenchmarkStatus[run.summary["benchmarking_state"]]
else:
return BenchmarkStatus.INITIALISING
case _:
return BenchmarkStatus.UNKNOWN
def _add_runs(wandb_runs: list[wapi.Run]):
for wandb_run in wandb_runs:
if not _is_valid_run(wandb_run):
continue
metrics = wandb_run.summary
submission_info: dict[Uid, SubmissionInfo] = {}
submissions: dict[Uid, Submission] = {}
invalid_submissions: dict[Uid, InvalidSubmission] = {}
baseline_metrics: MetricData | None = None
if "baseline" in metrics:
baseline = metrics["baseline"]
baseline_metrics = MetricData(
generation_time=float(baseline["generation_time"]),
vram_used=float(baseline["vram_used"]),
watts_used=float(baseline["watts_used"]),
load_time=float(baseline["load_time"]),
size=int(baseline["size"]),
)
if "submissions" in metrics:
for uid, submission in metrics["submissions"].items():
submission_info[uid] = SubmissionInfo(
uid=uid,
hotkey=submission["hotkey"] if "hotkey" in submission else metrics["benchmarks"][uid]["hotkey"] if uid in metrics["benchmarks"] else "unknown",
# hotkey=submission["hotkey"], # TODO use this once validators update
repository=submission["repository"],
revision=submission["revision"],
block=submission["block"],
)
if "benchmarks" in metrics:
for uid, benchmark in metrics["benchmarks"].items():
model = benchmark["model"]
submissions[uid] = Submission(
info=submission_info[uid],
metrics=MetricData(
generation_time=float(model["generation_time"]),
vram_used=float(model["vram_used"]),
watts_used=float(model["watts_used"]),
load_time=float(model["load_time"]),
size=int(model["size"]),
),
average_similarity=float(benchmark["average_similarity"]),
min_similarity=float(benchmark["min_similarity"]),
tier=int(benchmark["tier"]),
score=float(benchmark["score"]),
)
if "invalid" in metrics:
for uid, reason in metrics["invalid"].items():
if not uid in submission_info:
continue
invalid_submissions[uid] = InvalidSubmission(
info=submission_info[uid],
reason=reason,
)
status = _status_from_run(wandb_run)
winners = sorted(
submissions.values(),
key=lambda submission: (submission.tier, -submission.info.block),
reverse=True,
)
winner_uid = winners[0].info.uid if winners and status == status.FINISHED else None
from chain_data import VALIDATOR_IDENTITIES
uid = int(wandb_run.config["uid"])
hotkey = wandb_run.config["hotkey"]
date = _date_from_run(wandb_run)
id = wandb_run.id
average_benchmark_time = float(wandb_run.summary["average_benchmark_time"]) if "average_benchmark_time" in wandb_run.summary else 0
run = Run(
start_date=date,
version=wandb_run.tags[1][8:],
uid=uid,
name=VALIDATOR_IDENTITIES.get(uid, hotkey),
hotkey=hotkey,
status=status,
average_benchmark_time=average_benchmark_time,
eta=int(average_benchmark_time * (len(submission_info) - len(submissions) - len(invalid_submissions))) if average_benchmark_time else 0,
winner_uid=winner_uid,
baseline_metrics=baseline_metrics,
total_submissions=len(submission_info),
submissions=submissions,
invalid_submissions=invalid_submissions,
)
if id not in RUNS:
RUNS[id] = [run]
else:
present = False
for i, existing_run in enumerate(RUNS[id]):
if existing_run.uid == run.uid:
RUNS[id][i] = run
present = True
break
if not present:
RUNS[id].append(run)
def _fetch_history(wandb_api: wandb.Api):
wandb_runs = wandb_api.runs(
WANDB_RUN_PATH,
filters={"config.type": "validator", "created_at": {'$gt': str(START_DATE)}},
order="+created_at",
)
_add_runs(wandb_runs)
def _fetch_current_runs(wandb_api: wandb.Api, now: datetime):
contest_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
wandb_runs = wandb_api.runs(
WANDB_RUN_PATH,
filters={"config.type": "validator", "created_at": {'$gt': str(contest_start)}},
order="+created_at",
)
_add_runs(wandb_runs)
last_sync: datetime = datetime.fromtimestamp(0, TIMEZONE)
def sync():
global last_sync
now = datetime.now(TIMEZONE)
if now - last_sync < timedelta(seconds=30):
return
print("Syncing runs...")
last_sync = now
wandb_api = wandb.Api()
if not RUNS:
_fetch_history(wandb_api)
else:
_fetch_current_runs(wandb_api, now)
def get_current_runs() -> list[Run]:
sync()
from chain_data import sync_metagraph
sync_metagraph()
now = datetime.now(TIMEZONE)
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
if now.hour < 12:
today -= timedelta(days=1)
current_runs: list[Run] = []
for runs in RUNS.values():
for run in runs:
if run.start_date >= today:
current_runs.append(run)
return current_runs |