File size: 4,793 Bytes
b073cce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import json
import utils
import traceback
import bittensor as bt
import asyncio
from prompting.forward import handle_response
from prompting.validator import Validator
from prompting.utils.uids import get_random_uids
from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
from prompting.dendrite import DendriteResponseEvent
from base import QueryValidatorParams, ValidatorWrapper
from aiohttp.web_response import Response, StreamResponse
from deprecated import deprecated

class S1ValidatorWrapper(ValidatorWrapper):
    def __init__(self):
        self.validator = Validator()    
    
                    
    @deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
    async def get_response(self, params:QueryValidatorParams) -> Response:
        try:
            # Guess the task name of current request
            task_name = utils.guess_task_name(params.messages[-1])

            # Get the list of uids to query for this step.
            uids = get_random_uids(self.validator, k=params.k_miners, exclude=params.exclude or []).tolist()
            axons = [self.validator.metagraph.axons[uid] for uid in uids]

            # Make calls to the network with the prompt.
            bt.logging.info(f'Calling dendrite')
            responses = await self.validator.dendrite(
                axons=axons,
                synapse=PromptingSynapse(roles=params.roles, messages=params.messages),
                timeout=params.timeout,
            )

            bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
            # Encapsulate the responses in a response event (dataclass)
            response_event = DendriteResponseEvent(responses, uids)

            # convert dict to json
            response = response_event.__state_dict__()

            response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
            valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]

            response['task_name'] = task_name
            response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=params.prefer)

            bt.logging.info(f"Response:\n {response}")
            return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))

        except Exception:
            bt.logging.error(f'Encountered in {self.__class__.__name__}:get_response:\n{traceback.format_exc()}')
            return Response(status=500, reason="Internal error")
        
        
    async def get_stream_response(self, params:QueryValidatorParams) -> StreamResponse:
        response = StreamResponse(status=200, reason="OK")
        response.headers['Content-Type'] = 'application/json'

        await response.prepare()  # Prepare and send the headers
        
        try:
            # Guess the task name of current request
            task_name = utils.guess_task_name(params.messages[-1])

            # Get the list of uids to query for this step.
            uids = get_random_uids(self.validator, k=params.k_miners, exclude=params.exclude or []).tolist()
            axons = [self.validator.metagraph.axons[uid] for uid in uids]

            # Make calls to the network with the prompt.
            bt.logging.info(f'Calling dendrite')
            streams_responses = await self.validator.dendrite(
                axons=axons,
                synapse=StreamPromptingSynapse(roles=params.roles, messages=params.messages),
                timeout=params.timeout,
                deserialize=False,
                streaming=True,
            )

            # Asynchronous iteration over streaming responses
            async for stream_result in streams_responses:
                if stream_result is not None:
                    # Convert stream result to JSON and write to the response stream
                    json_data = json.dumps(stream_result)
                    await response.write(json_data.encode('utf-8'))

        except Exception as e:
            bt.logging.error(f'Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}')
            response.set_status(500, reason="Internal error")
            await response.write(json.dumps({'error': str(e)}).encode('utf-8'))
        finally:
            await response.write_eof()  # Ensure to close the response properly

        return response
            

    async def query_validator(self, params:QueryValidatorParams, stream: bool = True) -> Response:        
        if stream:
            return await self.get_stream_response(params)
        else:
            # DEPRECATED
            return await self.get_response(params)