Spaces:
Sleeping
Sleeping
import time | |
import sys | |
import asyncio | |
import numpy as np | |
import bittensor as bt | |
import traceback | |
from typing import List, Dict, Awaitable | |
from prompting.agent import HumanAgent | |
from prompting.dendrite import DendriteResponseEvent | |
from prompting.conversation import create_task | |
from prompting.protocol import StreamPromptingSynapse | |
from prompting.rewards import RewardResult | |
from prompting.utils.uids import get_random_uids | |
from prompting.utils.logging import log_event | |
from prompting.utils.misc import async_log, serialize_exception_to_string | |
from dataclasses import dataclass | |
async def generate_reference(agent): | |
loop = asyncio.get_running_loop() | |
result = await loop.run_in_executor( | |
None, agent.task.generate_reference, agent.llm_pipeline | |
) | |
return result | |
async def execute_dendrite_call(dendrite_call): | |
responses = await dendrite_call | |
return responses | |
class StreamResult: | |
synapse: StreamPromptingSynapse = None | |
exception: BaseException = None | |
uid: int = None | |
async def process_response(uid: int, async_generator: Awaitable): | |
"""Process a single response asynchronously.""" | |
try: | |
chunk = None # Initialize chunk with a default value | |
async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse. | |
bt.logging.debug(f"\nchunk for uid {uid}: {chunk}") | |
if chunk is not None: | |
synapse = chunk # last object yielded is the synapse itself with completion filled | |
# Assuming chunk holds the last value yielded which should be a synapse | |
if isinstance(synapse, StreamPromptingSynapse): | |
return synapse | |
bt.logging.debug( | |
f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' " | |
) | |
except Exception as e: | |
# bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True) | |
traceback_details = traceback.format_exc() | |
bt.logging.error( | |
f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}" | |
) | |
failed_synapse = StreamPromptingSynapse( | |
roles=["user"], messages=["failure"], completion="" | |
) | |
return failed_synapse | |
async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]: | |
"""The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks | |
and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults. | |
Args: | |
responses (Dict[int, Awaitable]): Responses contains awaitables that are used to acquire streamed miner chunks. | |
Raises: | |
ValueError | |
Returns: | |
List[StreamResult]: DataClass containing the synapse, exception, and uid | |
""" | |
tasks_with_uid = [ | |
(uid, responses[uid]) for uid, _ in responses.items() | |
] # Pair UIDs with their tasks | |
# Start tasks, preserving order and their associated UIDs | |
tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid] | |
results = await asyncio.gather(*tasks, return_exceptions=True) | |
mapped_results = [] | |
# Pair each result with its original uid | |
for (uid, _), result in zip(tasks_with_uid, results): | |
# If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions | |
if isinstance(result, StreamPromptingSynapse): | |
mapped_results.append(StreamResult(synapse=result, uid=uid)) | |
# If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse | |
elif isinstance(result, BaseException): | |
failed_synapse = StreamPromptingSynapse( | |
roles=["user"], messages=["failure"], completion="" | |
) | |
mapped_results.append( | |
StreamResult(synapse=failed_synapse, exception=result, uid=uid) | |
) | |
# If the result is neither an error or a StreamSynapse, log the error and raise a ValueError | |
else: | |
bt.logging.error(f"Unexpected result type for UID {uid}: {result}") | |
raise ValueError(f"Unexpected result type for UID {uid}: {result}") | |
return mapped_results | |
async def generate_reference(agent: HumanAgent): | |
loop = asyncio.get_running_loop() | |
result = await loop.run_in_executor( | |
None, agent.task.generate_reference, agent.llm_pipeline | |
) | |
return result | |
def log_stream_results(stream_results: List[StreamResult]): | |
failed_responses = [ | |
response for response in stream_results if response.exception is not None | |
] | |
empty_responses = [ | |
response | |
for response in stream_results | |
if response.exception is None and response.synapse.completion == "" | |
] | |
non_empty_responses = [ | |
response | |
for response in stream_results | |
if response.exception is None and response.synapse.completion != "" | |
] | |
bt.logging.info(f"Total of non_empty responses: ({len(non_empty_responses)})") | |
bt.logging.info(f"Total of empty responses: ({len(empty_responses)})") | |
bt.logging.info( | |
f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}" | |
) | |
for failed_response in failed_responses: | |
formatted_exception = serialize_exception_to_string(failed_response.exception) | |
bt.logging.error( | |
f"Failed response for uid {failed_response.uid}: {formatted_exception}" | |
) | |
async def run_step( | |
self, agent: HumanAgent, k: int, timeout: float, exclude: list = None | |
): | |
"""Executes a single step of the agent, which consists of: | |
- Getting a list of uids to query | |
- Querying the network | |
- Rewarding the network | |
- Updating the scores | |
- Logging the event | |
Args: | |
agent (HumanAgent): The agent to run the step for. | |
k (int): The number of uids to query. | |
timeout (float): The timeout for the queries. | |
exclude (list, optional): The list of uids to exclude from the query. Defaults to []. | |
""" | |
bt.logging.debug("run_step", agent.task.name) | |
# Record event start time. | |
start_time = time.time() | |
# Get the list of uids to query for this step. | |
uids = get_random_uids(self, k=k, exclude=exclude or []).to(self.device) | |
uids_cpu = uids.cpu().tolist() | |
axons = [self.metagraph.axons[uid] for uid in uids] | |
# Directly call dendrite and process responses in parallel | |
streams_responses = await self.dendrite( | |
axons=axons, | |
synapse=StreamPromptingSynapse(roles=["user"], messages=[agent.challenge]), | |
timeout=timeout, | |
deserialize=False, | |
streaming=True, | |
) | |
# Prepare the task for handling stream responses | |
handle_stream_responses_task = asyncio.create_task( | |
handle_response(responses=dict(zip(uids_cpu, streams_responses))) | |
) | |
if not agent.task.static_reference: | |
reference_generation_task = generate_reference(agent) | |
_, stream_results = await asyncio.gather( | |
reference_generation_task, handle_stream_responses_task | |
) | |
else: | |
stream_results = await handle_stream_responses_task | |
log_stream_results(stream_results) | |
all_synapses_results = [stream_result.synapse for stream_result in stream_results] | |
# Encapsulate the responses in a response event (dataclass) | |
response_event = DendriteResponseEvent( | |
responses=all_synapses_results, uids=uids, timeout=timeout | |
) | |
bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}") | |
# Reward the responses and get the reward result (dataclass) | |
# This contains a list of RewardEvents but can be exported as a dict (column-wise) for logging etc | |
reward_result = RewardResult( | |
self.reward_pipeline, | |
agent=agent, | |
response_event=response_event, | |
device=self.device, | |
) | |
bt.logging.info(f"Created RewardResult:\n {reward_result}") | |
# The original idea was that the agent is 'satisfied' when it gets a good enough response (e.g. reward critera is met, such as ROUGE>threshold) | |
agent.update_progress( | |
top_reward=reward_result.rewards.max(), | |
top_response=response_event.completions[reward_result.rewards.argmax()], | |
) | |
self.update_scores(reward_result.rewards, uids) | |
stream_results_uids = [stream_result.uid for stream_result in stream_results] | |
stream_results_exceptions = [ | |
serialize_exception_to_string(stream_result.exception) | |
for stream_result in stream_results | |
] | |
# Log the step event. | |
event = { | |
"block": self.block, | |
"step_time": time.time() - start_time, | |
"stream_results_uids": stream_results_uids, | |
"stream_results_exceptions": stream_results_exceptions, | |
**agent.__state_dict__(full=self.config.neuron.log_full), | |
**reward_result.__state_dict__(full=self.config.neuron.log_full), | |
**response_event.__state_dict__(), | |
} | |
return event | |