steffenc commited on
Commit
1321e38
·
1 Parent(s): fc5ac41

Add stream forward

Browse files
Files changed (1) hide show
  1. forward.py +244 -0
forward.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import sys
3
+ import asyncio
4
+ import numpy as np
5
+ import bittensor as bt
6
+ import traceback
7
+ from typing import List, Dict, Awaitable
8
+ from prompting.agent import HumanAgent
9
+ from prompting.dendrite import DendriteResponseEvent
10
+ from prompting.conversation import create_task
11
+ from prompting.protocol import StreamPromptingSynapse
12
+ from prompting.rewards import RewardResult
13
+ from prompting.utils.uids import get_random_uids
14
+ from prompting.utils.logging import log_event
15
+ from prompting.utils.misc import async_log, serialize_exception_to_string
16
+ from dataclasses import dataclass
17
+
18
+ @async_log
19
+ async def generate_reference(agent):
20
+ loop = asyncio.get_running_loop()
21
+ result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
+ return result
23
+
24
+ @async_log
25
+ async def execute_dendrite_call(dendrite_call):
26
+ responses = await dendrite_call
27
+ return responses
28
+
29
+ @dataclass
30
+ class StreamResult:
31
+ synapse: StreamPromptingSynapse = None
32
+ exception: BaseException = None
33
+ uid: int = None
34
+
35
+
36
+ async def process_response(uid: int, async_generator: Awaitable):
37
+ """Process a single response asynchronously."""
38
+ try:
39
+ chunk = None # Initialize chunk with a default value
40
+ async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
41
+ bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
42
+
43
+ if chunk is not None:
44
+ synapse = chunk # last object yielded is the synapse itself with completion filled
45
+
46
+ # Assuming chunk holds the last value yielded which should be a synapse
47
+ if isinstance(synapse, StreamPromptingSynapse):
48
+ return synapse
49
+
50
+ bt.logging.debug(
51
+ f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' "
52
+ )
53
+ except Exception as e:
54
+ # bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True)
55
+ traceback_details = traceback.format_exc()
56
+ bt.logging.error(
57
+ f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}"
58
+ )
59
+
60
+ failed_synapse = StreamPromptingSynapse(
61
+ roles=["user"], messages=["failure"], completion=""
62
+ )
63
+
64
+ return failed_synapse
65
+
66
+
67
+ @async_log
68
+ async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]:
69
+ """The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
70
+ and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.
71
+
72
+ Args:
73
+ responses (Dict[int, Awaitable]): Responses contains awaitables that are used to acquire streamed miner chunks.
74
+
75
+ Raises:
76
+ ValueError
77
+
78
+ Returns:
79
+ List[StreamResult]: DataClass containing the synapse, exception, and uid
80
+ """
81
+ tasks_with_uid = [
82
+ (uid, responses[uid]) for uid, _ in responses.items()
83
+ ] # Pair UIDs with their tasks
84
+
85
+ # Start tasks, preserving order and their associated UIDs
86
+ tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]
87
+
88
+ results = await asyncio.gather(*tasks, return_exceptions=True)
89
+
90
+ mapped_results = []
91
+ # Pair each result with its original uid
92
+ for (uid, _), result in zip(tasks_with_uid, results):
93
+ # If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
94
+ if isinstance(result, StreamPromptingSynapse):
95
+ mapped_results.append(StreamResult(synapse=result, uid=uid))
96
+
97
+ # If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
98
+ elif isinstance(result, BaseException):
99
+ failed_synapse = StreamPromptingSynapse(
100
+ roles=["user"], messages=["failure"], completion=""
101
+ )
102
+ mapped_results.append(
103
+ StreamResult(synapse=failed_synapse, exception=result, uid=uid)
104
+ )
105
+
106
+ # If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
107
+ else:
108
+ bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
109
+ raise ValueError(f"Unexpected result type for UID {uid}: {result}")
110
+
111
+ return mapped_results
112
+
113
+
114
+ @async_log
115
+ async def generate_reference(agent: HumanAgent):
116
+ loop = asyncio.get_running_loop()
117
+ result = await loop.run_in_executor(
118
+ None, agent.task.generate_reference, agent.llm_pipeline
119
+ )
120
+ return result
121
+
122
+
123
+ def log_stream_results(stream_results: List[StreamResult]):
124
+ failed_responses = [
125
+ response for response in stream_results if response.exception is not None
126
+ ]
127
+ empty_responses = [
128
+ response
129
+ for response in stream_results
130
+ if response.exception is None and response.synapse.completion == ""
131
+ ]
132
+ non_empty_responses = [
133
+ response
134
+ for response in stream_results
135
+ if response.exception is None and response.synapse.completion != ""
136
+ ]
137
+
138
+ bt.logging.info(f"Total of non_empty responses: ({len(non_empty_responses)})")
139
+ bt.logging.info(f"Total of empty responses: ({len(empty_responses)})")
140
+ bt.logging.info(
141
+ f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}"
142
+ )
143
+
144
+ for failed_response in failed_responses:
145
+ formatted_exception = serialize_exception_to_string(failed_response.exception)
146
+ bt.logging.error(
147
+ f"Failed response for uid {failed_response.uid}: {formatted_exception}"
148
+ )
149
+
150
+
151
+ async def run_step(
152
+ self, agent: HumanAgent, k: int, timeout: float, exclude: list = None
153
+ ):
154
+ """Executes a single step of the agent, which consists of:
155
+ - Getting a list of uids to query
156
+ - Querying the network
157
+ - Rewarding the network
158
+ - Updating the scores
159
+ - Logging the event
160
+
161
+ Args:
162
+ agent (HumanAgent): The agent to run the step for.
163
+ k (int): The number of uids to query.
164
+ timeout (float): The timeout for the queries.
165
+ exclude (list, optional): The list of uids to exclude from the query. Defaults to [].
166
+ """
167
+
168
+ bt.logging.debug("run_step", agent.task.name)
169
+
170
+ # Record event start time.
171
+ start_time = time.time()
172
+ # Get the list of uids to query for this step.
173
+ uids = get_random_uids(self, k=k, exclude=exclude or []).to(self.device)
174
+ uids_cpu = uids.cpu().tolist()
175
+
176
+ axons = [self.metagraph.axons[uid] for uid in uids]
177
+
178
+ # Directly call dendrite and process responses in parallel
179
+ streams_responses = await self.dendrite(
180
+ axons=axons,
181
+ synapse=StreamPromptingSynapse(roles=["user"], messages=[agent.challenge]),
182
+ timeout=timeout,
183
+ deserialize=False,
184
+ streaming=True,
185
+ )
186
+
187
+ # Prepare the task for handling stream responses
188
+ handle_stream_responses_task = asyncio.create_task(
189
+ handle_response(responses=dict(zip(uids_cpu, streams_responses)))
190
+ )
191
+
192
+ if not agent.task.static_reference:
193
+ reference_generation_task = generate_reference(agent)
194
+ _, stream_results = await asyncio.gather(
195
+ reference_generation_task, handle_stream_responses_task
196
+ )
197
+ else:
198
+ stream_results = await handle_stream_responses_task
199
+
200
+ log_stream_results(stream_results)
201
+
202
+ all_synapses_results = [stream_result.synapse for stream_result in stream_results]
203
+
204
+ # Encapsulate the responses in a response event (dataclass)
205
+ response_event = DendriteResponseEvent(
206
+ responses=all_synapses_results, uids=uids, timeout=timeout
207
+ )
208
+
209
+ bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}")
210
+ # Reward the responses and get the reward result (dataclass)
211
+ # This contains a list of RewardEvents but can be exported as a dict (column-wise) for logging etc
212
+ reward_result = RewardResult(
213
+ self.reward_pipeline,
214
+ agent=agent,
215
+ response_event=response_event,
216
+ device=self.device,
217
+ )
218
+ bt.logging.info(f"Created RewardResult:\n {reward_result}")
219
+
220
+ # The original idea was that the agent is 'satisfied' when it gets a good enough response (e.g. reward critera is met, such as ROUGE>threshold)
221
+ agent.update_progress(
222
+ top_reward=reward_result.rewards.max(),
223
+ top_response=response_event.completions[reward_result.rewards.argmax()],
224
+ )
225
+
226
+ self.update_scores(reward_result.rewards, uids)
227
+
228
+ stream_results_uids = [stream_result.uid for stream_result in stream_results]
229
+ stream_results_exceptions = [
230
+ serialize_exception_to_string(stream_result.exception)
231
+ for stream_result in stream_results
232
+ ]
233
+ # Log the step event.
234
+ event = {
235
+ "block": self.block,
236
+ "step_time": time.time() - start_time,
237
+ "stream_results_uids": stream_results_uids,
238
+ "stream_results_exceptions": stream_results_exceptions,
239
+ **agent.__state_dict__(full=self.config.neuron.log_full),
240
+ **reward_result.__state_dict__(full=self.config.neuron.log_full),
241
+ **response_event.__state_dict__(),
242
+ }
243
+
244
+ return event