Spaces:
Sleeping
Sleeping
import datetime | |
import time | |
from collections import Counter | |
from contextlib import suppress | |
from dataclasses import asdict | |
from functools import partial | |
import hivemind | |
import numpy as np | |
from multiaddr import Multiaddr | |
from petals.data_structures import UID_DELIMITER, ServerState | |
from petals.utils.dht import compute_spans, get_remote_module_infos | |
import config | |
from data_structures import ModelInfo | |
from p2p_utils import check_reachability_parallel, get_peers_ips, extract_peer_ip_info | |
logger = hivemind.get_logger(__name__) | |
def fetch_health_state(dht: hivemind.DHT) -> dict: | |
start_time = time.perf_counter() | |
bootstrap_peer_ids = [] | |
for addr in config.INITIAL_PEERS: | |
peer_id = hivemind.PeerID.from_base58(Multiaddr(addr)["p2p"]) | |
if peer_id not in bootstrap_peer_ids: | |
bootstrap_peer_ids.append(peer_id) | |
reach_infos = dht.run_coroutine(partial(check_reachability_parallel, bootstrap_peer_ids)) | |
bootstrap_states = ["online" if reach_infos[peer_id]["ok"] else "unreachable" for peer_id in bootstrap_peer_ids] | |
models = config.MODELS[:] | |
model_index = dht.get("_petals.models", latest=True) | |
if model_index is not None and isinstance(model_index.value, dict): | |
official_dht_prefixes = {model.dht_prefix for model in models} | |
custom_models = [] | |
for dht_prefix, model in model_index.value.items(): | |
if dht_prefix in official_dht_prefixes: | |
continue | |
with suppress(TypeError, ValueError): | |
model_info = ModelInfo.from_dict(model.value) | |
if model_info.repository is None or not model_info.repository.startswith("https://huggingface.co/"): | |
continue | |
model_info.dht_prefix = dht_prefix | |
model_info.official = False | |
custom_models.append(model_info) | |
models.extend(sorted(custom_models, key=lambda info: (-info.num_blocks, info.dht_prefix))) | |
logger.info(f"Fetching info for models {[info.name for info in models]}") | |
block_uids = [f"{model.dht_prefix}{UID_DELIMITER}{i}" for model in models for i in range(model.num_blocks)] | |
module_infos = get_remote_module_infos(dht, block_uids, latest=True) | |
model_servers = {} | |
all_servers = {} | |
offset = 0 | |
for model in models: | |
model_servers[model.dht_prefix] = compute_spans( | |
module_infos[offset : offset + model.num_blocks], min_state=ServerState.OFFLINE | |
) | |
all_servers.update(model_servers[model.dht_prefix]) | |
offset += model.num_blocks | |
online_servers = [peer_id for peer_id, span in all_servers.items() if span.state == ServerState.ONLINE] | |
reach_infos.update(dht.run_coroutine(partial(check_reachability_parallel, online_servers, fetch_info=True))) | |
peers_info = {str(peer.peer_id): {"location": extract_peer_ip_info(str(peer.addrs[0])), "multiaddrs": [str(multiaddr) for multiaddr in peer.addrs]} for peer in dht.run_coroutine(get_peers_ips)} | |
top_contributors = Counter() | |
model_reports = [] | |
for model in models: | |
block_healthy = np.zeros(model.num_blocks, dtype=bool) | |
server_rows = [] | |
for peer_id, span in sorted(model_servers[model.dht_prefix].items()): | |
reachable = reach_infos[peer_id]["ok"] if peer_id in reach_infos else True | |
state = span.state.name.lower() if reachable else "unreachable" | |
if state == "online": | |
block_healthy[span.start : span.end] = True | |
show_public_name = state == "online" and span.length >= 10 | |
if model.official and span.server_info.public_name and show_public_name: | |
top_contributors[span.server_info.public_name] += span.length | |
row = { | |
"short_peer_id": "..." + str(peer_id)[-6:], | |
"peer_id": peer_id, | |
"peer_ip_info": peers_info.get(str(peer_id), "unknown"), | |
"show_public_name": show_public_name, | |
"state": state, | |
"span": span, | |
"adapters": [dict(name=name, short_name=name.split("/")[-1]) for name in span.server_info.adapters], | |
"pings_to_me": { | |
str(origin_id): origin.server_info.next_pings[str(peer_id)] | |
for origin_id, origin in model_servers[model.dht_prefix].items() | |
if origin.server_info.next_pings is not None and str(peer_id) in origin.server_info.next_pings | |
}, | |
} | |
if span.server_info.cache_tokens_left is not None: | |
# We use num_blocks * 2 to account for both keys and values | |
row["cache_tokens_left_per_block"] = span.server_info.cache_tokens_left // (span.length * 2) | |
server_rows.append(row) | |
model_reports.append( | |
dict( | |
name=model.name, | |
short_name=model.short_name, | |
state="healthy" if block_healthy.all() else "broken", | |
server_rows=server_rows, | |
**asdict(model), | |
) | |
) | |
reachability_issues = [ | |
dict(peer_id=peer_id, err=info["error"]) for peer_id, info in sorted(reach_infos.items()) if not info["ok"] | |
] | |
return dict( | |
bootstrap_states=bootstrap_states, | |
top_contributors=top_contributors, | |
model_reports=model_reports, | |
reachability_issues=reachability_issues, | |
last_updated=datetime.datetime.now(datetime.timezone.utc), | |
update_period=config.UPDATE_PERIOD, | |
update_duration=time.perf_counter() - start_time | |
) | |