pedroferreira commited on
Commit
cd059dd
·
1 Parent(s): ba0d874

adds sampling mode to back-end

Browse files
validators/base.py CHANGED
@@ -13,6 +13,7 @@ class QueryValidatorParams:
13
  timeout: int
14
  prefer: str
15
  request: Request
 
16
 
17
  @staticmethod
18
  def from_request(request: Request):
@@ -26,6 +27,7 @@ class QueryValidatorParams:
26
  timeout=data.get("timeout", 10),
27
  prefer=data.get("prefer", "longest"),
28
  request=request,
 
29
  )
30
 
31
 
 
13
  timeout: int
14
  prefer: str
15
  request: Request
16
+ sampling_mode: str
17
 
18
  @staticmethod
19
  def from_request(request: Request):
 
27
  timeout=data.get("timeout", 10),
28
  prefer=data.get("prefer", "longest"),
29
  request=request,
30
+ sampling_mode=data.get("sampling_mode", "random"),
31
  )
32
 
33
 
validators/sn1_validator_wrapper.py CHANGED
@@ -5,31 +5,34 @@ from prompting.utils.uids import get_random_uids
5
  from prompting.protocol import StreamPromptingSynapse
6
  from .base import QueryValidatorParams, ValidatorAPI
7
  from aiohttp.web_response import Response, StreamResponse
8
- from dataclasses import dataclass
9
- from typing import List
10
  from .streamer import AsyncResponseDataStreamer
11
-
12
-
13
- @dataclass
14
- class ProcessedStreamResponse:
15
- streamed_chunks: List[str]
16
- streamed_chunks_timings: List[float]
17
- synapse: StreamPromptingSynapse
18
 
19
 
20
  class S1ValidatorAPI(ValidatorAPI):
21
  def __init__(self):
22
  self.validator = Validator()
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  async def get_stream_response(self, params: QueryValidatorParams) -> StreamResponse:
26
  # Guess the task name of current request
27
  # task_name = utils.guess_task_name(params.messages[-1])
28
 
29
  # Get the list of uids to query for this step.
30
- uids = get_random_uids(
31
- self.validator, k=params.k_miners, exclude=params.exclude or []
32
- ).tolist()
33
  axons = [self.validator.metagraph.axons[uid] for uid in uids]
34
 
35
  # Make calls to the network with the prompt.
 
5
  from prompting.protocol import StreamPromptingSynapse
6
  from .base import QueryValidatorParams, ValidatorAPI
7
  from aiohttp.web_response import Response, StreamResponse
 
 
8
  from .streamer import AsyncResponseDataStreamer
9
+ from .validator_utils import get_top_incentive_uids
 
 
 
 
 
 
10
 
11
 
12
  class S1ValidatorAPI(ValidatorAPI):
13
  def __init__(self):
14
  self.validator = Validator()
15
 
16
+ def sample_uids(self, params: QueryValidatorParams):
17
+ if params.sampling_mode == "random":
18
+ uids = get_random_uids(
19
+ self.validator, k=params.k_miners, exclude=params.exclude or []
20
+ ).tolist()
21
+ return uids
22
+ if params.sampling_mode == "top_incentive":
23
+ metagraph = self.validator.metagraph
24
+ vpermit_tao_limit = self.validator.config.neuron.vpermit_tao_limit
25
+
26
+ top_uids = get_top_incentive_uids(metagraph, k=params.k_miners, vpermit_tao_limit=vpermit_tao_limit)
27
+
28
+ return top_uids
29
 
30
  async def get_stream_response(self, params: QueryValidatorParams) -> StreamResponse:
31
  # Guess the task name of current request
32
  # task_name = utils.guess_task_name(params.messages[-1])
33
 
34
  # Get the list of uids to query for this step.
35
+ uids = self.sample_uids(params)
 
 
36
  axons = [self.validator.metagraph.axons[uid] for uid in uids]
37
 
38
  # Make calls to the network with the prompt.
validators/validator_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from prompting.utils.uids import check_uid_availability
3
+
4
+
5
+ def get_top_incentive_uids(metagraph, k: int, vpermit_tao_limit: int) -> List[int]:
6
+ miners_uids = list(map(int, filter(lambda uid: check_uid_availability(metagraph, uid, vpermit_tao_limit), metagraph.uids)))
7
+
8
+ # Builds a dictionary of uids and their corresponding incentives
9
+ all_miners_incentives = {
10
+ "miners_uids": miners_uids,
11
+ "incentives": list(map(lambda uid: metagraph.I[uid], miners_uids))
12
+ }
13
+
14
+ # Zip the uids and their corresponding incentives into a list of tuples
15
+ uid_incentive_pairs = list(zip(all_miners_incentives['miners_uids'], all_miners_incentives['incentives']))
16
+
17
+ # Sort the list of tuples by the incentive value in descending order
18
+ uid_incentive_pairs_sorted = sorted(uid_incentive_pairs, key=lambda x: x[1], reverse=True)
19
+
20
+ # Extract the top 10 uids
21
+ top_k_uids = [uid for uid, incentive in uid_incentive_pairs_sorted[:k]]
22
+
23
+ return top_k_uids