steffenc commited on
Commit
c60daaf
·
unverified ·
2 Parent(s): fc5ac41 728d41d

Merge pull request #1 from macrocosm-os/stream

Browse files
Files changed (2) hide show
  1. forward.py +244 -0
  2. server.py +71 -2
forward.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import sys
3
+ import asyncio
4
+ import numpy as np
5
+ import bittensor as bt
6
+ import traceback
7
+ from typing import List, Dict, Awaitable
8
+ from prompting.agent import HumanAgent
9
+ from prompting.dendrite import DendriteResponseEvent
10
+ from prompting.conversation import create_task
11
+ from prompting.protocol import StreamPromptingSynapse
12
+ from prompting.rewards import RewardResult
13
+ from prompting.utils.uids import get_random_uids
14
+ from prompting.utils.logging import log_event
15
+ from prompting.utils.misc import async_log, serialize_exception_to_string
16
+ from dataclasses import dataclass
17
+
18
+ @async_log
19
+ async def generate_reference(agent):
20
+ loop = asyncio.get_running_loop()
21
+ result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
+ return result
23
+
24
+ @async_log
25
+ async def execute_dendrite_call(dendrite_call):
26
+ responses = await dendrite_call
27
+ return responses
28
+
29
+ @dataclass
30
+ class StreamResult:
31
+ synapse: StreamPromptingSynapse = None
32
+ exception: BaseException = None
33
+ uid: int = None
34
+
35
+
36
+ async def process_response(uid: int, async_generator: Awaitable):
37
+ """Process a single response asynchronously."""
38
+ try:
39
+ chunk = None # Initialize chunk with a default value
40
+ async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
41
+ bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
42
+
43
+ if chunk is not None:
44
+ synapse = chunk # last object yielded is the synapse itself with completion filled
45
+
46
+ # Assuming chunk holds the last value yielded which should be a synapse
47
+ if isinstance(synapse, StreamPromptingSynapse):
48
+ return synapse
49
+
50
+ bt.logging.debug(
51
+ f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' "
52
+ )
53
+ except Exception as e:
54
+ # bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True)
55
+ traceback_details = traceback.format_exc()
56
+ bt.logging.error(
57
+ f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}"
58
+ )
59
+
60
+ failed_synapse = StreamPromptingSynapse(
61
+ roles=["user"], messages=["failure"], completion=""
62
+ )
63
+
64
+ return failed_synapse
65
+
66
+
67
+ @async_log
68
+ async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]:
69
+ """The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
70
+ and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.
71
+
72
+ Args:
73
+ responses (Dict[int, Awaitable]): Responses contains awaitables that are used to acquire streamed miner chunks.
74
+
75
+ Raises:
76
+ ValueError
77
+
78
+ Returns:
79
+ List[StreamResult]: DataClass containing the synapse, exception, and uid
80
+ """
81
+ tasks_with_uid = [
82
+ (uid, responses[uid]) for uid, _ in responses.items()
83
+ ] # Pair UIDs with their tasks
84
+
85
+ # Start tasks, preserving order and their associated UIDs
86
+ tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]
87
+
88
+ results = await asyncio.gather(*tasks, return_exceptions=True)
89
+
90
+ mapped_results = []
91
+ # Pair each result with its original uid
92
+ for (uid, _), result in zip(tasks_with_uid, results):
93
+ # If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
94
+ if isinstance(result, StreamPromptingSynapse):
95
+ mapped_results.append(StreamResult(synapse=result, uid=uid))
96
+
97
+ # If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
98
+ elif isinstance(result, BaseException):
99
+ failed_synapse = StreamPromptingSynapse(
100
+ roles=["user"], messages=["failure"], completion=""
101
+ )
102
+ mapped_results.append(
103
+ StreamResult(synapse=failed_synapse, exception=result, uid=uid)
104
+ )
105
+
106
+ # If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
107
+ else:
108
+ bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
109
+ raise ValueError(f"Unexpected result type for UID {uid}: {result}")
110
+
111
+ return mapped_results
112
+
113
+
114
+ @async_log
115
+ async def generate_reference(agent: HumanAgent):
116
+ loop = asyncio.get_running_loop()
117
+ result = await loop.run_in_executor(
118
+ None, agent.task.generate_reference, agent.llm_pipeline
119
+ )
120
+ return result
121
+
122
+
123
+ def log_stream_results(stream_results: List[StreamResult]):
124
+ failed_responses = [
125
+ response for response in stream_results if response.exception is not None
126
+ ]
127
+ empty_responses = [
128
+ response
129
+ for response in stream_results
130
+ if response.exception is None and response.synapse.completion == ""
131
+ ]
132
+ non_empty_responses = [
133
+ response
134
+ for response in stream_results
135
+ if response.exception is None and response.synapse.completion != ""
136
+ ]
137
+
138
+ bt.logging.info(f"Total of non_empty responses: ({len(non_empty_responses)})")
139
+ bt.logging.info(f"Total of empty responses: ({len(empty_responses)})")
140
+ bt.logging.info(
141
+ f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}"
142
+ )
143
+
144
+ for failed_response in failed_responses:
145
+ formatted_exception = serialize_exception_to_string(failed_response.exception)
146
+ bt.logging.error(
147
+ f"Failed response for uid {failed_response.uid}: {formatted_exception}"
148
+ )
149
+
150
+
151
+ async def run_step(
152
+ self, agent: HumanAgent, k: int, timeout: float, exclude: list = None
153
+ ):
154
+ """Executes a single step of the agent, which consists of:
155
+ - Getting a list of uids to query
156
+ - Querying the network
157
+ - Rewarding the network
158
+ - Updating the scores
159
+ - Logging the event
160
+
161
+ Args:
162
+ agent (HumanAgent): The agent to run the step for.
163
+ k (int): The number of uids to query.
164
+ timeout (float): The timeout for the queries.
165
+ exclude (list, optional): The list of uids to exclude from the query. Defaults to [].
166
+ """
167
+
168
+ bt.logging.debug("run_step", agent.task.name)
169
+
170
+ # Record event start time.
171
+ start_time = time.time()
172
+ # Get the list of uids to query for this step.
173
+ uids = get_random_uids(self, k=k, exclude=exclude or []).to(self.device)
174
+ uids_cpu = uids.cpu().tolist()
175
+
176
+ axons = [self.metagraph.axons[uid] for uid in uids]
177
+
178
+ # Directly call dendrite and process responses in parallel
179
+ streams_responses = await self.dendrite(
180
+ axons=axons,
181
+ synapse=StreamPromptingSynapse(roles=["user"], messages=[agent.challenge]),
182
+ timeout=timeout,
183
+ deserialize=False,
184
+ streaming=True,
185
+ )
186
+
187
+ # Prepare the task for handling stream responses
188
+ handle_stream_responses_task = asyncio.create_task(
189
+ handle_response(responses=dict(zip(uids_cpu, streams_responses)))
190
+ )
191
+
192
+ if not agent.task.static_reference:
193
+ reference_generation_task = generate_reference(agent)
194
+ _, stream_results = await asyncio.gather(
195
+ reference_generation_task, handle_stream_responses_task
196
+ )
197
+ else:
198
+ stream_results = await handle_stream_responses_task
199
+
200
+ log_stream_results(stream_results)
201
+
202
+ all_synapses_results = [stream_result.synapse for stream_result in stream_results]
203
+
204
+ # Encapsulate the responses in a response event (dataclass)
205
+ response_event = DendriteResponseEvent(
206
+ responses=all_synapses_results, uids=uids, timeout=timeout
207
+ )
208
+
209
+ bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}")
210
+ # Reward the responses and get the reward result (dataclass)
211
+ # This contains a list of RewardEvents but can be exported as a dict (column-wise) for logging etc
212
+ reward_result = RewardResult(
213
+ self.reward_pipeline,
214
+ agent=agent,
215
+ response_event=response_event,
216
+ device=self.device,
217
+ )
218
+ bt.logging.info(f"Created RewardResult:\n {reward_result}")
219
+
220
+ # The original idea was that the agent is 'satisfied' when it gets a good enough response (e.g. reward critera is met, such as ROUGE>threshold)
221
+ agent.update_progress(
222
+ top_reward=reward_result.rewards.max(),
223
+ top_response=response_event.completions[reward_result.rewards.argmax()],
224
+ )
225
+
226
+ self.update_scores(reward_result.rewards, uids)
227
+
228
+ stream_results_uids = [stream_result.uid for stream_result in stream_results]
229
+ stream_results_exceptions = [
230
+ serialize_exception_to_string(stream_result.exception)
231
+ for stream_result in stream_results
232
+ ]
233
+ # Log the step event.
234
+ event = {
235
+ "block": self.block,
236
+ "step_time": time.time() - start_time,
237
+ "stream_results_uids": stream_results_uids,
238
+ "stream_results_exceptions": stream_results_exceptions,
239
+ **agent.__state_dict__(full=self.config.neuron.log_full),
240
+ **reward_result.__state_dict__(full=self.config.neuron.log_full),
241
+ **response_event.__state_dict__(),
242
+ }
243
+
244
+ return event
server.py CHANGED
@@ -3,6 +3,7 @@
3
 
4
  import os
5
  import re
 
6
  import asyncio
7
  import json
8
  import traceback
@@ -24,6 +25,9 @@ from aiohttp.web_response import Response
24
  curl -X POST http://0.0.0.0:10000/chat/ -H "api_key: hello" -d '{"k": 5, "timeout": 3, "roles": ["user"], "messages": ["hello world"]}'
25
 
26
  curl -X POST http://0.0.0.0:10000/chat/ -H "api_key: hey-michal" -d '{"k": 5, "timeout": 3, "roles": ["user"], "messages": ["on what exact date did the 21st century begin?"]}'
 
 
 
27
  ```
28
 
29
  TROUBLESHOOT
@@ -31,11 +35,17 @@ check if port is open
31
  ```
32
  sudo ufw allow 10000/tcp
33
  sudo ufw allow 10000/tcp
34
- ```
35
  # run
36
  ```
37
  EXPECTED_ACCESS_KEY="hey-michal" pm2 start app.py --interpreter python3 --name app -- --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
38
  ```
 
 
 
 
 
 
39
  """
40
 
41
  EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
@@ -210,6 +220,62 @@ async def chat(request: web.Request) -> Response:
210
 
211
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  class ValidatorApplication(web.Application):
215
  def __init__(self, *a, **kw):
@@ -218,7 +284,10 @@ class ValidatorApplication(web.Application):
218
 
219
 
220
  validator_app = ValidatorApplication()
221
- validator_app.add_routes([web.post('/chat/', chat)])
 
 
 
222
 
223
  bt.logging.info("Starting validator application.")
224
  bt.logging.info(validator_app)
 
3
 
4
  import os
5
  import re
6
+ import time
7
  import asyncio
8
  import json
9
  import traceback
 
25
  curl -X POST http://0.0.0.0:10000/chat/ -H "api_key: hello" -d '{"k": 5, "timeout": 3, "roles": ["user"], "messages": ["hello world"]}'
26
 
27
  curl -X POST http://0.0.0.0:10000/chat/ -H "api_key: hey-michal" -d '{"k": 5, "timeout": 3, "roles": ["user"], "messages": ["on what exact date did the 21st century begin?"]}'
28
+
29
+ # stream
30
+ curl --no-buffer -X POST http://129.146.127.82:10000/echo/ -H "api_key: hey-michal" -d '{"k": 3, "timeout": 0.2, "roles": ["user"], "messages": ["i need to tell you something important but first"]}'
31
  ```
32
 
33
  TROUBLESHOOT
 
35
  ```
36
  sudo ufw allow 10000/tcp
37
  sudo ufw allow 10000/tcp
38
+ ```
39
  # run
40
  ```
41
  EXPECTED_ACCESS_KEY="hey-michal" pm2 start app.py --interpreter python3 --name app -- --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
42
  ```
43
+
44
+ basic testing
45
+ ```
46
+ EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
47
+ ```
48
+ add --mock to test the echo stream
49
  """
50
 
51
  EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
 
220
 
221
 
222
 
223
+ async def echo_stream(request):
224
+
225
+ bt.logging.info(f'echo_stream()')
226
+ # Check access key
227
+ access_key = request.headers.get("api_key")
228
+ if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
229
+ bt.logging.error(f'Invalid access key: {access_key}')
230
+ return Response(status=401, reason="Invalid access key")
231
+
232
+ try:
233
+ request_data = await request.json()
234
+ except ValueError:
235
+ bt.logging.error(f'Invalid request data: {request_data}')
236
+ return Response(status=400)
237
+
238
+ bt.logging.info(f'Request data: {request_data}')
239
+ k = request_data.get('k', 1)
240
+ exclude = request_data.get('exclude', [])
241
+ timeout = request_data.get('timeout', 0.2)
242
+ message = '\n\n'.join(request_data['messages'])
243
+
244
+ # Create a StreamResponse
245
+ response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/plain'})
246
+ await response.prepare(request)
247
+
248
+ completion = ''
249
+ # Echo the message k times with a timeout between each chunk
250
+ for _ in range(k):
251
+ for word in message.split():
252
+ chunk = f'{word} '
253
+ await response.write(chunk.encode('utf-8'))
254
+ completion += chunk
255
+ time.sleep(timeout)
256
+ bt.logging.info(f"Echoed: {chunk}")
257
+
258
+ completion = completion.strip()
259
+
260
+ # Prepare final JSON chunk
261
+ json_chunk = json.dumps({
262
+ "uids": [0],
263
+ "completion": completion,
264
+ "completions": [completion.strip()],
265
+ "timings": [0],
266
+ "status_messages": ['Went well!'],
267
+ "status_codes": [200],
268
+ "completion_is_valid": [True],
269
+ "task_name": 'echo',
270
+ "ensemble_result": {}
271
+ })
272
+
273
+ # Send the final JSON as part of the stream
274
+ await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8'))
275
+
276
+ # Finalize the response
277
+ await response.write_eof()
278
+ return response
279
 
280
  class ValidatorApplication(web.Application):
281
  def __init__(self, *a, **kw):
 
284
 
285
 
286
  validator_app = ValidatorApplication()
287
+ validator_app.add_routes([
288
+ web.post('/chat/', chat),
289
+ web.post('/echo/', echo_stream)
290
+ ])
291
 
292
  bt.logging.info("Starting validator application.")
293
  bt.logging.info(validator_app)