File size: 2,412 Bytes
d3ebdae
9d1a999
b073cce
 
 
2a5b08d
a415c67
b073cce
2a5b08d
cd059dd
d3ebdae
fdc8fdb
ab063cf
b073cce
fdc8fdb
a415c67
cd059dd
 
 
 
 
 
 
 
 
 
 
 
 
7ea7f29
fdc8fdb
2a5b08d
 
 
 
cd059dd
2a5b08d
 
 
9e8df1b
2a5b08d
 
 
 
 
 
 
 
 
 
 
 
d3ebdae
1862b79
d3ebdae
1862b79
d3ebdae
b073cce
1b0bdb5
7ea7f29
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import asyncio
import random
import bittensor as bt
from prompting.validator import Validator
from prompting.utils.uids import get_random_uids
from prompting.protocol import StreamPromptingSynapse
from .base import QueryValidatorParams, ValidatorAPI
from aiohttp.web_response import Response, StreamResponse
from .streamer import AsyncResponseDataStreamer
from .validator_utils import get_top_incentive_uids
from .stream_manager import StreamManager

class S1ValidatorAPI(ValidatorAPI):
    def __init__(self):
        self.validator = Validator()

    def sample_uids(self, params: QueryValidatorParams):    
        if params.sampling_mode == "random":
            uids = get_random_uids(
                self.validator, k=params.k_miners, exclude=params.exclude or []
            ).tolist()
            return uids
        if params.sampling_mode == "top_incentive":
            metagraph = self.validator.metagraph
            vpermit_tao_limit = self.validator.config.neuron.vpermit_tao_limit
            
            top_uids = get_top_incentive_uids(metagraph, k=params.k_miners, vpermit_tao_limit=vpermit_tao_limit)
            
            return top_uids

    async def get_stream_response(self, params: QueryValidatorParams) -> StreamResponse:
        # Guess the task name of current request
        # task_name = utils.guess_task_name(params.messages[-1])

        # Get the list of uids to query for this step.
        uids =  self.sample_uids(params)
        axons = [self.validator.metagraph.axons[uid] for uid in uids]

        # Make calls to the network with the prompt.
        bt.logging.info(f"Sampling dendrite by {params.sampling_mode} with roles {params.roles} and messages {params.messages}")

        streams_responses = await self.validator.dendrite(
            axons=axons,
            synapse=StreamPromptingSynapse(
                roles=params.roles, messages=params.messages
            ),
            timeout=params.timeout,
            deserialize=False,
            streaming=True,
        )
        
        # Creates a streamer from the selected stream
        stream_manager = StreamManager()
        selected_stream = await stream_manager.process_streams(params.request, streams_responses, uids)
        
        return selected_stream


    async def query_validator(self, params: QueryValidatorParams) -> Response:
        return await self.get_stream_response(params)