Spaces:
Sleeping
Sleeping
File size: 4,409 Bytes
48f02b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import json
import asyncio
import traceback
import bittensor as bt
import utils
from typing import List
from neurons.validator import Validator
from prompting.forward import handle_response
from prompting.dendrite import DendriteResponseEvent
from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
from prompting.utils.uids import get_random_uids
from aiohttp import web
from aiohttp.web_response import Response
async def single_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> Response:
try:
# Guess the task name of current request
task_name = utils.guess_task_name(messages[-1])
# Get the list of uids to query for this step.
uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist()
axons = [validator.metagraph.axons[uid] for uid in uids]
# Make calls to the network with the prompt.
bt.logging.info(f'Calling dendrite')
responses = await validator.dendrite(
axons=axons,
synapse=PromptingSynapse(roles=roles, messages=messages),
timeout=timeout,
)
bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
# Encapsulate the responses in a response event (dataclass)
response_event = DendriteResponseEvent(responses, uids)
# convert dict to json
response = response_event.__state_dict__()
response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
response['task_name'] = task_name
response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
bt.logging.info(f"Response:\n {response}")
return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
except Exception:
bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}')
return Response(status=500, reason="Internal error")
async def stream_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> web.StreamResponse:
try:
# Guess the task name of current request
task_name = utils.guess_task_name(messages[-1])
# Get the list of uids to query for this step.
uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist()
axons = [validator.metagraph.axons[uid] for uid in uids]
# Make calls to the network with the prompt.
bt.logging.info(f'Calling dendrite')
streams_responses = await validator.dendrite(
axons=axons,
synapse=StreamPromptingSynapse(roles=roles, messages=messages),
timeout=timeout,
deserialize=False,
streaming=True,
)
# Prepare the task for handling stream responses
handle_stream_responses_task = asyncio.create_task(
handle_response(responses=dict(zip(uids, streams_responses)))
)
stream_results = await handle_stream_responses_task
responses = [stream_result.synapse for stream_result in stream_results]
bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
# Encapsulate the responses in a response event (dataclass)
response_event = DendriteResponseEvent(responses, uids)
# convert dict to json
response = response_event.__state_dict__()
response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
response['task_name'] = task_name
response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
bt.logging.info(f"Response:\n {response}")
return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
except Exception:
bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}')
return Response(status=500, reason="Internal error") |