pedroferreira commited on
Commit
1e9def1
·
1 Parent(s): a415c67

integration fixes

Browse files
Files changed (4) hide show
  1. forward.py +245 -244
  2. server.py +4 -6
  3. validators/base.py +7 -3
  4. validators/sn1_validator_wrapper.py +33 -12
forward.py CHANGED
@@ -1,244 +1,245 @@
1
- # import time
2
- # import sys
3
- # import asyncio
4
- # import numpy as np
5
- # import bittensor as bt
6
- # import traceback
7
- # from typing import List, Dict, Awaitable
8
- # from prompting.agent import HumanAgent
9
- # from prompting.dendrite import DendriteResponseEvent
10
- # from prompting.conversation import create_task
11
- # from prompting.protocol import StreamPromptingSynapse
12
- # from prompting.rewards import RewardResult
13
- # from prompting.utils.uids import get_random_uids
14
- # from prompting.utils.logging import log_event
15
- # from prompting.utils.misc import async_log, serialize_exception_to_string
16
- # from dataclasses import dataclass
17
-
18
- # @async_log
19
- # async def generate_reference(agent):
20
- # loop = asyncio.get_running_loop()
21
- # result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
- # return result
23
-
24
- # @async_log
25
- # async def execute_dendrite_call(dendrite_call):
26
- # responses = await dendrite_call
27
- # return responses
28
-
29
- # @dataclass
30
- # class StreamResult:
31
- # synapse: StreamPromptingSynapse = None
32
- # exception: BaseException = None
33
- # uid: int = None
34
-
35
-
36
- # async def process_response(uid: int, async_generator: Awaitable):
37
- # """Process a single response asynchronously."""
38
- # try:
39
- # chunk = None # Initialize chunk with a default value
40
- # async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
41
- # bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
42
-
43
- # if chunk is not None:
44
- # synapse = chunk # last object yielded is the synapse itself with completion filled
45
-
46
- # # Assuming chunk holds the last value yielded which should be a synapse
47
- # if isinstance(synapse, StreamPromptingSynapse):
48
- # return synapse
49
-
50
- # bt.logging.debug(
51
- # f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' "
52
- # )
53
- # except Exception as e:
54
- # # bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True)
55
- # traceback_details = traceback.format_exc()
56
- # bt.logging.error(
57
- # f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}"
58
- # )
59
-
60
- # failed_synapse = StreamPromptingSynapse(
61
- # roles=["user"], messages=["failure"], completion=""
62
- # )
63
-
64
- # return failed_synapse
65
-
66
-
67
- # @async_log
68
- # async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]:
69
- # """The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
70
- # and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.
71
-
72
- # Args:
73
- # responses (Dict[int, Awaitable]): Responses contains awaitables that are used to acquire streamed miner chunks.
74
-
75
- # Raises:
76
- # ValueError
77
-
78
- # Returns:
79
- # List[StreamResult]: DataClass containing the synapse, exception, and uid
80
- # """
81
- # tasks_with_uid = [
82
- # (uid, responses[uid]) for uid, _ in responses.items()
83
- # ] # Pair UIDs with their tasks
84
-
85
- # # Start tasks, preserving order and their associated UIDs
86
- # tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]
87
-
88
- # results = await asyncio.gather(*tasks, return_exceptions=True)
89
-
90
- # mapped_results = []
91
- # # Pair each result with its original uid
92
- # for (uid, _), result in zip(tasks_with_uid, results):
93
- # # If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
94
- # if isinstance(result, StreamPromptingSynapse):
95
- # mapped_results.append(StreamResult(synapse=result, uid=uid))
96
-
97
- # # If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
98
- # elif isinstance(result, BaseException):
99
- # failed_synapse = StreamPromptingSynapse(
100
- # roles=["user"], messages=["failure"], completion=""
101
- # )
102
- # mapped_results.append(
103
- # StreamResult(synapse=failed_synapse, exception=result, uid=uid)
104
- # )
105
-
106
- # # If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
107
- # else:
108
- # bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
109
- # raise ValueError(f"Unexpected result type for UID {uid}: {result}")
110
-
111
- # return mapped_results
112
-
113
-
114
- # @async_log
115
- # async def generate_reference(agent: HumanAgent):
116
- # loop = asyncio.get_running_loop()
117
- # result = await loop.run_in_executor(
118
- # None, agent.task.generate_reference, agent.llm_pipeline
119
- # )
120
- # return result
121
-
122
-
123
- # def log_stream_results(stream_results: List[StreamResult]):
124
- # failed_responses = [
125
- # response for response in stream_results if response.exception is not None
126
- # ]
127
- # empty_responses = [
128
- # response
129
- # for response in stream_results
130
- # if response.exception is None and response.synapse.completion == ""
131
- # ]
132
- # non_empty_responses = [
133
- # response
134
- # for response in stream_results
135
- # if response.exception is None and response.synapse.completion != ""
136
- # ]
137
-
138
- # bt.logging.info(f"Total of non_empty responses: ({len(non_empty_responses)})")
139
- # bt.logging.info(f"Total of empty responses: ({len(empty_responses)})")
140
- # bt.logging.info(
141
- # f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}"
142
- # )
143
-
144
- # for failed_response in failed_responses:
145
- # formatted_exception = serialize_exception_to_string(failed_response.exception)
146
- # bt.logging.error(
147
- # f"Failed response for uid {failed_response.uid}: {formatted_exception}"
148
- # )
149
-
150
-
151
- # async def run_step(
152
- # self, agent: HumanAgent, k: int, timeout: float, exclude: list = None
153
- # ):
154
- # """Executes a single step of the agent, which consists of:
155
- # - Getting a list of uids to query
156
- # - Querying the network
157
- # - Rewarding the network
158
- # - Updating the scores
159
- # - Logging the event
160
-
161
- # Args:
162
- # agent (HumanAgent): The agent to run the step for.
163
- # k (int): The number of uids to query.
164
- # timeout (float): The timeout for the queries.
165
- # exclude (list, optional): The list of uids to exclude from the query. Defaults to [].
166
- # """
167
-
168
- # bt.logging.debug("run_step", agent.task.name)
169
-
170
- # # Record event start time.
171
- # start_time = time.time()
172
- # # Get the list of uids to query for this step.
173
- # uids = get_random_uids(self, k=k, exclude=exclude or []).to(self.device)
174
- # uids_cpu = uids.cpu().tolist()
175
-
176
- # axons = [self.metagraph.axons[uid] for uid in uids]
177
-
178
- # # Directly call dendrite and process responses in parallel
179
- # streams_responses = await self.dendrite(
180
- # axons=axons,
181
- # synapse=StreamPromptingSynapse(roles=["user"], messages=[agent.challenge]),
182
- # timeout=timeout,
183
- # deserialize=False,
184
- # streaming=True,
185
- # )
186
-
187
- # # Prepare the task for handling stream responses
188
- # handle_stream_responses_task = asyncio.create_task(
189
- # handle_response(responses=dict(zip(uids_cpu, streams_responses)))
190
- # )
191
-
192
- # if not agent.task.static_reference:
193
- # reference_generation_task = generate_reference(agent)
194
- # _, stream_results = await asyncio.gather(
195
- # reference_generation_task, handle_stream_responses_task
196
- # )
197
- # else:
198
- # stream_results = await handle_stream_responses_task
199
-
200
- # log_stream_results(stream_results)
201
-
202
- # all_synapses_results = [stream_result.synapse for stream_result in stream_results]
203
-
204
- # # Encapsulate the responses in a response event (dataclass)
205
- # response_event = DendriteResponseEvent(
206
- # responses=all_synapses_results, uids=uids, timeout=timeout
207
- # )
208
-
209
- # bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}")
210
- # # Reward the responses and get the reward result (dataclass)
211
- # # This contains a list of RewardEvents but can be exported as a dict (column-wise) for logging etc
212
- # reward_result = RewardResult(
213
- # self.reward_pipeline,
214
- # agent=agent,
215
- # response_event=response_event,
216
- # device=self.device,
217
- # )
218
- # bt.logging.info(f"Created RewardResult:\n {reward_result}")
219
-
220
- # # The original idea was that the agent is 'satisfied' when it gets a good enough response (e.g. reward critera is met, such as ROUGE>threshold)
221
- # agent.update_progress(
222
- # top_reward=reward_result.rewards.max(),
223
- # top_response=response_event.completions[reward_result.rewards.argmax()],
224
- # )
225
-
226
- # self.update_scores(reward_result.rewards, uids)
227
-
228
- # stream_results_uids = [stream_result.uid for stream_result in stream_results]
229
- # stream_results_exceptions = [
230
- # serialize_exception_to_string(stream_result.exception)
231
- # for stream_result in stream_results
232
- # ]
233
- # # Log the step event.
234
- # event = {
235
- # "block": self.block,
236
- # "step_time": time.time() - start_time,
237
- # "stream_results_uids": stream_results_uids,
238
- # "stream_results_exceptions": stream_results_exceptions,
239
- # **agent.__state_dict__(full=self.config.neuron.log_full),
240
- # **reward_result.__state_dict__(full=self.config.neuron.log_full),
241
- # **response_event.__state_dict__(),
242
- # }
243
-
244
- # return event
 
 
1
+ import time
2
+ import sys
3
+ import asyncio
4
+ import numpy as np
5
+ import bittensor as bt
6
+ import traceback
7
+ from typing import List, Dict, Awaitable
8
+ from prompting.agent import HumanAgent
9
+ from prompting.dendrite import DendriteResponseEvent
10
+ from prompting.conversation import create_task
11
+ from prompting.protocol import StreamPromptingSynapse
12
+ from prompting.rewards import RewardResult
13
+ from prompting.utils.uids import get_random_uids
14
+ from prompting.utils.logging import log_event
15
+ from prompting.utils.misc import async_log, serialize_exception_to_string
16
+ from dataclasses import dataclass
17
+
18
+ @async_log
19
+ async def generate_reference(agent):
20
+ loop = asyncio.get_running_loop()
21
+ result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
+ return result
23
+
24
+ @async_log
25
+ async def execute_dendrite_call(dendrite_call):
26
+ responses = await dendrite_call
27
+ return responses
28
+
29
+ @dataclass
30
+ class StreamResult:
31
+ synapse: StreamPromptingSynapse = None
32
+ exception: BaseException = None
33
+ uid: int = None
34
+
35
+
36
+ async def process_response(uid: int, async_generator: Awaitable):
37
+ """Process a single response asynchronously."""
38
+ try:
39
+ chunk = None # Initialize chunk with a default value
40
+ async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
41
+ bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
42
+
43
+ if chunk is not None:
44
+ synapse = chunk # last object yielded is the synapse itself with completion filled
45
+
46
+ # Assuming chunk holds the last value yielded which should be a synapse
47
+ if isinstance(synapse, StreamPromptingSynapse):
48
+ return synapse
49
+
50
+ bt.logging.debug(
51
+ f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' "
52
+ )
53
+ except Exception as e:
54
+ # bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True)
55
+ traceback_details = traceback.format_exc()
56
+ bt.logging.error(
57
+ f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}"
58
+ )
59
+
60
+ failed_synapse = StreamPromptingSynapse(
61
+ roles=["user"], messages=["failure"], completion=""
62
+ )
63
+
64
+ return failed_synapse
65
+
66
+
67
+ @async_log
68
+ async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]:
69
+ """The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
70
+ and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.
71
+
72
+ Args:
73
+ responses (Dict[int, Awaitable]): Responses contains awaitables that are used to acquire streamed miner chunks.
74
+
75
+ Raises:
76
+ ValueError
77
+
78
+ Returns:
79
+ List[StreamResult]: DataClass containing the synapse, exception, and uid
80
+ """
81
+ tasks_with_uid = [
82
+ (uid, responses[uid]) for uid, _ in responses.items()
83
+ ] # Pair UIDs with their tasks
84
+
85
+ #Start tasks, preserving order and their associated UIDs
86
+ tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]
87
+
88
+ results = await asyncio.gather(*tasks, return_exceptions=True)
89
+
90
+ mapped_results = []
91
+ #Pair each result with its original uid
92
+ for (uid, _), result in zip(tasks_with_uid, results):
93
+
94
+ #If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
95
+ if isinstance(result, StreamPromptingSynapse):
96
+ mapped_results.append(StreamResult(synapse=result, uid=uid))
97
+
98
+ #If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
99
+ elif isinstance(result, BaseException):
100
+ failed_synapse = StreamPromptingSynapse(
101
+ roles=["user"], messages=["failure"], completion=""
102
+ )
103
+ mapped_results.append(
104
+ StreamResult(synapse=failed_synapse, exception=result, uid=uid)
105
+ )
106
+
107
+ #If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
108
+ else:
109
+ bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
110
+ raise ValueError(f"Unexpected result type for UID {uid}: {result}")
111
+
112
+ return mapped_results
113
+
114
+
115
+ @async_log
116
+ async def generate_reference(agent: HumanAgent):
117
+ loop = asyncio.get_running_loop()
118
+ result = await loop.run_in_executor(
119
+ None, agent.task.generate_reference, agent.llm_pipeline
120
+ )
121
+ return result
122
+
123
+
124
+ def log_stream_results(stream_results: List[StreamResult]):
125
+ failed_responses = [
126
+ response for response in stream_results if response.exception is not None
127
+ ]
128
+ empty_responses = [
129
+ response
130
+ for response in stream_results
131
+ if response.exception is None and response.synapse.completion == ""
132
+ ]
133
+ non_empty_responses = [
134
+ response
135
+ for response in stream_results
136
+ if response.exception is None and response.synapse.completion != ""
137
+ ]
138
+
139
+ bt.logging.info(f"Total of non_empty responses: ({len(non_empty_responses)})")
140
+ bt.logging.info(f"Total of empty responses: ({len(empty_responses)})")
141
+ bt.logging.info(
142
+ f"Total of failed responses: ({len(failed_responses)}):\n {failed_responses}"
143
+ )
144
+
145
+ for failed_response in failed_responses:
146
+ formatted_exception = serialize_exception_to_string(failed_response.exception)
147
+ bt.logging.error(
148
+ f"Failed response for uid {failed_response.uid}: {formatted_exception}"
149
+ )
150
+
151
+
152
+ async def run_step(
153
+ self, agent: HumanAgent, k: int, timeout: float, exclude: list = None
154
+ ):
155
+ """Executes a single step of the agent, which consists of:
156
+ - Getting a list of uids to query
157
+ - Querying the network
158
+ - Rewarding the network
159
+ - Updating the scores
160
+ - Logging the event
161
+
162
+ Args:
163
+ agent (HumanAgent): The agent to run the step for.
164
+ k (int): The number of uids to query.
165
+ timeout (float): The timeout for the queries.
166
+ exclude (list, optional): The list of uids to exclude from the query. Defaults to [].
167
+ """
168
+
169
+ bt.logging.debug("run_step", agent.task.name)
170
+
171
+ # Record event start time.
172
+ start_time = time.time()
173
+ # Get the list of uids to query for this step.
174
+ uids = get_random_uids(self, k=k, exclude=exclude or []).to(self.device)
175
+ uids_cpu = uids.cpu().tolist()
176
+
177
+ axons = [self.metagraph.axons[uid] for uid in uids]
178
+
179
+ # Directly call dendrite and process responses in parallel
180
+ streams_responses = await self.dendrite(
181
+ axons=axons,
182
+ synapse=StreamPromptingSynapse(roles=["user"], messages=[agent.challenge]),
183
+ timeout=timeout,
184
+ deserialize=False,
185
+ streaming=True,
186
+ )
187
+
188
+ # Prepare the task for handling stream responses
189
+ handle_stream_responses_task = asyncio.create_task(
190
+ handle_response(responses=dict(zip(uids_cpu, streams_responses)))
191
+ )
192
+
193
+ if not agent.task.static_reference:
194
+ reference_generation_task = generate_reference(agent)
195
+ _, stream_results = await asyncio.gather(
196
+ reference_generation_task, handle_stream_responses_task
197
+ )
198
+ else:
199
+ stream_results = await handle_stream_responses_task
200
+
201
+ log_stream_results(stream_results)
202
+
203
+ all_synapses_results = [stream_result.synapse for stream_result in stream_results]
204
+
205
+ # Encapsulate the responses in a response event (dataclass)
206
+ response_event = DendriteResponseEvent(
207
+ responses=all_synapses_results, uids=uids, timeout=timeout
208
+ )
209
+
210
+ bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}")
211
+ # Reward the responses and get the reward result (dataclass)
212
+ # This contains a list of RewardEvents but can be exported as a dict (column-wise) for logging etc
213
+ reward_result = RewardResult(
214
+ self.reward_pipeline,
215
+ agent=agent,
216
+ response_event=response_event,
217
+ device=self.device,
218
+ )
219
+ bt.logging.info(f"Created RewardResult:\n {reward_result}")
220
+
221
+ # The original idea was that the agent is 'satisfied' when it gets a good enough response (e.g. reward critera is met, such as ROUGE>threshold)
222
+ agent.update_progress(
223
+ top_reward=reward_result.rewards.max(),
224
+ top_response=response_event.completions[reward_result.rewards.argmax()],
225
+ )
226
+
227
+ self.update_scores(reward_result.rewards, uids)
228
+
229
+ stream_results_uids = [stream_result.uid for stream_result in stream_results]
230
+ stream_results_exceptions = [
231
+ serialize_exception_to_string(stream_result.exception)
232
+ for stream_result in stream_results
233
+ ]
234
+ # Log the step event.
235
+ event = {
236
+ "block": self.block,
237
+ "step_time": time.time() - start_time,
238
+ "stream_results_uids": stream_results_uids,
239
+ "stream_results_exceptions": stream_results_exceptions,
240
+ **agent.__state_dict__(full=self.config.neuron.log_full),
241
+ **reward_result.__state_dict__(full=self.config.neuron.log_full),
242
+ **response_event.__state_dict__(),
243
+ }
244
+
245
+ return event
server.py CHANGED
@@ -37,16 +37,14 @@ add --mock to test the echo stream
37
  async def chat(request: web.Request) -> Response:
38
  """
39
  Chat endpoint for the validator.
40
- """
41
- request_data = request['data']
42
- params = QueryValidatorParams.from_dict(request_data)
43
- # TODO: SET STREAM AS DEFAULT
44
- stream = request_data.get('stream', True)
45
 
46
  # Access the validator from the application context
47
  validator: ValidatorAPI = request.app['validator']
48
 
49
- response = await validator.query_validator(params, stream=stream)
50
  return response
51
 
52
 
 
37
  async def chat(request: web.Request) -> Response:
38
  """
39
  Chat endpoint for the validator.
40
+ """
41
+ params = QueryValidatorParams.from_request(request)
42
+
 
 
43
 
44
  # Access the validator from the application context
45
  validator: ValidatorAPI = request.app['validator']
46
 
47
+ response = await validator.query_validator(params)
48
  return response
49
 
50
 
validators/base.py CHANGED
@@ -1,7 +1,7 @@
1
  from abc import ABC, abstractmethod
2
  from typing import List
3
  from dataclasses import dataclass
4
- from aiohttp.web_response import Response
5
 
6
  @dataclass
7
  class QueryValidatorParams:
@@ -11,16 +11,20 @@ class QueryValidatorParams:
11
  messages: List[str]
12
  timeout: int
13
  prefer: str
 
14
 
15
  @staticmethod
16
- def from_dict(data: dict):
 
 
17
  return QueryValidatorParams(
18
  k_miners=data.get('k', 10),
19
  exclude=data.get('exclude', []),
20
  roles=data['roles'],
21
  messages=data['messages'],
22
  timeout=data.get('timeout', 10),
23
- prefer=data.get('prefer', 'longest')
 
24
  )
25
 
26
  class ValidatorAPI(ABC):
 
1
  from abc import ABC, abstractmethod
2
  from typing import List
3
  from dataclasses import dataclass
4
+ from aiohttp.web import Response, Request
5
 
6
  @dataclass
7
  class QueryValidatorParams:
 
11
  messages: List[str]
12
  timeout: int
13
  prefer: str
14
+ request: Request
15
 
16
  @staticmethod
17
+ def from_request(request: Request):
18
+ data = request['data']
19
+
20
  return QueryValidatorParams(
21
  k_miners=data.get('k', 10),
22
  exclude=data.get('exclude', []),
23
  roles=data['roles'],
24
  messages=data['messages'],
25
  timeout=data.get('timeout', 10),
26
+ prefer=data.get('prefer', 'longest'),
27
+ request=request
28
  )
29
 
30
  class ValidatorAPI(ABC):
validators/sn1_validator_wrapper.py CHANGED
@@ -2,7 +2,9 @@ import json
2
  import utils
3
  import torch
4
  import traceback
 
5
  import bittensor as bt
 
6
  from prompting.validator import Validator
7
  from prompting.utils.uids import get_random_uids
8
  from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
@@ -55,11 +57,30 @@ class S1ValidatorAPI(ValidatorAPI):
55
  return Response(status=500, reason="Internal error")
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  async def get_stream_response(self, params:QueryValidatorParams) -> StreamResponse:
59
  response = StreamResponse(status=200, reason="OK")
60
  response.headers['Content-Type'] = 'application/json'
61
 
62
- await response.prepare() # Prepare and send the headers
63
 
64
  try:
65
  # Guess the task name of current request
@@ -78,14 +99,11 @@ class S1ValidatorAPI(ValidatorAPI):
78
  deserialize=False,
79
  streaming=True,
80
  )
81
-
82
- # Asynchronous iteration over streaming responses
83
- async for stream_result in streams_responses:
84
- if stream_result is not None:
85
- # Convert stream result to JSON and write to the response stream
86
- json_data = json.dumps(stream_result)
87
- await response.write(json_data.encode('utf-8'))
88
-
89
  except Exception as e:
90
  bt.logging.error(f'Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}')
91
  response.set_status(500, reason="Internal error")
@@ -94,12 +112,15 @@ class S1ValidatorAPI(ValidatorAPI):
94
  await response.write_eof() # Ensure to close the response properly
95
 
96
  return response
97
-
98
 
99
- async def query_validator(self, params:QueryValidatorParams, stream: bool = True) -> Response:
 
 
 
 
100
  if stream:
101
  return await self.get_stream_response(params)
102
  else:
103
  # DEPRECATED
104
  return await self.get_response(params)
105
-
 
2
  import utils
3
  import torch
4
  import traceback
5
+ import asyncio
6
  import bittensor as bt
7
+ from typing import Awaitable
8
  from prompting.validator import Validator
9
  from prompting.utils.uids import get_random_uids
10
  from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
 
57
  return Response(status=500, reason="Internal error")
58
 
59
 
60
+ async def process_response(self, response: StreamResponse, uid: int, async_generator: Awaitable):
61
+ """Process a single response asynchronously."""
62
+ try:
63
+ chunk = None # Initialize chunk with a default value
64
+ async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
65
+ bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
66
+
67
+ # TODO: SET PROPER IMPLEMENTATION TO RETURN CHUNK
68
+ if chunk is not None:
69
+ json_data = json.dumps(chunk)
70
+ await response.write(json_data.encode('utf-8'))
71
+
72
+ except Exception as e:
73
+ bt.logging.error(f'Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}')
74
+ response.set_status(500, reason="Internal error")
75
+ await response.write(json.dumps({'error': str(e)}).encode('utf-8'))
76
+ finally:
77
+ await response.write_eof() # Ensure to close the response properly
78
+
79
  async def get_stream_response(self, params:QueryValidatorParams) -> StreamResponse:
80
  response = StreamResponse(status=200, reason="OK")
81
  response.headers['Content-Type'] = 'application/json'
82
 
83
+ await response.prepare(params.request) # Prepare and send the headers
84
 
85
  try:
86
  # Guess the task name of current request
 
99
  deserialize=False,
100
  streaming=True,
101
  )
102
+
103
+ tasks = [self.process_response(uid, res) for uid, res in dict(zip(uids, streams_responses))]
104
+ results = await asyncio.gather(*tasks, return_exceptions=True)
105
+
106
+ # TODO: Continue implementation, business decision needs to be made on how to handle the results
 
 
 
107
  except Exception as e:
108
  bt.logging.error(f'Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}')
109
  response.set_status(500, reason="Internal error")
 
112
  await response.write_eof() # Ensure to close the response properly
113
 
114
  return response
 
115
 
116
+
117
+ async def query_validator(self, params:QueryValidatorParams) -> Response:
118
+ # TODO: SET STREAM AS DEFAULT
119
+ stream = params.request.get('stream', False)
120
+
121
  if stream:
122
  return await self.get_stream_response(params)
123
  else:
124
  # DEPRECATED
125
  return await self.get_response(params)
126
+