Spaces:
Sleeping
Sleeping
File size: 9,163 Bytes
1e9def1 fdc8fdb 1e9def1 fdc8fdb 1e9def1 fdc8fdb 1e9def1 fdc8fdb 1e9def1 fdc8fdb 1e9def1 fdc8fdb 1e9def1 fdc8fdb 1e9def1 fdc8fdb 1e9def1 fdc8fdb 1e9def1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
import time
import sys
import asyncio
import numpy as np
import bittensor as bt
import traceback
from typing import List, Dict, Awaitable
from prompting.agent import HumanAgent
from prompting.dendrite import DendriteResponseEvent
from prompting.conversation import create_task
from prompting.protocol import StreamPromptingSynapse
from prompting.rewards import RewardResult
from prompting.utils.uids import get_random_uids
from prompting.utils.logging import log_event
from prompting.utils.misc import async_log, serialize_exception_to_string
from dataclasses import dataclass
@async_log
async def generate_reference(agent):
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None, agent.task.generate_reference, agent.llm_pipeline
)
return result
@async_log
async def execute_dendrite_call(dendrite_call):
responses = await dendrite_call
return responses
@dataclass
class StreamResult:
synapse: StreamPromptingSynapse = None
exception: BaseException = None
uid: int = None
async def process_response(uid: int, async_generator: Awaitable):
"""Process a single response asynchronously."""
try:
chunk = None # Initialize chunk with a default value
async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
if chunk is not None:
synapse = chunk # last object yielded is the synapse itself with completion filled
# Assuming chunk holds the last value yielded which should be a synapse
if isinstance(synapse, StreamPromptingSynapse):
return synapse
bt.logging.debug(
f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' "
)
except Exception as e:
# bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True)
traceback_details = traceback.format_exc()
bt.logging.error(
f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}"
)
failed_synapse = StreamPromptingSynapse(
roles=["user"], messages=["failure"], completion=""
)
return failed_synapse
@async_log
async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]:
"""The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.
Args:
responses (Dict[int, Awaitable]): Responses contains awaitables that are used to acquire streamed miner chunks.
Raises:
ValueError
Returns:
List[StreamResult]: DataClass containing the synapse, exception, and uid
"""
tasks_with_uid = [
(uid, responses[uid]) for uid, _ in responses.items()
] # Pair UIDs with their tasks
# Start tasks, preserving order and their associated UIDs
tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]
results = await asyncio.gather(*tasks, return_exceptions=True)
mapped_results = []
# Pair each result with its original uid
for (uid, _), result in zip(tasks_with_uid, results):
# If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
if isinstance(result, StreamPromptingSynapse):
mapped_results.append(StreamResult(synapse=result, uid=uid))
# If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
elif isinstance(result, BaseException):
failed_synapse = StreamPromptingSynapse(
roles=["user"], messages=["failure"], completion=""
)
mapped_results.append(
StreamResult(synapse=failed_synapse, exception=result, uid=uid)
)
# If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
else:
bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
raise ValueError(f"Unexpected result type for UID {uid}: {result}")
return mapped_results
@async_log
async def generate_reference(agent: HumanAgent):
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None, agent.task.generate_reference, agent.llm_pipeline
)
return result
def log_stream_results(stream_results: List[StreamResult]):
failed_responses = [
response for response in stream_results if response.exception is not None
]
empty_responses = [
response
for response in stream_results
if response.exception is None and response.synapse.completion == ""
]
non_empty_responses = [
response
for response in stream_results
if response.exception is None and response.synapse.completion != ""
]
bt.logging.info(f"Total of non_empty responses: ({len(non_empty_responses)})")
bt.logging.info(f"Total of empty responses: ({len(empty_responses)})")
bt.logging.info(
f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}"
)
for failed_response in failed_responses:
formatted_exception = serialize_exception_to_string(failed_response.exception)
bt.logging.error(
f"Failed response for uid {failed_response.uid}: {formatted_exception}"
)
async def run_step(
self, agent: HumanAgent, k: int, timeout: float, exclude: list = None
):
"""Executes a single step of the agent, which consists of:
- Getting a list of uids to query
- Querying the network
- Rewarding the network
- Updating the scores
- Logging the event
Args:
agent (HumanAgent): The agent to run the step for.
k (int): The number of uids to query.
timeout (float): The timeout for the queries.
exclude (list, optional): The list of uids to exclude from the query. Defaults to [].
"""
bt.logging.debug("run_step", agent.task.name)
# Record event start time.
start_time = time.time()
# Get the list of uids to query for this step.
uids = get_random_uids(self, k=k, exclude=exclude or []).to(self.device)
uids_cpu = uids.cpu().tolist()
axons = [self.metagraph.axons[uid] for uid in uids]
# Directly call dendrite and process responses in parallel
streams_responses = await self.dendrite(
axons=axons,
synapse=StreamPromptingSynapse(roles=["user"], messages=[agent.challenge]),
timeout=timeout,
deserialize=False,
streaming=True,
)
# Prepare the task for handling stream responses
handle_stream_responses_task = asyncio.create_task(
handle_response(responses=dict(zip(uids_cpu, streams_responses)))
)
if not agent.task.static_reference:
reference_generation_task = generate_reference(agent)
_, stream_results = await asyncio.gather(
reference_generation_task, handle_stream_responses_task
)
else:
stream_results = await handle_stream_responses_task
log_stream_results(stream_results)
all_synapses_results = [stream_result.synapse for stream_result in stream_results]
# Encapsulate the responses in a response event (dataclass)
response_event = DendriteResponseEvent(
responses=all_synapses_results, uids=uids, timeout=timeout
)
bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}")
# Reward the responses and get the reward result (dataclass)
# This contains a list of RewardEvents but can be exported as a dict (column-wise) for logging etc
reward_result = RewardResult(
self.reward_pipeline,
agent=agent,
response_event=response_event,
device=self.device,
)
bt.logging.info(f"Created RewardResult:\n {reward_result}")
# The original idea was that the agent is 'satisfied' when it gets a good enough response (e.g. reward critera is met, such as ROUGE>threshold)
agent.update_progress(
top_reward=reward_result.rewards.max(),
top_response=response_event.completions[reward_result.rewards.argmax()],
)
self.update_scores(reward_result.rewards, uids)
stream_results_uids = [stream_result.uid for stream_result in stream_results]
stream_results_exceptions = [
serialize_exception_to_string(stream_result.exception)
for stream_result in stream_results
]
# Log the step event.
event = {
"block": self.block,
"step_time": time.time() - start_time,
"stream_results_uids": stream_results_uids,
"stream_results_exceptions": stream_results_exceptions,
**agent.__state_dict__(full=self.config.neuron.log_full),
**reward_result.__state_dict__(full=self.config.neuron.log_full),
**response_event.__state_dict__(),
}
return event
|