p-ferreira commited on
Commit
a9ac6b7
·
unverified ·
2 Parent(s): bcd7f75 f6cfbcb

Merge pull request #22 from macrocosm-os/features/distributed-streaming

Browse files
common/utils.py CHANGED
@@ -153,7 +153,6 @@ async def echo_stream(request: web.Request) -> web.StreamResponse:
153
  k = request_data.get("k", 1)
154
  message = "\n\n".join(request_data["messages"])
155
 
156
-
157
  echo_iterator = EchoAsyncIterator(message, k, delay=0.3)
158
  streamer = AsyncResponseDataStreamer(echo_iterator, selected_uid=0, delay=0.3)
159
 
 
153
  k = request_data.get("k", 1)
154
  message = "\n\n".join(request_data["messages"])
155
 
 
156
  echo_iterator = EchoAsyncIterator(message, k, delay=0.3)
157
  streamer = AsyncResponseDataStreamer(echo_iterator, selected_uid=0, delay=0.3)
158
 
requirements.txt CHANGED
@@ -2,3 +2,4 @@ git+https://github.com/macrocosm-os/prompting.git
2
  aiohttp
3
  deprecated
4
  aiohttp_apispec>=2.2.3
 
 
2
  aiohttp
3
  deprecated
4
  aiohttp_apispec>=2.2.3
5
+ aiofiles
server.py CHANGED
@@ -2,7 +2,13 @@ import asyncio
2
 
3
  import bittensor as bt
4
  from aiohttp import web
5
- from aiohttp_apispec import docs, request_schema, response_schema, setup_aiohttp_apispec, validation_middleware
 
 
 
 
 
 
6
 
7
  from common import utils
8
  from common.middlewares import api_key_middleware, json_parsing_middleware
@@ -10,11 +16,7 @@ from common.schemas import QueryChatSchema, StreamChunkSchema, StreamErrorSchema
10
  from validators import QueryValidatorParams, S1ValidatorAPI, ValidatorAPI
11
 
12
 
13
- @docs(
14
- tags=["Prompting API"],
15
- summary="Chat",
16
- description="Chat endpoint."
17
- )
18
  @request_schema(QueryChatSchema)
19
  @response_schema(StreamChunkSchema, 200)
20
  @response_schema(StreamErrorSchema, 400)
@@ -32,7 +34,7 @@ async def chat(request: web.Request) -> web.StreamResponse:
32
  @docs(
33
  tags=["Prompting API"],
34
  summary="Echo test",
35
- description="Echo endpoint for testing purposes."
36
  )
37
  @request_schema(QueryChatSchema)
38
  @response_schema(StreamChunkSchema, 200)
@@ -45,7 +47,9 @@ class ValidatorApplication(web.Application):
45
  def __init__(self, validator_instance=None, *args, **kwargs):
46
  super().__init__(*args, **kwargs)
47
 
48
- self["validator"] = validator_instance if validator_instance else S1ValidatorAPI()
 
 
49
 
50
  # Add middlewares to application
51
  self.add_routes(
 
2
 
3
  import bittensor as bt
4
  from aiohttp import web
5
+ from aiohttp_apispec import (
6
+ docs,
7
+ request_schema,
8
+ response_schema,
9
+ setup_aiohttp_apispec,
10
+ validation_middleware,
11
+ )
12
 
13
  from common import utils
14
  from common.middlewares import api_key_middleware, json_parsing_middleware
 
16
  from validators import QueryValidatorParams, S1ValidatorAPI, ValidatorAPI
17
 
18
 
19
+ @docs(tags=["Prompting API"], summary="Chat", description="Chat endpoint.")
 
 
 
 
20
  @request_schema(QueryChatSchema)
21
  @response_schema(StreamChunkSchema, 200)
22
  @response_schema(StreamErrorSchema, 400)
 
34
  @docs(
35
  tags=["Prompting API"],
36
  summary="Echo test",
37
+ description="Echo endpoint for testing purposes.",
38
  )
39
  @request_schema(QueryChatSchema)
40
  @response_schema(StreamChunkSchema, 200)
 
47
  def __init__(self, validator_instance=None, *args, **kwargs):
48
  super().__init__(*args, **kwargs)
49
 
50
+ self["validator"] = (
51
+ validator_instance if validator_instance else S1ValidatorAPI()
52
+ )
53
 
54
  # Add middlewares to application
55
  self.add_routes(
validators/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .base import QueryValidatorParams, ValidatorAPI, MockValidator
2
  from .sn1_validator_wrapper import S1ValidatorAPI
3
  from .streamer import AsyncResponseDataStreamer
 
 
1
  from .base import QueryValidatorParams, ValidatorAPI, MockValidator
2
  from .sn1_validator_wrapper import S1ValidatorAPI
3
  from .streamer import AsyncResponseDataStreamer
4
+ from .stream_manager import StreamManager
validators/database.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import aiofiles
4
+ import bittensor as bt
5
+ from .streamer import ProcessedStreamResponse
6
+
7
+
8
+ class LogDatabase:
9
+ """
10
+ A class to manage a log database stored as a JSONL (JSON Lines) file.
11
+
12
+ Attributes:
13
+ log_database_path (str): The path to the log database file.
14
+
15
+ Methods:
16
+ ensure_db_exists(file_path):
17
+ Ensures that the log database file exists. If it doesn't, an empty file is created.
18
+
19
+ add_streams_to_db(stream_responses: ProcessedStreamResponse):
20
+ Asynchronously adds stream responses to the log database.
21
+
22
+ append_dicts_to_file(file_path, dictionaries):
23
+ Asynchronously appends a list of dictionaries to the specified file.
24
+ """
25
+
26
+ def __init__(self, log_database_path: str):
27
+ """
28
+ Initializes the LogDatabase with the given log database file path.
29
+
30
+ Args:
31
+ log_database_path (str): The path to the log database file.
32
+ """
33
+ self.log_database_path = log_database_path
34
+ self.ensure_db_exists(log_database_path)
35
+
36
+ def ensure_db_exists(self, file_path):
37
+ """
38
+ Ensures that the log database file exists. If it doesn't, creates an empty JSONL file.
39
+
40
+ Args:
41
+ file_path (str): The path to the log database file.
42
+ """
43
+ if not os.path.exists(file_path):
44
+ # Create an empty JSONL file
45
+ with open(file_path, "w") as file:
46
+ pass
47
+ # TODO: change log to debug
48
+ bt.logging.info(f"File '{file_path}' created.")
49
+ else:
50
+ bt.logging.info(f"File '{file_path}' already exists.")
51
+
52
+ async def add_streams_to_db(self, stream_responses: ProcessedStreamResponse):
53
+ """
54
+ Asynchronously adds stream responses to the log database.
55
+
56
+ Args:
57
+ stream_responses (ProcessedStreamResponse): A list of processed stream responses to add to the log database.
58
+
59
+ Raises:
60
+ Exception: If an error occurs while adding streams to the database.
61
+ """
62
+ bt.logging.info(f"Writing streams to the database...")
63
+ try:
64
+ stream_responses_dict = [
65
+ dict(stream_response) for stream_response in stream_responses
66
+ ]
67
+ await self.append_dicts_to_file(
68
+ self.log_database_path, stream_responses_dict
69
+ )
70
+ bt.logging.success("Streams added to the database.")
71
+ except Exception as e:
72
+ bt.logging.error(f"Error while adding streams to the database: {e}")
73
+ raise e
74
+
75
+ async def append_dicts_to_file(self, file_path, dictionaries):
76
+ """
77
+ Asynchronously appends a list of dictionaries to the specified file.
78
+
79
+ Args:
80
+ file_path (str): The path to the file where dictionaries will be appended.
81
+ dictionaries (list): A list of dictionaries to append to the file.
82
+ """
83
+ async with aiofiles.open(file_path, mode="a") as file:
84
+ for dictionary in dictionaries:
85
+ await file.write(json.dumps(dictionary) + "\n")
validators/sn1_validator_wrapper.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import random
2
  import bittensor as bt
3
  from prompting.validator import Validator
@@ -7,13 +8,14 @@ from .base import QueryValidatorParams, ValidatorAPI
7
  from aiohttp.web_response import Response, StreamResponse
8
  from .streamer import AsyncResponseDataStreamer
9
  from .validator_utils import get_top_incentive_uids
 
10
 
11
 
12
  class S1ValidatorAPI(ValidatorAPI):
13
  def __init__(self):
14
  self.validator = Validator()
15
 
16
- def sample_uids(self, params: QueryValidatorParams):
17
  if params.sampling_mode == "random":
18
  uids = get_random_uids(
19
  self.validator, k=params.k_miners, exclude=params.exclude or []
@@ -22,9 +24,11 @@ class S1ValidatorAPI(ValidatorAPI):
22
  if params.sampling_mode == "top_incentive":
23
  metagraph = self.validator.metagraph
24
  vpermit_tao_limit = self.validator.config.neuron.vpermit_tao_limit
25
-
26
- top_uids = get_top_incentive_uids(metagraph, k=params.k_miners, vpermit_tao_limit=vpermit_tao_limit)
27
-
 
 
28
  return top_uids
29
 
30
  async def get_stream_response(self, params: QueryValidatorParams) -> StreamResponse:
@@ -32,11 +36,13 @@ class S1ValidatorAPI(ValidatorAPI):
32
  # task_name = utils.guess_task_name(params.messages[-1])
33
 
34
  # Get the list of uids to query for this step.
35
- uids = self.sample_uids(params)
36
  axons = [self.validator.metagraph.axons[uid] for uid in uids]
37
 
38
  # Make calls to the network with the prompt.
39
- bt.logging.info(f"Sampling dendrite by {params.sampling_mode} with roles {params.roles} and messages {params.messages}")
 
 
40
 
41
  streams_responses = await self.validator.dendrite(
42
  axons=axons,
@@ -47,13 +53,14 @@ class S1ValidatorAPI(ValidatorAPI):
47
  deserialize=False,
48
  streaming=True,
49
  )
50
- uid_stream_dict = dict(zip(uids, streams_responses))
51
- random_uid, random_stream = random.choice(list(uid_stream_dict.items()))
52
-
53
  # Creates a streamer from the selected stream
54
- streamer = AsyncResponseDataStreamer(async_iterator=random_stream, selected_uid=random_uid)
55
- response = await streamer.stream(params.request)
56
- return response
 
 
 
57
 
58
  async def query_validator(self, params: QueryValidatorParams) -> Response:
59
  return await self.get_stream_response(params)
 
1
+ import asyncio
2
  import random
3
  import bittensor as bt
4
  from prompting.validator import Validator
 
8
  from aiohttp.web_response import Response, StreamResponse
9
  from .streamer import AsyncResponseDataStreamer
10
  from .validator_utils import get_top_incentive_uids
11
+ from .stream_manager import StreamManager
12
 
13
 
14
  class S1ValidatorAPI(ValidatorAPI):
15
  def __init__(self):
16
  self.validator = Validator()
17
 
18
+ def sample_uids(self, params: QueryValidatorParams):
19
  if params.sampling_mode == "random":
20
  uids = get_random_uids(
21
  self.validator, k=params.k_miners, exclude=params.exclude or []
 
24
  if params.sampling_mode == "top_incentive":
25
  metagraph = self.validator.metagraph
26
  vpermit_tao_limit = self.validator.config.neuron.vpermit_tao_limit
27
+
28
+ top_uids = get_top_incentive_uids(
29
+ metagraph, k=params.k_miners, vpermit_tao_limit=vpermit_tao_limit
30
+ )
31
+
32
  return top_uids
33
 
34
  async def get_stream_response(self, params: QueryValidatorParams) -> StreamResponse:
 
36
  # task_name = utils.guess_task_name(params.messages[-1])
37
 
38
  # Get the list of uids to query for this step.
39
+ uids = self.sample_uids(params)
40
  axons = [self.validator.metagraph.axons[uid] for uid in uids]
41
 
42
  # Make calls to the network with the prompt.
43
+ bt.logging.info(
44
+ f"Sampling dendrite by {params.sampling_mode} with roles {params.roles} and messages {params.messages}"
45
+ )
46
 
47
  streams_responses = await self.validator.dendrite(
48
  axons=axons,
 
53
  deserialize=False,
54
  streaming=True,
55
  )
56
+
 
 
57
  # Creates a streamer from the selected stream
58
+ stream_manager = StreamManager()
59
+ selected_stream = await stream_manager.process_streams(
60
+ params.request, streams_responses, uids
61
+ )
62
+
63
+ return selected_stream
64
 
65
  async def query_validator(self, params: QueryValidatorParams) -> Response:
66
  return await self.get_stream_response(params)
validators/stream_manager.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import bittensor as bt
3
+ from .streamer import AsyncResponseDataStreamer
4
+ from .database import LogDatabase
5
+ from typing import List, AsyncIterator
6
+ from aiohttp.web import Request
7
+
8
+
9
+ class StreamManager:
10
+ """
11
+ A class to manage the processing of multiple asynchronous data streams and log their responses.
12
+
13
+ Attributes:
14
+ log_database (LogDatabase): The log database to store stream responses.
15
+
16
+ Methods:
17
+ process_streams(request, streams_responses, stream_uids):
18
+ Processes multiple asynchronous streams, logs their responses, and returns the selected stream response.
19
+ """
20
+
21
+ def __init__(self, log_database_path: str = "requests_db.jsonl"):
22
+ """
23
+ Initializes the StreamManager with the given log database file path.
24
+
25
+ Args:
26
+ log_database_path (str): The path to the log database file, defaults to "requests_db.jsonl".
27
+ """
28
+ self.log_database = LogDatabase(log_database_path)
29
+
30
+ async def process_streams(
31
+ self,
32
+ request: Request,
33
+ streams_responses: List[AsyncIterator],
34
+ stream_uids: List[int],
35
+ ):
36
+ """
37
+ Processes multiple asynchronous streams, logs their responses, and returns the selected stream response (stream from first non-empty chunk).
38
+
39
+ Args:
40
+ request (Request): The web request object.
41
+ streams_responses (List[AsyncIterator]): A list of asynchronous iterators representing the streams.
42
+ stream_uids (List[int]): A list of unique IDs for the streams.
43
+
44
+ Returns:
45
+ ProcessedStreamResponse: The response from the selected stream.
46
+ """
47
+ lock = asyncio.Lock()
48
+
49
+ streamers = [
50
+ AsyncResponseDataStreamer(
51
+ async_iterator=stream, selected_uid=stream_uid, lock=lock
52
+ )
53
+ for stream, stream_uid in zip(streams_responses, stream_uids)
54
+ ]
55
+ completed_streams = await asyncio.gather(
56
+ *[streamer.stream(request) for streamer in streamers]
57
+ )
58
+
59
+ lock.release()
60
+ bt.logging.info(f"Streams from uids: {stream_uids} processing completed.")
61
+
62
+ await self.log_database.add_streams_to_db(completed_streams)
63
+ # Gets the first stream that acquired the lock, meaning the first stream that was able to return a non-empty chunk
64
+ selected_stream = next(
65
+ (
66
+ completed_stream
67
+ for streamer, completed_stream in zip(streamers, completed_streams)
68
+ if streamer.lock_acquired
69
+ ),
70
+ None,
71
+ )
72
+
73
+ return selected_stream
validators/streamer.py CHANGED
@@ -1,15 +1,28 @@
1
  import json
2
  import time
 
3
  import traceback
4
  import bittensor as bt
5
  from pydantic import BaseModel
6
  from datetime import datetime
7
- from typing import AsyncIterator, Optional, List
8
  from aiohttp import web, web_response
9
  from prompting.protocol import StreamPromptingSynapse
10
 
11
 
12
  class StreamChunk(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
13
  delta: str
14
  finish_reason: Optional[str]
15
  accumulated_chunks: List[str]
@@ -17,74 +30,207 @@ class StreamChunk(BaseModel):
17
  timestamp: str
18
  sequence_number: int
19
  selected_uid: int
20
-
21
  def encode(self, encoding: str) -> bytes:
 
 
 
 
 
 
 
 
 
22
  data = json.dumps(self.dict(), indent=4)
23
  return data.encode(encoding)
24
 
25
 
26
  class StreamError(BaseModel):
 
 
 
 
 
 
 
 
 
27
  error: str
28
  timestamp: str
29
  sequence_number: int
30
- finish_reason: str = 'error'
31
 
32
  def encode(self, encoding: str) -> bytes:
33
  data = json.dumps(self.dict(), indent=4)
34
  return data.encode(encoding)
35
 
36
 
 
 
 
37
  class AsyncResponseDataStreamer:
38
- def __init__(self, async_iterator: AsyncIterator, selected_uid:int, delay: float = 0.1):
39
- self.async_iterator = async_iterator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  self.delay = delay
41
  self.selected_uid = selected_uid
42
  self.accumulated_chunks: List[str] = []
43
  self.accumulated_chunks_timings: List[float] = []
44
  self.finish_reason: str = None
45
  self.sequence_number: int = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- async def stream(self, request: web.Request) -> web_response.StreamResponse:
48
- response = web_response.StreamResponse(status=200, reason="OK")
49
- response.headers["Content-Type"] = "application/json"
50
- await response.prepare(request) # Prepare and send the headers
51
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  try:
53
  start_time = time.time()
 
 
 
54
  async for chunk in self.async_iterator:
55
- if isinstance(chunk, list):
56
- # Chunks are currently returned in string arrays, so we need to concatenate them
57
- concatenated_chunks = "".join(chunk)
58
- self.accumulated_chunks.append(concatenated_chunks)
 
 
59
  self.accumulated_chunks_timings.append(time.time() - start_time)
60
  # Gets new response state
61
  self.sequence_number += 1
62
- new_response_state = self._create_chunk_response(concatenated_chunks)
63
- # Writes the new response state to the response
64
- await response.write(new_response_state.encode('utf-8'))
65
-
 
 
 
 
66
  if chunk is not None and isinstance(chunk, StreamPromptingSynapse):
67
- if len(self.accumulated_chunks) == 0:
68
- self.accumulated_chunks.append(chunk.completion)
69
  self.accumulated_chunks_timings.append(time.time() - start_time)
70
-
71
  self.finish_reason = "completed"
72
  self.sequence_number += 1
73
  # Assuming the last chunk holds the last value yielded which should be a synapse with the completion filled
74
- synapse = chunk
75
- final_state = self._create_chunk_response(synapse.completion)
76
- await response.write(final_state.encode('utf-8'))
77
-
 
 
 
 
 
 
78
  except Exception as e:
79
  bt.logging.error(
80
- f"Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}"
81
  )
82
- response.set_status(500, reason="Internal error")
83
  error_response = self._create_error_response(str(e))
84
- response.write(error_response.encode('utf-8'))
 
 
 
 
 
 
85
  finally:
86
- await response.write_eof() # Ensure to close the response properly
87
- return response
 
 
 
 
 
88
 
89
  def _create_chunk_response(self, chunk: str) -> StreamChunk:
90
  """
@@ -100,7 +246,7 @@ class AsyncResponseDataStreamer:
100
  accumulated_chunks_timings=self.accumulated_chunks_timings,
101
  timestamp=self._current_timestamp(),
102
  sequence_number=self.sequence_number,
103
- selected_uid=self.selected_uid
104
  )
105
 
106
  def _create_error_response(self, error_message: str) -> StreamError:
@@ -113,7 +259,7 @@ class AsyncResponseDataStreamer:
113
  return StreamError(
114
  error=error_message,
115
  timestamp=self._current_timestamp(),
116
- sequence_number=self.sequence_number
117
  )
118
 
119
  def _current_timestamp(self) -> str:
@@ -122,4 +268,4 @@ class AsyncResponseDataStreamer:
122
 
123
  :return: Current timestamp as a string.
124
  """
125
- return datetime.utcnow().isoformat()
 
1
  import json
2
  import time
3
+ import asyncio
4
  import traceback
5
  import bittensor as bt
6
  from pydantic import BaseModel
7
  from datetime import datetime
8
+ from typing import AsyncIterator, Optional, List, Union
9
  from aiohttp import web, web_response
10
  from prompting.protocol import StreamPromptingSynapse
11
 
12
 
13
  class StreamChunk(BaseModel):
14
+ """
15
+ A model representing a chunk of streaming data.
16
+
17
+ Attributes:
18
+ delta (str): The change in the stream.
19
+ finish_reason (Optional[str]): The reason for finishing the stream.
20
+ accumulated_chunks (List[str]): List of accumulated chunks.
21
+ accumulated_chunks_timings (List[float]): Timings for the accumulated chunks.
22
+ timestamp (str): The timestamp of the chunk.
23
+ sequence_number (int): The sequence number of the chunk.
24
+ selected_uid (int): The selected user ID.
25
+ """
26
  delta: str
27
  finish_reason: Optional[str]
28
  accumulated_chunks: List[str]
 
30
  timestamp: str
31
  sequence_number: int
32
  selected_uid: int
33
+
34
  def encode(self, encoding: str) -> bytes:
35
+ """
36
+ Encodes the StreamChunk instance to a JSON-formatted bytes object.
37
+
38
+ Args:
39
+ encoding (str): The encoding to use.
40
+
41
+ Returns:
42
+ bytes: The encoded JSON data.
43
+ """
44
  data = json.dumps(self.dict(), indent=4)
45
  return data.encode(encoding)
46
 
47
 
48
  class StreamError(BaseModel):
49
+ """
50
+ A model representing an error in the streaming data.
51
+
52
+ Attributes:
53
+ error (str): The error message.
54
+ timestamp (str): The timestamp of the error.
55
+ sequence_number (int): The sequence number at the time of error.
56
+ finish_reason (str): The reason for finishing the stream, defaults to "error".
57
+ """
58
  error: str
59
  timestamp: str
60
  sequence_number: int
61
+ finish_reason: str = "error"
62
 
63
  def encode(self, encoding: str) -> bytes:
64
  data = json.dumps(self.dict(), indent=4)
65
  return data.encode(encoding)
66
 
67
 
68
+ ProcessedStreamResponse = Union[StreamChunk, StreamError]
69
+
70
+
71
  class AsyncResponseDataStreamer:
72
+ """
73
+ A class to manage asynchronous streaming of response data.
74
+
75
+ Attributes:
76
+ async_iterator (AsyncIterator): An asynchronous iterator for streaming data.
77
+ selected_uid (int): The selected user ID.
78
+ lock (asyncio.Lock): An asyncio lock to ensure exclusive access.
79
+ delay (float): Delay between processing chunks, defaults to 0.1 seconds.
80
+ accumulated_chunks (List[str]): List of accumulated chunks.
81
+ accumulated_chunks_timings (List[float]): Timings for the accumulated chunks.
82
+ finish_reason (str): The reason for finishing the stream.
83
+ sequence_number (int): The sequence number of the stream.
84
+ lock_acquired (bool): Flag indicating if the lock was acquired.
85
+ """
86
+ def __init__(
87
+ self,
88
+ async_iterator: AsyncIterator,
89
+ selected_uid: int,
90
+ lock: asyncio.Lock,
91
+ delay: float = 0.1,
92
+ ):
93
+ self.async_iterator = async_iterator
94
  self.delay = delay
95
  self.selected_uid = selected_uid
96
  self.accumulated_chunks: List[str] = []
97
  self.accumulated_chunks_timings: List[float] = []
98
  self.finish_reason: str = None
99
  self.sequence_number: int = 0
100
+ self.lock = lock
101
+ self.lock_acquired = False
102
+
103
+ def ensure_response_is_created(
104
+ self, initiated_response: web.StreamResponse
105
+ ) -> web.StreamResponse:
106
+ """
107
+ Ensures that a StreamResponse is created if it does not already exist.
108
+
109
+ Args:
110
+ initiated_response (web.StreamResponse): The initiated response.
111
+
112
+ Returns:
113
+ web.StreamResponse: The ensured response.
114
+ """
115
+ # Creates response if it was not created
116
+ if initiated_response == None:
117
+ initiated_response = web_response.StreamResponse(status=200, reason="OK")
118
+ initiated_response.headers["Content-Type"] = "application/json"
119
+ return initiated_response
120
+
121
+ return initiated_response
122
+
123
+ async def write_to_stream(
124
+ self,
125
+ request: web.Request,
126
+ initiated_response: web.StreamResponse,
127
+ stream_chunk: StreamChunk,
128
+ lock: asyncio.Lock,
129
+ ) -> web.StreamResponse:
130
+ """
131
+ Writes a stream chunk to the response if the lock is acquired.
132
+
133
+ Args:
134
+ request (web.Request): The web request object.
135
+ initiated_response (web.StreamResponse): The initiated response.
136
+ stream_chunk (StreamChunk): The chunk of stream data to write.
137
+ lock (asyncio.Lock): The lock to ensure exclusive access.
138
 
139
+ Returns:
140
+ web.StreamResponse: The response with the written chunk.
141
+ """
142
+ # Try to acquire the lock and sets the lock_acquired flag. Only the stream that acquires the lock should write to the response
143
+ if lock.locked() == False:
144
+ self.lock_acquired = await lock.acquire()
145
+
146
+ if initiated_response == None and self.lock_acquired:
147
+ initiated_response = self.ensure_response_is_created(initiated_response)
148
+ # Prepare and send the headers
149
+ await initiated_response.prepare(request)
150
+
151
+ if self.lock_acquired:
152
+ await initiated_response.write(stream_chunk.encode("utf-8"))
153
+ else:
154
+ bt.logging.debug(
155
+ f"Stream of uid {stream_chunk.selected_uid} was not the first to return, skipping..."
156
+ )
157
+
158
+ return initiated_response
159
+
160
+ async def stream(self, request: web.Request) -> ProcessedStreamResponse:
161
+ """
162
+ Streams data from the async iterator and writes it to the response.
163
+
164
+ Args:
165
+ request (web.Request): The web request object.
166
+
167
+ Returns:
168
+ ProcessedStreamResponse: The final processed stream response.
169
+
170
+ Raises:
171
+ ValueError: If the stream does not return a valid synapse.
172
+ """
173
  try:
174
  start_time = time.time()
175
+ client_response: web.Response = None
176
+ final_response: ProcessedStreamResponse
177
+
178
  async for chunk in self.async_iterator:
179
+ if isinstance(chunk, str):
180
+ # If chunk is empty, skip
181
+ if not chunk:
182
+ continue
183
+
184
+ self.accumulated_chunks.append(chunk)
185
  self.accumulated_chunks_timings.append(time.time() - start_time)
186
  # Gets new response state
187
  self.sequence_number += 1
188
+ new_response_state = self._create_chunk_response(
189
+ chunk
190
+ )
191
+ # Writes the new response state to the response
192
+ client_response = await self.write_to_stream(
193
+ request, client_response, new_response_state, self.lock
194
+ )
195
+
196
  if chunk is not None and isinstance(chunk, StreamPromptingSynapse):
197
+ if len(self.accumulated_chunks) == 0:
198
+ self.accumulated_chunks.append(chunk.completion)
199
  self.accumulated_chunks_timings.append(time.time() - start_time)
200
+
201
  self.finish_reason = "completed"
202
  self.sequence_number += 1
203
  # Assuming the last chunk holds the last value yielded which should be a synapse with the completion filled
204
+ synapse = chunk
205
+ final_response = self._create_chunk_response(synapse.completion)
206
+
207
+ if synapse.completion:
208
+ client_response = await self.write_to_stream(
209
+ request, client_response, final_response, self.lock
210
+ )
211
+ else:
212
+ raise ValueError("Stream did not return a valid synapse.")
213
+
214
  except Exception as e:
215
  bt.logging.error(
216
+ f"Encountered an error while processing stream for uid {self.selected_uid} get_stream_response:\n{traceback.format_exc()}"
217
  )
 
218
  error_response = self._create_error_response(str(e))
219
+ final_response = error_response
220
+
221
+ # Only the stream that acquires the lock should write the error response
222
+ if self.lock_acquired:
223
+ self.ensure_response_is_created(client_response)
224
+ client_response.set_status(500, reason="Internal error")
225
+ client_response.write(error_response.encode("utf-8"))
226
  finally:
227
+ # Only the stream that acquires the lock should close the response
228
+ if self.lock_acquired:
229
+ self.ensure_response_is_created(client_response)
230
+ # Ensure to close the response properly
231
+ await client_response.write_eof()
232
+
233
+ return final_response
234
 
235
  def _create_chunk_response(self, chunk: str) -> StreamChunk:
236
  """
 
246
  accumulated_chunks_timings=self.accumulated_chunks_timings,
247
  timestamp=self._current_timestamp(),
248
  sequence_number=self.sequence_number,
249
+ selected_uid=self.selected_uid,
250
  )
251
 
252
  def _create_error_response(self, error_message: str) -> StreamError:
 
259
  return StreamError(
260
  error=error_message,
261
  timestamp=self._current_timestamp(),
262
+ sequence_number=self.sequence_number,
263
  )
264
 
265
  def _current_timestamp(self) -> str:
 
268
 
269
  :return: Current timestamp as a string.
270
  """
271
+ return datetime.utcnow().isoformat()
validators/validator_utils.py CHANGED
@@ -3,21 +3,33 @@ from prompting.utils.uids import check_uid_availability
3
 
4
 
5
  def get_top_incentive_uids(metagraph, k: int, vpermit_tao_limit: int) -> List[int]:
6
- miners_uids = list(map(int, filter(lambda uid: check_uid_availability(metagraph, uid, vpermit_tao_limit), metagraph.uids)))
7
-
 
 
 
 
 
 
 
 
8
  # Builds a dictionary of uids and their corresponding incentives
9
  all_miners_incentives = {
10
  "miners_uids": miners_uids,
11
- "incentives": list(map(lambda uid: metagraph.I[uid], miners_uids))
12
  }
13
-
14
  # Zip the uids and their corresponding incentives into a list of tuples
15
- uid_incentive_pairs = list(zip(all_miners_incentives['miners_uids'], all_miners_incentives['incentives']))
 
 
16
 
17
  # Sort the list of tuples by the incentive value in descending order
18
- uid_incentive_pairs_sorted = sorted(uid_incentive_pairs, key=lambda x: x[1], reverse=True)
 
 
19
 
20
  # Extract the top 10 uids
21
  top_k_uids = [uid for uid, incentive in uid_incentive_pairs_sorted[:k]]
22
-
23
- return top_k_uids
 
3
 
4
 
5
  def get_top_incentive_uids(metagraph, k: int, vpermit_tao_limit: int) -> List[int]:
6
+ miners_uids = list(
7
+ map(
8
+ int,
9
+ filter(
10
+ lambda uid: check_uid_availability(metagraph, uid, vpermit_tao_limit),
11
+ metagraph.uids,
12
+ ),
13
+ )
14
+ )
15
+
16
  # Builds a dictionary of uids and their corresponding incentives
17
  all_miners_incentives = {
18
  "miners_uids": miners_uids,
19
+ "incentives": list(map(lambda uid: metagraph.I[uid], miners_uids)),
20
  }
21
+
22
  # Zip the uids and their corresponding incentives into a list of tuples
23
+ uid_incentive_pairs = list(
24
+ zip(all_miners_incentives["miners_uids"], all_miners_incentives["incentives"])
25
+ )
26
 
27
  # Sort the list of tuples by the incentive value in descending order
28
+ uid_incentive_pairs_sorted = sorted(
29
+ uid_incentive_pairs, key=lambda x: x[1], reverse=True
30
+ )
31
 
32
  # Extract the top 10 uids
33
  top_k_uids = [uid for uid, incentive in uid_incentive_pairs_sorted[:k]]
34
+
35
+ return top_k_uids