File size: 2,441 Bytes
d3ebdae
9d1a999
b073cce
 
 
2a5b08d
a415c67
b073cce
2a5b08d
cd059dd
d3ebdae
fdc8fdb
df1e6f4
ab063cf
b073cce
fdc8fdb
a415c67
df1e6f4
cd059dd
 
 
 
 
 
 
 
df1e6f4
 
 
 
 
cd059dd
7ea7f29
fdc8fdb
2a5b08d
 
 
 
df1e6f4
2a5b08d
 
 
df1e6f4
 
 
2a5b08d
 
 
 
 
 
 
 
 
 
df1e6f4
2a5b08d
d3ebdae
df1e6f4
 
 
d3ebdae
df1e6f4
b073cce
1b0bdb5
7ea7f29
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import asyncio
import random
import bittensor as bt
from prompting.validator import Validator
from prompting.utils.uids import get_random_uids
from prompting.protocol import StreamPromptingSynapse
from .base import QueryValidatorParams, ValidatorAPI
from aiohttp.web_response import Response, StreamResponse
from .streamer import AsyncResponseDataStreamer
from .validator_utils import get_top_incentive_uids
from .stream_manager import StreamManager


class S1ValidatorAPI(ValidatorAPI):
    def __init__(self):
        self.validator = Validator()

    def sample_uids(self, params: QueryValidatorParams):
        if params.sampling_mode == "random":
            uids = get_random_uids(
                self.validator, k=params.k_miners, exclude=params.exclude or []
            ).tolist()
            return uids
        if params.sampling_mode == "top_incentive":
            metagraph = self.validator.metagraph
            vpermit_tao_limit = self.validator.config.neuron.vpermit_tao_limit

            top_uids = get_top_incentive_uids(
                metagraph, k=params.k_miners, vpermit_tao_limit=vpermit_tao_limit
            )

            return top_uids

    async def get_stream_response(self, params: QueryValidatorParams) -> StreamResponse:
        # Guess the task name of current request
        # task_name = utils.guess_task_name(params.messages[-1])

        # Get the list of uids to query for this step.
        uids = self.sample_uids(params)
        axons = [self.validator.metagraph.axons[uid] for uid in uids]

        # Make calls to the network with the prompt.
        bt.logging.info(
            f"Sampling dendrite by {params.sampling_mode} with roles {params.roles} and messages {params.messages}"
        )

        streams_responses = await self.validator.dendrite(
            axons=axons,
            synapse=StreamPromptingSynapse(
                roles=params.roles, messages=params.messages
            ),
            timeout=params.timeout,
            deserialize=False,
            streaming=True,
        )

        # Creates a streamer from the selected stream
        stream_manager = StreamManager()
        selected_stream = await stream_manager.process_streams(
            params.request, streams_responses, uids
        )

        return selected_stream

    async def query_validator(self, params: QueryValidatorParams) -> Response:
        return await self.get_stream_response(params)