pedroferreira commited on
Commit
d3ebdae
·
1 Parent(s): 852435a

adds stream manager

Browse files
validators/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .base import QueryValidatorParams, ValidatorAPI, MockValidator
2
  from .sn1_validator_wrapper import S1ValidatorAPI
3
  from .streamer import AsyncResponseDataStreamer
 
 
1
  from .base import QueryValidatorParams, ValidatorAPI, MockValidator
2
  from .sn1_validator_wrapper import S1ValidatorAPI
3
  from .streamer import AsyncResponseDataStreamer
4
+ from .stream_manager import StreamManager
validators/sn1_validator_wrapper.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import random
2
  import bittensor as bt
3
  from prompting.validator import Validator
@@ -7,7 +8,7 @@ from .base import QueryValidatorParams, ValidatorAPI
7
  from aiohttp.web_response import Response, StreamResponse
8
  from .streamer import AsyncResponseDataStreamer
9
  from .validator_utils import get_top_incentive_uids
10
-
11
 
12
  class S1ValidatorAPI(ValidatorAPI):
13
  def __init__(self):
@@ -47,13 +48,18 @@ class S1ValidatorAPI(ValidatorAPI):
47
  deserialize=False,
48
  streaming=True,
49
  )
50
- uid_stream_dict = dict(zip(uids, streams_responses))
51
- random_uid, random_stream = random.choice(list(uid_stream_dict.items()))
52
 
53
  # Creates a streamer from the selected stream
54
- streamer = AsyncResponseDataStreamer(async_iterator=random_stream, selected_uid=random_uid)
55
- response = await streamer.stream(params.request)
56
- return response
 
 
 
 
 
57
 
58
  async def query_validator(self, params: QueryValidatorParams) -> Response:
59
  return await self.get_stream_response(params)
 
1
+ import asyncio
2
  import random
3
  import bittensor as bt
4
  from prompting.validator import Validator
 
8
  from aiohttp.web_response import Response, StreamResponse
9
  from .streamer import AsyncResponseDataStreamer
10
  from .validator_utils import get_top_incentive_uids
11
+ from .stream_manager import StreamManager
12
 
13
  class S1ValidatorAPI(ValidatorAPI):
14
  def __init__(self):
 
48
  deserialize=False,
49
  streaming=True,
50
  )
51
+ # uid_stream_dict = dict(zip(uids, streams_responses))
52
+ # random_uid, random_stream = random.choice(list(uid_stream_dict.items()))
53
 
54
  # Creates a streamer from the selected stream
55
+ stream_manager = StreamManager()
56
+ await stream_manager.process_streams(params.request, streams_responses, uids)
57
+
58
+ # response = await streamer.stream(params.request)
59
+ # return response
60
+
61
+ return None
62
+
63
 
64
  async def query_validator(self, params: QueryValidatorParams) -> Response:
65
  return await self.get_stream_response(params)
validators/stream_manager.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from .streamer import AsyncResponseDataStreamer
3
+ from typing import List, AsyncIterator
4
+ from aiohttp.web import Request
5
+
6
+ class StreamManager:
7
+
8
+ def __init__(self):
9
+ ...
10
+
11
+
12
+
13
+ async def process_streams(self, request:Request, streams_responses: List[AsyncIterator], stream_uids: List[int]):
14
+
15
+ # create local lock for returning responses to the front-end
16
+ # creates n number of streamers
17
+ # organizes responses
18
+ # logs responses locally
19
+ # returns selected response to the front-end
20
+
21
+ lock = asyncio.Lock()
22
+
23
+ streamers = [AsyncResponseDataStreamer(async_iterator=stream, selected_uid=stream_uid, lock=lock) for stream, stream_uid in zip(streams_responses, stream_uids)]
24
+ completed_streams = await asyncio.gather(*[streamer.stream(request) for streamer in streamers])
25
+
26
+ lock.release()
27
+ print(f"Stream {stream_uids} completed the operation.")
28
+
29
+
30
+
31
+
32
+
validators/streamer.py CHANGED
@@ -1,10 +1,11 @@
1
  import json
2
  import time
 
3
  import traceback
4
  import bittensor as bt
5
  from pydantic import BaseModel
6
  from datetime import datetime
7
- from typing import AsyncIterator, Optional, List
8
  from aiohttp import web, web_response
9
  from prompting.protocol import StreamPromptingSynapse
10
 
@@ -35,7 +36,7 @@ class StreamError(BaseModel):
35
 
36
 
37
  class AsyncResponseDataStreamer:
38
- def __init__(self, async_iterator: AsyncIterator, selected_uid:int, delay: float = 0.1):
39
  self.async_iterator = async_iterator
40
  self.delay = delay
41
  self.selected_uid = selected_uid
@@ -43,16 +44,51 @@ class AsyncResponseDataStreamer:
43
  self.accumulated_chunks_timings: List[float] = []
44
  self.finish_reason: str = None
45
  self.sequence_number: int = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  async def stream(self, request: web.Request) -> web_response.StreamResponse:
48
- response = web_response.StreamResponse(status=200, reason="OK")
49
- response.headers["Content-Type"] = "application/json"
50
- await response.prepare(request) # Prepare and send the headers
51
-
52
  try:
53
  start_time = time.time()
 
 
 
54
  async for chunk in self.async_iterator:
55
- if isinstance(chunk, list):
56
  # Chunks are currently returned in string arrays, so we need to concatenate them
57
  concatenated_chunks = "".join(chunk)
58
  self.accumulated_chunks.append(concatenated_chunks)
@@ -60,8 +96,9 @@ class AsyncResponseDataStreamer:
60
  # Gets new response state
61
  self.sequence_number += 1
62
  new_response_state = self._create_chunk_response(concatenated_chunks)
63
- # Writes the new response state to the response
64
- await response.write(new_response_state.encode('utf-8'))
 
65
 
66
  if chunk is not None and isinstance(chunk, StreamPromptingSynapse):
67
  if len(self.accumulated_chunks) == 0:
@@ -72,19 +109,33 @@ class AsyncResponseDataStreamer:
72
  self.sequence_number += 1
73
  # Assuming the last chunk holds the last value yielded which should be a synapse with the completion filled
74
  synapse = chunk
75
- final_state = self._create_chunk_response(synapse.completion)
76
- await response.write(final_state.encode('utf-8'))
 
 
 
 
77
 
78
  except Exception as e:
79
  bt.logging.error(
80
- f"Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}"
81
  )
82
- response.set_status(500, reason="Internal error")
83
  error_response = self._create_error_response(str(e))
84
- response.write(error_response.encode('utf-8'))
 
 
 
 
 
 
85
  finally:
86
- await response.write_eof() # Ensure to close the response properly
87
- return response
 
 
 
 
 
88
 
89
  def _create_chunk_response(self, chunk: str) -> StreamChunk:
90
  """
 
1
  import json
2
  import time
3
+ import asyncio
4
  import traceback
5
  import bittensor as bt
6
  from pydantic import BaseModel
7
  from datetime import datetime
8
+ from typing import AsyncIterator, Optional, List, Union
9
  from aiohttp import web, web_response
10
  from prompting.protocol import StreamPromptingSynapse
11
 
 
36
 
37
 
38
  class AsyncResponseDataStreamer:
39
+ def __init__(self, async_iterator: AsyncIterator, selected_uid:int, lock: asyncio.Lock, delay: float = 0.1):
40
  self.async_iterator = async_iterator
41
  self.delay = delay
42
  self.selected_uid = selected_uid
 
44
  self.accumulated_chunks_timings: List[float] = []
45
  self.finish_reason: str = None
46
  self.sequence_number: int = 0
47
+ self.lock = lock
48
+ self.lock_acquired = False
49
+
50
+
51
+ def ensure_response_is_created(self, initiated_response: web.StreamResponse) -> web.StreamResponse:
52
+ # Creates response if it was not created
53
+ if initiated_response == None:
54
+ initiated_response = web_response.StreamResponse(status=200, reason="OK")
55
+ initiated_response.headers["Content-Type"] = "application/json"
56
+ return initiated_response
57
+
58
+ return initiated_response
59
+
60
+
61
+ async def write_to_stream(self, request: web.Request, initiated_response: web.StreamResponse, stream_chunk: StreamChunk, lock: asyncio.Lock) -> web.StreamResponse:
62
+ # Try to acquire the lock and sets the lock_acquired flag. Only the stream that acquires the lock should write to the response
63
+ if lock.locked() == False:
64
+ self.lock_acquired = await lock.acquire()
65
+
66
+ if initiated_response == None and self.lock_acquired:
67
+ initiated_response = self.ensure_response_is_created(initiated_response)
68
+ # Prepare and send the headers
69
+ await initiated_response.prepare(request)
70
+
71
+ if self.lock_acquired:
72
+ await initiated_response.write(stream_chunk.encode('utf-8'))
73
+ else:
74
+ bt.logging.info(f"Stream of uid {stream_chunk.selected_uid} was not the first to return, skipping...")
75
+
76
+ return initiated_response
77
+
78
+
79
 
80
  async def stream(self, request: web.Request) -> web_response.StreamResponse:
81
+ # response = web_response.StreamResponse(status=200, reason="OK")
82
+ # response.headers["Content-Type"] = "application/json"
83
+ # await response.prepare(request) # Prepare and send the headers
84
+
85
  try:
86
  start_time = time.time()
87
+ client_response: web.Response = None
88
+ final_response: Union[StreamChunk, StreamError]
89
+
90
  async for chunk in self.async_iterator:
91
+ if isinstance(chunk, str):
92
  # Chunks are currently returned in string arrays, so we need to concatenate them
93
  concatenated_chunks = "".join(chunk)
94
  self.accumulated_chunks.append(concatenated_chunks)
 
96
  # Gets new response state
97
  self.sequence_number += 1
98
  new_response_state = self._create_chunk_response(concatenated_chunks)
99
+ # Writes the new response state to the response
100
+ client_response = await self.write_to_stream(request, client_response, new_response_state, self.lock)
101
+ #await response.write(new_response_state.encode('utf-8'))
102
 
103
  if chunk is not None and isinstance(chunk, StreamPromptingSynapse):
104
  if len(self.accumulated_chunks) == 0:
 
109
  self.sequence_number += 1
110
  # Assuming the last chunk holds the last value yielded which should be a synapse with the completion filled
111
  synapse = chunk
112
+ final_response = self._create_chunk_response(synapse.completion)
113
+
114
+ if synapse.completion:
115
+ client_response = await self.write_to_stream(request, client_response, final_response, self.lock)
116
+ else:
117
+ raise ValueError("Stream did not return a valid synapse.")
118
 
119
  except Exception as e:
120
  bt.logging.error(
121
+ f"Encountered an error while processing stream for uid {self.selected_uid} get_stream_response:\n{traceback.format_exc()}"
122
  )
 
123
  error_response = self._create_error_response(str(e))
124
+ final_response = error_response
125
+
126
+ # Only the stream that acquires the lock should write the error response
127
+ if self.lock.locked():
128
+ self.ensure_response_is_created(client_response)
129
+ client_response.set_status(500, reason="Internal error")
130
+ client_response.write(error_response.encode('utf-8'))
131
  finally:
132
+ # Only the stream that acquires the lock should close the response
133
+ if self.lock.locked():
134
+ self.ensure_response_is_created(client_response)
135
+ # Ensure to close the response properly
136
+ await client_response.write_eof()
137
+
138
+ return final_response
139
 
140
  def _create_chunk_response(self, chunk: str) -> StreamChunk:
141
  """