File size: 4,409 Bytes
48f02b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

import json
import asyncio

import traceback
import bittensor as bt

import utils

from typing import List
from neurons.validator import Validator
from prompting.forward import handle_response
from prompting.dendrite import DendriteResponseEvent
from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
from prompting.utils.uids import get_random_uids

from aiohttp import web

from aiohttp.web_response import Response


async def single_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> Response:

    try:
        # Guess the task name of current request
        task_name = utils.guess_task_name(messages[-1])

        # Get the list of uids to query for this step.
        uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist()
        axons = [validator.metagraph.axons[uid] for uid in uids]

        # Make calls to the network with the prompt.
        bt.logging.info(f'Calling dendrite')
        responses = await validator.dendrite(
            axons=axons,
            synapse=PromptingSynapse(roles=roles, messages=messages),
            timeout=timeout,
        )

        bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
        # Encapsulate the responses in a response event (dataclass)
        response_event = DendriteResponseEvent(responses, uids)

        # convert dict to json
        response = response_event.__state_dict__()

        response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
        valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]

        response['task_name'] = task_name
        response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer)

        bt.logging.info(f"Response:\n {response}")
        return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))

    except Exception:
        bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}')
        return Response(status=500, reason="Internal error")


async def stream_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> web.StreamResponse:

    try:
        # Guess the task name of current request
        task_name = utils.guess_task_name(messages[-1])

        # Get the list of uids to query for this step.
        uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist()
        axons = [validator.metagraph.axons[uid] for uid in uids]

        # Make calls to the network with the prompt.
        bt.logging.info(f'Calling dendrite')
        streams_responses = await validator.dendrite(
            axons=axons,
            synapse=StreamPromptingSynapse(roles=roles, messages=messages),
            timeout=timeout,
            deserialize=False,
            streaming=True,
        )

        # Prepare the task for handling stream responses
        handle_stream_responses_task = asyncio.create_task(
            handle_response(responses=dict(zip(uids, streams_responses)))
        )

        stream_results = await handle_stream_responses_task

        responses = [stream_result.synapse for stream_result in stream_results]
        bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
        # Encapsulate the responses in a response event (dataclass)
        response_event = DendriteResponseEvent(responses, uids)

        # convert dict to json
        response = response_event.__state_dict__()

        response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
        valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]

        response['task_name'] = task_name
        response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer)

        bt.logging.info(f"Response:\n {response}")
        return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))

    except Exception:
        bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}')
        return Response(status=500, reason="Internal error")