import json import asyncio import traceback import bittensor as bt import utils from typing import List from neurons.validator import Validator from prompting.forward import handle_response from prompting.dendrite import DendriteResponseEvent from prompting.protocol import PromptingSynapse, StreamPromptingSynapse from prompting.utils.uids import get_random_uids from aiohttp import web from aiohttp.web_response import Response async def single_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> Response: try: # Guess the task name of current request task_name = utils.guess_task_name(messages[-1]) # Get the list of uids to query for this step. uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist() axons = [validator.metagraph.axons[uid] for uid in uids] # Make calls to the network with the prompt. bt.logging.info(f'Calling dendrite') responses = await validator.dendrite( axons=axons, synapse=PromptingSynapse(roles=roles, messages=messages), timeout=timeout, ) bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}") # Encapsulate the responses in a response event (dataclass) response_event = DendriteResponseEvent(responses, uids) # convert dict to json response = response_event.__state_dict__() response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions'])) valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v] response['task_name'] = task_name response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer) bt.logging.info(f"Response:\n {response}") return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response)) except Exception: bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}') return Response(status=500, reason="Internal error") async def stream_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> web.StreamResponse: try: # Guess the task name of current request task_name = utils.guess_task_name(messages[-1]) # Get the list of uids to query for this step. uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist() axons = [validator.metagraph.axons[uid] for uid in uids] # Make calls to the network with the prompt. bt.logging.info(f'Calling dendrite') streams_responses = await validator.dendrite( axons=axons, synapse=StreamPromptingSynapse(roles=roles, messages=messages), timeout=timeout, deserialize=False, streaming=True, ) # Prepare the task for handling stream responses handle_stream_responses_task = asyncio.create_task( handle_response(responses=dict(zip(uids, streams_responses))) ) stream_results = await handle_stream_responses_task responses = [stream_result.synapse for stream_result in stream_results] bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}") # Encapsulate the responses in a response event (dataclass) response_event = DendriteResponseEvent(responses, uids) # convert dict to json response = response_event.__state_dict__() response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions'])) valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v] response['task_name'] = task_name response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer) bt.logging.info(f"Response:\n {response}") return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response)) except Exception: bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}') return Response(status=500, reason="Internal error")