prompting-dashboard / forward.py
pedroferreira's picture
add middlewares + api refactoring
32e1e2e
raw
history blame
9.55 kB
# import time
# import sys
# import asyncio
# import numpy as np
# import bittensor as bt
# import traceback
# from typing import List, Dict, Awaitable
# from prompting.agent import HumanAgent
# from prompting.dendrite import DendriteResponseEvent
# from prompting.conversation import create_task
# from prompting.protocol import StreamPromptingSynapse
# from prompting.rewards import RewardResult
# from prompting.utils.uids import get_random_uids
# from prompting.utils.logging import log_event
# from prompting.utils.misc import async_log, serialize_exception_to_string
# from dataclasses import dataclass
# @async_log
# async def generate_reference(agent):
# loop = asyncio.get_running_loop()
# result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
# return result
# @async_log
# async def execute_dendrite_call(dendrite_call):
# responses = await dendrite_call
# return responses
# @dataclass
# class StreamResult:
# synapse: StreamPromptingSynapse = None
# exception: BaseException = None
# uid: int = None
# async def process_response(uid: int, async_generator: Awaitable):
# """Process a single response asynchronously."""
# try:
# chunk = None # Initialize chunk with a default value
# async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
# bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
# if chunk is not None:
# synapse = chunk # last object yielded is the synapse itself with completion filled
# # Assuming chunk holds the last value yielded which should be a synapse
# if isinstance(synapse, StreamPromptingSynapse):
# return synapse
# bt.logging.debug(
# f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' "
# )
# except Exception as e:
# # bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True)
# traceback_details = traceback.format_exc()
# bt.logging.error(
# f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}"
# )
# failed_synapse = StreamPromptingSynapse(
# roles=["user"], messages=["failure"], completion=""
# )
# return failed_synapse
# @async_log
# async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]:
# """The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
# and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.
# Args:
# responses (Dict[int, Awaitable]): Responses contains awaitables that are used to acquire streamed miner chunks.
# Raises:
# ValueError
# Returns:
# List[StreamResult]: DataClass containing the synapse, exception, and uid
# """
# tasks_with_uid = [
# (uid, responses[uid]) for uid, _ in responses.items()
# ] # Pair UIDs with their tasks
# # Start tasks, preserving order and their associated UIDs
# tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]
# results = await asyncio.gather(*tasks, return_exceptions=True)
# mapped_results = []
# # Pair each result with its original uid
# for (uid, _), result in zip(tasks_with_uid, results):
# # If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
# if isinstance(result, StreamPromptingSynapse):
# mapped_results.append(StreamResult(synapse=result, uid=uid))
# # If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
# elif isinstance(result, BaseException):
# failed_synapse = StreamPromptingSynapse(
# roles=["user"], messages=["failure"], completion=""
# )
# mapped_results.append(
# StreamResult(synapse=failed_synapse, exception=result, uid=uid)
# )
# # If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
# else:
# bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
# raise ValueError(f"Unexpected result type for UID {uid}: {result}")
# return mapped_results
# @async_log
# async def generate_reference(agent: HumanAgent):
# loop = asyncio.get_running_loop()
# result = await loop.run_in_executor(
# None, agent.task.generate_reference, agent.llm_pipeline
# )
# return result
# def log_stream_results(stream_results: List[StreamResult]):
# failed_responses = [
# response for response in stream_results if response.exception is not None
# ]
# empty_responses = [
# response
# for response in stream_results
# if response.exception is None and response.synapse.completion == ""
# ]
# non_empty_responses = [
# response
# for response in stream_results
# if response.exception is None and response.synapse.completion != ""
# ]
# bt.logging.info(f"Total of non_empty responses: ({len(non_empty_responses)})")
# bt.logging.info(f"Total of empty responses: ({len(empty_responses)})")
# bt.logging.info(
# f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}"
# )
# for failed_response in failed_responses:
# formatted_exception = serialize_exception_to_string(failed_response.exception)
# bt.logging.error(
# f"Failed response for uid {failed_response.uid}: {formatted_exception}"
# )
# async def run_step(
# self, agent: HumanAgent, k: int, timeout: float, exclude: list = None
# ):
# """Executes a single step of the agent, which consists of:
# - Getting a list of uids to query
# - Querying the network
# - Rewarding the network
# - Updating the scores
# - Logging the event
# Args:
# agent (HumanAgent): The agent to run the step for.
# k (int): The number of uids to query.
# timeout (float): The timeout for the queries.
# exclude (list, optional): The list of uids to exclude from the query. Defaults to [].
# """
# bt.logging.debug("run_step", agent.task.name)
# # Record event start time.
# start_time = time.time()
# # Get the list of uids to query for this step.
# uids = get_random_uids(self, k=k, exclude=exclude or []).to(self.device)
# uids_cpu = uids.cpu().tolist()
# axons = [self.metagraph.axons[uid] for uid in uids]
# # Directly call dendrite and process responses in parallel
# streams_responses = await self.dendrite(
# axons=axons,
# synapse=StreamPromptingSynapse(roles=["user"], messages=[agent.challenge]),
# timeout=timeout,
# deserialize=False,
# streaming=True,
# )
# # Prepare the task for handling stream responses
# handle_stream_responses_task = asyncio.create_task(
# handle_response(responses=dict(zip(uids_cpu, streams_responses)))
# )
# if not agent.task.static_reference:
# reference_generation_task = generate_reference(agent)
# _, stream_results = await asyncio.gather(
# reference_generation_task, handle_stream_responses_task
# )
# else:
# stream_results = await handle_stream_responses_task
# log_stream_results(stream_results)
# all_synapses_results = [stream_result.synapse for stream_result in stream_results]
# # Encapsulate the responses in a response event (dataclass)
# response_event = DendriteResponseEvent(
# responses=all_synapses_results, uids=uids, timeout=timeout
# )
# bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}")
# # Reward the responses and get the reward result (dataclass)
# # This contains a list of RewardEvents but can be exported as a dict (column-wise) for logging etc
# reward_result = RewardResult(
# self.reward_pipeline,
# agent=agent,
# response_event=response_event,
# device=self.device,
# )
# bt.logging.info(f"Created RewardResult:\n {reward_result}")
# # The original idea was that the agent is 'satisfied' when it gets a good enough response (e.g. reward critera is met, such as ROUGE>threshold)
# agent.update_progress(
# top_reward=reward_result.rewards.max(),
# top_response=response_event.completions[reward_result.rewards.argmax()],
# )
# self.update_scores(reward_result.rewards, uids)
# stream_results_uids = [stream_result.uid for stream_result in stream_results]
# stream_results_exceptions = [
# serialize_exception_to_string(stream_result.exception)
# for stream_result in stream_results
# ]
# # Log the step event.
# event = {
# "block": self.block,
# "step_time": time.time() - start_time,
# "stream_results_uids": stream_results_uids,
# "stream_results_exceptions": stream_results_exceptions,
# **agent.__state_dict__(full=self.config.neuron.log_full),
# **reward_result.__state_dict__(full=self.config.neuron.log_full),
# **response_event.__state_dict__(),
# }
# return event