File size: 5,585 Bytes
90e26fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import datetime
import time
from collections import Counter
from contextlib import suppress
from dataclasses import asdict
from functools import partial

import hivemind
import numpy as np
from multiaddr import Multiaddr
from petals.data_structures import UID_DELIMITER, ServerState
from petals.utils.dht import compute_spans, get_remote_module_infos

import config
from data_structures import ModelInfo
from p2p_utils import check_reachability_parallel, get_peers_ips, extract_peer_ip_info

logger = hivemind.get_logger(__name__)


def fetch_health_state(dht: hivemind.DHT) -> dict:
    start_time = time.perf_counter()
    bootstrap_peer_ids = []
    for addr in config.INITIAL_PEERS:
        peer_id = hivemind.PeerID.from_base58(Multiaddr(addr)["p2p"])
        if peer_id not in bootstrap_peer_ids:
            bootstrap_peer_ids.append(peer_id)

    reach_infos = dht.run_coroutine(partial(check_reachability_parallel, bootstrap_peer_ids))
    bootstrap_states = ["online" if reach_infos[peer_id]["ok"] else "unreachable" for peer_id in bootstrap_peer_ids]

    models = config.MODELS[:]
    model_index = dht.get("_petals.models", latest=True)
    if model_index is not None and isinstance(model_index.value, dict):
        official_dht_prefixes = {model.dht_prefix for model in models}
        custom_models = []
        for dht_prefix, model in model_index.value.items():
            if dht_prefix in official_dht_prefixes:
                continue
            with suppress(TypeError, ValueError):
                model_info = ModelInfo.from_dict(model.value)
                if model_info.repository is None or not model_info.repository.startswith("https://huggingface.co/"):
                    continue
                model_info.dht_prefix = dht_prefix
                model_info.official = False
                custom_models.append(model_info)
        models.extend(sorted(custom_models, key=lambda info: (-info.num_blocks, info.dht_prefix)))
    logger.info(f"Fetching info for models {[info.name for info in models]}")

    block_uids = [f"{model.dht_prefix}{UID_DELIMITER}{i}" for model in models for i in range(model.num_blocks)]
    module_infos = get_remote_module_infos(dht, block_uids, latest=True)

    model_servers = {}
    all_servers = {}
    offset = 0
    for model in models:
        model_servers[model.dht_prefix] = compute_spans(
            module_infos[offset : offset + model.num_blocks], min_state=ServerState.OFFLINE
        )
        all_servers.update(model_servers[model.dht_prefix])
        offset += model.num_blocks

    online_servers = [peer_id for peer_id, span in all_servers.items() if span.state == ServerState.ONLINE]

    reach_infos.update(dht.run_coroutine(partial(check_reachability_parallel, online_servers, fetch_info=True)))
    peers_info = {str(peer.peer_id): {"location": extract_peer_ip_info(str(peer.addrs[0])), "multiaddrs": [str(multiaddr) for multiaddr in peer.addrs]} for peer in dht.run_coroutine(get_peers_ips)}

    top_contributors = Counter()
    model_reports = []
    for model in models:
        block_healthy = np.zeros(model.num_blocks, dtype=bool)
        server_rows = []
        for peer_id, span in sorted(model_servers[model.dht_prefix].items()):
            reachable = reach_infos[peer_id]["ok"] if peer_id in reach_infos else True
            state = span.state.name.lower() if reachable else "unreachable"
            if state == "online":
                block_healthy[span.start : span.end] = True

            show_public_name = state == "online" and span.length >= 10
            if model.official and span.server_info.public_name and show_public_name:
                top_contributors[span.server_info.public_name] += span.length

            row = {
                "short_peer_id": "..." + str(peer_id)[-6:],
                "peer_id": peer_id,
                "peer_ip_info": peers_info.get(str(peer_id), "unknown"),
                "show_public_name": show_public_name,
                "state": state,
                "span": span,
                "adapters": [dict(name=name, short_name=name.split("/")[-1]) for name in span.server_info.adapters],
                "pings_to_me": {
                    str(origin_id): origin.server_info.next_pings[str(peer_id)]
                    for origin_id, origin in model_servers[model.dht_prefix].items()
                    if origin.server_info.next_pings is not None and str(peer_id) in origin.server_info.next_pings
                },
            }
            if span.server_info.cache_tokens_left is not None:
                # We use num_blocks * 2 to account for both keys and values
                row["cache_tokens_left_per_block"] = span.server_info.cache_tokens_left // (span.length * 2)
            server_rows.append(row)

        model_reports.append(
            dict(
                name=model.name,
                short_name=model.short_name,
                state="healthy" if block_healthy.all() else "broken",
                server_rows=server_rows,
                **asdict(model),
            )
        )

    reachability_issues = [
        dict(peer_id=peer_id, err=info["error"]) for peer_id, info in sorted(reach_infos.items()) if not info["ok"]
    ]

    return dict(
        bootstrap_states=bootstrap_states,
        top_contributors=top_contributors,
        model_reports=model_reports,
        reachability_issues=reachability_issues,
        last_updated=datetime.datetime.now(datetime.timezone.utc),
        update_period=config.UPDATE_PERIOD,
        update_duration=time.perf_counter() - start_time
    )