api_for_chat / health.py
ldhldh's picture
Upload health.py
1576271
raw
history blame
5.59 kB
import datetime
import time
from collections import Counter
from contextlib import suppress
from dataclasses import asdict
from functools import partial
import hivemind
import numpy as np
from multiaddr import Multiaddr
from petals.data_structures import UID_DELIMITER, ServerState
from petals.utils.dht import compute_spans, get_remote_module_infos
import config
from data_structures import ModelInfo
from p2p_utils import check_reachability_parallel, get_peers_ips, extract_peer_ip_info
logger = hivemind.get_logger(__name__)
def fetch_health_state(dht: hivemind.DHT) -> dict:
start_time = time.perf_counter()
bootstrap_peer_ids = []
for addr in config.INITIAL_PEERS:
peer_id = hivemind.PeerID.from_base58(Multiaddr(addr)["p2p"])
if peer_id not in bootstrap_peer_ids:
bootstrap_peer_ids.append(peer_id)
reach_infos = dht.run_coroutine(partial(check_reachability_parallel, bootstrap_peer_ids))
bootstrap_states = ["online" if reach_infos[peer_id]["ok"] else "unreachable" for peer_id in bootstrap_peer_ids]
models = config.MODELS[:]
model_index = dht.get("_petals.models", latest=True)
if model_index is not None and isinstance(model_index.value, dict):
official_dht_prefixes = {model.dht_prefix for model in models}
custom_models = []
for dht_prefix, model in model_index.value.items():
if dht_prefix in official_dht_prefixes:
continue
with suppress(TypeError, ValueError):
model_info = ModelInfo.from_dict(model.value)
if model_info.repository is None or not model_info.repository.startswith("https://huggingface.co/"):
continue
model_info.dht_prefix = dht_prefix
model_info.official = False
custom_models.append(model_info)
models.extend(sorted(custom_models, key=lambda info: (-info.num_blocks, info.dht_prefix)))
logger.info(f"Fetching info for models {[info.name for info in models]}")
block_uids = [f"{model.dht_prefix}{UID_DELIMITER}{i}" for model in models for i in range(model.num_blocks)]
module_infos = get_remote_module_infos(dht, block_uids, latest=True)
model_servers = {}
all_servers = {}
offset = 0
for model in models:
model_servers[model.dht_prefix] = compute_spans(
module_infos[offset : offset + model.num_blocks], min_state=ServerState.OFFLINE
)
all_servers.update(model_servers[model.dht_prefix])
offset += model.num_blocks
online_servers = [peer_id for peer_id, span in all_servers.items() if span.state == ServerState.ONLINE]
reach_infos.update(dht.run_coroutine(partial(check_reachability_parallel, online_servers, fetch_info=True)))
peers_info = {str(peer.peer_id): {"location": extract_peer_ip_info(str(peer.addrs[0])), "multiaddrs": [str(multiaddr) for multiaddr in peer.addrs]} for peer in dht.run_coroutine(get_peers_ips)}
top_contributors = Counter()
model_reports = []
for model in models:
block_healthy = np.zeros(model.num_blocks, dtype=bool)
server_rows = []
for peer_id, span in sorted(model_servers[model.dht_prefix].items()):
reachable = reach_infos[peer_id]["ok"] if peer_id in reach_infos else True
state = span.state.name.lower() if reachable else "unreachable"
if state == "online":
block_healthy[span.start : span.end] = True
show_public_name = state == "online" and span.length >= 10
if model.official and span.server_info.public_name and show_public_name:
top_contributors[span.server_info.public_name] += span.length
row = {
"short_peer_id": "..." + str(peer_id)[-6:],
"peer_id": peer_id,
"peer_ip_info": peers_info.get(str(peer_id), "unknown"),
"show_public_name": show_public_name,
"state": state,
"span": span,
"adapters": [dict(name=name, short_name=name.split("/")[-1]) for name in span.server_info.adapters],
"pings_to_me": {
str(origin_id): origin.server_info.next_pings[str(peer_id)]
for origin_id, origin in model_servers[model.dht_prefix].items()
if origin.server_info.next_pings is not None and str(peer_id) in origin.server_info.next_pings
},
}
if span.server_info.cache_tokens_left is not None:
# We use num_blocks * 2 to account for both keys and values
row["cache_tokens_left_per_block"] = span.server_info.cache_tokens_left // (span.length * 2)
server_rows.append(row)
model_reports.append(
dict(
name=model.name,
short_name=model.short_name,
state="healthy" if block_healthy.all() else "broken",
server_rows=server_rows,
**asdict(model),
)
)
reachability_issues = [
dict(peer_id=peer_id, err=info["error"]) for peer_id, info in sorted(reach_infos.items()) if not info["ok"]
]
return dict(
bootstrap_states=bootstrap_states,
top_contributors=top_contributors,
model_reports=model_reports,
reachability_issues=reachability_issues,
last_updated=datetime.datetime.now(datetime.timezone.utc),
update_period=config.UPDATE_PERIOD,
update_duration=time.perf_counter() - start_time
)