pedroferreira commited on
Commit
fdc8fdb
·
1 Parent(s): 1e9def1

runs black

Browse files
forward.py CHANGED
@@ -15,17 +15,22 @@ from prompting.utils.logging import log_event
15
  from prompting.utils.misc import async_log, serialize_exception_to_string
16
  from dataclasses import dataclass
17
 
 
18
  @async_log
19
  async def generate_reference(agent):
20
  loop = asyncio.get_running_loop()
21
- result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
 
 
22
  return result
23
 
 
24
  @async_log
25
  async def execute_dendrite_call(dendrite_call):
26
  responses = await dendrite_call
27
  return responses
28
 
 
29
  @dataclass
30
  class StreamResult:
31
  synapse: StreamPromptingSynapse = None
@@ -82,20 +87,19 @@ async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]
82
  (uid, responses[uid]) for uid, _ in responses.items()
83
  ] # Pair UIDs with their tasks
84
 
85
- #Start tasks, preserving order and their associated UIDs
86
  tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]
87
 
88
  results = await asyncio.gather(*tasks, return_exceptions=True)
89
 
90
  mapped_results = []
91
- #Pair each result with its original uid
92
  for (uid, _), result in zip(tasks_with_uid, results):
93
-
94
- #If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
95
  if isinstance(result, StreamPromptingSynapse):
96
  mapped_results.append(StreamResult(synapse=result, uid=uid))
97
 
98
- #If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
99
  elif isinstance(result, BaseException):
100
  failed_synapse = StreamPromptingSynapse(
101
  roles=["user"], messages=["failure"], completion=""
@@ -104,7 +108,7 @@ async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]
104
  StreamResult(synapse=failed_synapse, exception=result, uid=uid)
105
  )
106
 
107
- #If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
108
  else:
109
  bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
110
  raise ValueError(f"Unexpected result type for UID {uid}: {result}")
 
15
  from prompting.utils.misc import async_log, serialize_exception_to_string
16
  from dataclasses import dataclass
17
 
18
+
19
  @async_log
20
  async def generate_reference(agent):
21
  loop = asyncio.get_running_loop()
22
+ result = await loop.run_in_executor(
23
+ None, agent.task.generate_reference, agent.llm_pipeline
24
+ )
25
  return result
26
 
27
+
28
  @async_log
29
  async def execute_dendrite_call(dendrite_call):
30
  responses = await dendrite_call
31
  return responses
32
 
33
+
34
  @dataclass
35
  class StreamResult:
36
  synapse: StreamPromptingSynapse = None
 
87
  (uid, responses[uid]) for uid, _ in responses.items()
88
  ] # Pair UIDs with their tasks
89
 
90
+ # Start tasks, preserving order and their associated UIDs
91
  tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]
92
 
93
  results = await asyncio.gather(*tasks, return_exceptions=True)
94
 
95
  mapped_results = []
96
+ # Pair each result with its original uid
97
  for (uid, _), result in zip(tasks_with_uid, results):
98
+ # If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
 
99
  if isinstance(result, StreamPromptingSynapse):
100
  mapped_results.append(StreamResult(synapse=result, uid=uid))
101
 
102
+ # If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
103
  elif isinstance(result, BaseException):
104
  failed_synapse = StreamPromptingSynapse(
105
  roles=["user"], messages=["failure"], completion=""
 
108
  StreamResult(synapse=failed_synapse, exception=result, uid=uid)
109
  )
110
 
111
+ # If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
112
  else:
113
  bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
114
  raise ValueError(f"Unexpected result type for UID {uid}: {result}")
middlewares.py CHANGED
@@ -3,29 +3,31 @@ import json
3
  import bittensor as bt
4
  from aiohttp.web import Request, Response, middleware
5
 
6
- EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
 
7
 
8
  @middleware
9
- async def api_key_middleware(request: Request, handler):
10
  # Logging the request
11
  bt.logging.info(f"Handling {request.method} request to {request.path}")
12
 
13
  # Check access key
14
  access_key = request.headers.get("api_key")
15
  if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
16
- bt.logging.error(f'Invalid access key: {access_key}')
17
  return Response(status=401, reason="Invalid access key")
18
 
19
  # Continue to the next handler if the API key is valid
20
- return await handler(request)
 
21
 
22
  @middleware
23
- async def json_parsing_middleware(request: Request, handler):
24
  try:
25
  # Parsing JSON data from the request
26
- request['data'] = await request.json()
27
  except json.JSONDecodeError as e:
28
- bt.logging.error(f'Invalid JSON data: {str(e)}')
29
  return Response(status=400, text="Invalid JSON")
30
 
31
  # Continue to the next handler if JSON is successfully parsed
 
3
  import bittensor as bt
4
  from aiohttp.web import Request, Response, middleware
5
 
6
+ EXPECTED_ACCESS_KEY = os.environ.get("EXPECTED_ACCESS_KEY")
7
+
8
 
9
  @middleware
10
+ async def api_key_middleware(request: Request, handler):
11
  # Logging the request
12
  bt.logging.info(f"Handling {request.method} request to {request.path}")
13
 
14
  # Check access key
15
  access_key = request.headers.get("api_key")
16
  if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
17
+ bt.logging.error(f"Invalid access key: {access_key}")
18
  return Response(status=401, reason="Invalid access key")
19
 
20
  # Continue to the next handler if the API key is valid
21
+ return await handler(request)
22
+
23
 
24
  @middleware
25
+ async def json_parsing_middleware(request: Request, handler):
26
  try:
27
  # Parsing JSON data from the request
28
+ request["data"] = await request.json()
29
  except json.JSONDecodeError as e:
30
+ bt.logging.error(f"Invalid JSON data: {str(e)}")
31
  return Response(status=400, text="Invalid JSON")
32
 
33
  # Continue to the next handler if JSON is successfully parsed
server.py CHANGED
@@ -34,53 +34,53 @@ EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.n
34
  ```
35
  add --mock to test the echo stream
36
  """
 
 
37
  async def chat(request: web.Request) -> Response:
38
  """
39
  Chat endpoint for the validator.
40
- """
41
- params = QueryValidatorParams.from_request(request)
42
 
43
-
44
  # Access the validator from the application context
45
- validator: ValidatorAPI = request.app['validator']
46
-
47
  response = await validator.query_validator(params)
48
  return response
49
 
50
 
51
- async def echo_stream(request, request_data):
52
- request_data = request['data']
53
  return await utils.echo_stream(request_data)
54
 
55
 
56
-
57
  class ValidatorApplication(web.Application):
58
  def __init__(self, validator_instance=None, *args, **kwargs):
59
  super().__init__(*args, **kwargs)
60
-
61
- self['validator'] = validator_instance if validator_instance else S1ValidatorAPI()
62
-
63
- # Add middlewares to application
64
- self.add_routes([
65
- web.post('/chat/', chat),
66
- web.post('/echo/', echo_stream)
67
- ])
68
  self.setup_middlewares()
69
  # TODO: Enable rewarding and other features
70
-
71
  def setup_middlewares(self):
72
  self.middlewares.append(json_parsing_middleware)
73
  self.middlewares.append(api_key_middleware)
74
-
 
75
  def main(run_aio_app=True, test=False) -> None:
76
  loop = asyncio.get_event_loop()
77
  port = 10000
78
  if run_aio_app:
79
  # Instantiate the application with the actual validator
80
  bt.logging.info("Starting validator application.")
81
- validator_app = ValidatorApplication()
82
- bt.logging.success(f'Validator app initialized successfully', validator_app)
83
-
84
  try:
85
  web.run_app(validator_app, port=port, loop=loop)
86
  except KeyboardInterrupt:
@@ -88,5 +88,6 @@ def main(run_aio_app=True, test=False) -> None:
88
  finally:
89
  pass
90
 
 
91
  if __name__ == "__main__":
92
  main()
 
34
  ```
35
  add --mock to test the echo stream
36
  """
37
+
38
+
39
  async def chat(request: web.Request) -> Response:
40
  """
41
  Chat endpoint for the validator.
42
+ """
43
+ params = QueryValidatorParams.from_request(request)
44
 
 
45
  # Access the validator from the application context
46
+ validator: ValidatorAPI = request.app["validator"]
47
+
48
  response = await validator.query_validator(params)
49
  return response
50
 
51
 
52
+ async def echo_stream(request, request_data):
53
+ request_data = request["data"]
54
  return await utils.echo_stream(request_data)
55
 
56
 
 
57
  class ValidatorApplication(web.Application):
58
  def __init__(self, validator_instance=None, *args, **kwargs):
59
  super().__init__(*args, **kwargs)
60
+
61
+ self["validator"] = (
62
+ validator_instance if validator_instance else S1ValidatorAPI()
63
+ )
64
+
65
+ # Add middlewares to application
66
+ self.add_routes([web.post("/chat/", chat), web.post("/echo/", echo_stream)])
 
67
  self.setup_middlewares()
68
  # TODO: Enable rewarding and other features
69
+
70
  def setup_middlewares(self):
71
  self.middlewares.append(json_parsing_middleware)
72
  self.middlewares.append(api_key_middleware)
73
+
74
+
75
  def main(run_aio_app=True, test=False) -> None:
76
  loop = asyncio.get_event_loop()
77
  port = 10000
78
  if run_aio_app:
79
  # Instantiate the application with the actual validator
80
  bt.logging.info("Starting validator application.")
81
+ validator_app = ValidatorApplication()
82
+ bt.logging.success(f"Validator app initialized successfully", validator_app)
83
+
84
  try:
85
  web.run_app(validator_app, port=port, loop=loop)
86
  except KeyboardInterrupt:
 
88
  finally:
89
  pass
90
 
91
+
92
  if __name__ == "__main__":
93
  main()
test.py CHANGED
@@ -1,10 +1,10 @@
1
  import pytest
2
 
3
 
4
-
5
  def test_query_network():
6
  pass
7
 
 
8
  def test_filter_completions():
9
  pass
10
 
@@ -12,5 +12,6 @@ def test_filter_completions():
12
  def test_guess_task_name():
13
  pass
14
 
 
15
  def test_ensemble_completions():
16
- pass
 
1
  import pytest
2
 
3
 
 
4
  def test_query_network():
5
  pass
6
 
7
+
8
  def test_filter_completions():
9
  pass
10
 
 
12
  def test_guess_task_name():
13
  pass
14
 
15
+
16
  def test_ensemble_completions():
17
+ pass
utils.py CHANGED
@@ -6,13 +6,26 @@ from aiohttp import web
6
  from collections import Counter
7
  from prompting.rewards import DateRewardModel, FloatDiffModel
8
 
9
- UNSUCCESSFUL_RESPONSE_PATTERNS = ["I'm sorry", "unable to", "I cannot", "I can't", "I am unable", "I am sorry", "I can not", "don't know", "not sure", "don't understand", "not capable"]
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  reward_models = {
12
- 'date_qa': DateRewardModel(),
13
- 'math': FloatDiffModel(),
14
  }
15
 
 
16
  def completion_is_valid(completion: str):
17
  """
18
  Get the completion statuses from the completions.
@@ -20,13 +33,15 @@ def completion_is_valid(completion: str):
20
  if not completion.strip():
21
  return False
22
 
23
- patt = re.compile(r'\b(?:' + '|'.join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r')\b', re.IGNORECASE)
24
- if not len(re.findall(r'\w+',completion)) or patt.search(completion):
 
 
25
  return False
26
  return True
27
 
28
 
29
- def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
30
  """
31
  Ensemble completions from multiple models.
32
  # TODO: Measure agreement
@@ -37,11 +52,11 @@ def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
37
  return None
38
 
39
  answer = None
40
- if task_name in ('qa', 'summarization'):
41
  # No special handling for QA or summarization
42
  supporting_completions = completions
43
 
44
- elif task_name == 'date_qa':
45
  # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
46
  dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
47
  bt.logging.info(f"Unprocessed dates: {dates}")
@@ -58,9 +73,11 @@ def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
58
  if count == 1:
59
  supporting_completions = valid_completions
60
  else:
61
- supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common]
 
 
62
 
63
- elif task_name == 'math':
64
  # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
65
  # TODO: use the median instead of the most common value
66
  vals = list(map(reward_models[task_name].extract_number, completions))
@@ -74,57 +91,67 @@ def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
74
  if count == 1:
75
  supporting_completions = completions
76
  else:
77
- supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common]
78
-
 
79
 
80
  bt.logging.info(f"Supporting completions: {supporting_completions}")
81
- if prefer == 'longest':
82
  preferred_completion = sorted(supporting_completions, key=len)[-1]
83
- elif prefer == 'shortest':
84
  preferred_completion = sorted(supporting_completions, key=len)[0]
85
- elif prefer == 'most_common':
86
- preferred_completion = max(set(supporting_completions), key=supporting_completions.count)
 
 
87
  else:
88
  raise ValueError(f"Unknown ensemble preference: {prefer}")
89
 
90
  return {
91
- 'completion': preferred_completion,
92
- 'accepted_answer': answer,
93
- 'support': len(supporting_completions),
94
- 'support_indices': [completions.index(c) for c in supporting_completions],
95
- 'method': f'Selected the {prefer.replace("_", " ")} completion'
96
  }
97
 
 
98
  def guess_task_name(challenge: str):
99
  # TODO: use a pre-trained classifier to guess the task name
100
  categories = {
101
- 'summarization': re.compile('summar|quick rundown|overview'),
102
- 'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'),
103
- 'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'),
 
 
 
 
104
  }
105
  for task_name, patt in categories.items():
106
  if patt.search(challenge):
107
  return task_name
108
 
109
- return 'qa'
110
 
111
 
112
  async def echo_stream(request_data: dict):
113
- k = request_data.get('k', 1)
114
- exclude = request_data.get('exclude', [])
115
- timeout = request_data.get('timeout', 0.2)
116
- message = '\n\n'.join(request_data['messages'])
117
 
118
  # Create a StreamResponse
119
- response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/plain'})
 
 
120
  await response.prepare()
121
 
122
- completion = ''
123
  # Echo the message k times with a timeout between each chunk
124
  for _ in range(k):
125
  for word in message.split():
126
- chunk = f'{word} '
127
- await response.write(chunk.encode('utf-8'))
128
  completion += chunk
129
  time.sleep(timeout)
130
  bt.logging.info(f"Echoed: {chunk}")
@@ -132,21 +159,23 @@ async def echo_stream(request_data: dict):
132
  completion = completion.strip()
133
 
134
  # Prepare final JSON chunk
135
- json_chunk = json.dumps({
136
- "uids": [0],
137
- "completion": completion,
138
- "completions": [completion.strip()],
139
- "timings": [0],
140
- "status_messages": ['Went well!'],
141
- "status_codes": [200],
142
- "completion_is_valid": [True],
143
- "task_name": 'echo',
144
- "ensemble_result": {}
145
- })
 
 
146
 
147
  # Send the final JSON as part of the stream
148
- await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8'))
149
 
150
  # Finalize the response
151
  await response.write_eof()
152
- return response
 
6
  from collections import Counter
7
  from prompting.rewards import DateRewardModel, FloatDiffModel
8
 
9
+ UNSUCCESSFUL_RESPONSE_PATTERNS = [
10
+ "I'm sorry",
11
+ "unable to",
12
+ "I cannot",
13
+ "I can't",
14
+ "I am unable",
15
+ "I am sorry",
16
+ "I can not",
17
+ "don't know",
18
+ "not sure",
19
+ "don't understand",
20
+ "not capable",
21
+ ]
22
 
23
  reward_models = {
24
+ "date_qa": DateRewardModel(),
25
+ "math": FloatDiffModel(),
26
  }
27
 
28
+
29
  def completion_is_valid(completion: str):
30
  """
31
  Get the completion statuses from the completions.
 
33
  if not completion.strip():
34
  return False
35
 
36
+ patt = re.compile(
37
+ r"\b(?:" + "|".join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r")\b", re.IGNORECASE
38
+ )
39
+ if not len(re.findall(r"\w+", completion)) or patt.search(completion):
40
  return False
41
  return True
42
 
43
 
44
+ def ensemble_result(completions: list, task_name: str, prefer: str = "longest"):
45
  """
46
  Ensemble completions from multiple models.
47
  # TODO: Measure agreement
 
52
  return None
53
 
54
  answer = None
55
+ if task_name in ("qa", "summarization"):
56
  # No special handling for QA or summarization
57
  supporting_completions = completions
58
 
59
+ elif task_name == "date_qa":
60
  # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
61
  dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
62
  bt.logging.info(f"Unprocessed dates: {dates}")
 
73
  if count == 1:
74
  supporting_completions = valid_completions
75
  else:
76
+ supporting_completions = [
77
+ c for i, c in enumerate(valid_completions) if dates[i] == most_common
78
+ ]
79
 
80
+ elif task_name == "math":
81
  # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
82
  # TODO: use the median instead of the most common value
83
  vals = list(map(reward_models[task_name].extract_number, completions))
 
91
  if count == 1:
92
  supporting_completions = completions
93
  else:
94
+ supporting_completions = [
95
+ c for i, c in enumerate(completions) if vals[i] == most_common
96
+ ]
97
 
98
  bt.logging.info(f"Supporting completions: {supporting_completions}")
99
+ if prefer == "longest":
100
  preferred_completion = sorted(supporting_completions, key=len)[-1]
101
+ elif prefer == "shortest":
102
  preferred_completion = sorted(supporting_completions, key=len)[0]
103
+ elif prefer == "most_common":
104
+ preferred_completion = max(
105
+ set(supporting_completions), key=supporting_completions.count
106
+ )
107
  else:
108
  raise ValueError(f"Unknown ensemble preference: {prefer}")
109
 
110
  return {
111
+ "completion": preferred_completion,
112
+ "accepted_answer": answer,
113
+ "support": len(supporting_completions),
114
+ "support_indices": [completions.index(c) for c in supporting_completions],
115
+ "method": f'Selected the {prefer.replace("_", " ")} completion',
116
  }
117
 
118
+
119
  def guess_task_name(challenge: str):
120
  # TODO: use a pre-trained classifier to guess the task name
121
  categories = {
122
+ "summarization": re.compile("summar|quick rundown|overview"),
123
+ "date_qa": re.compile(
124
+ "exact date|tell me when|on what date|on what day|was born?|died?"
125
+ ),
126
+ "math": re.compile(
127
+ "math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial"
128
+ ),
129
  }
130
  for task_name, patt in categories.items():
131
  if patt.search(challenge):
132
  return task_name
133
 
134
+ return "qa"
135
 
136
 
137
  async def echo_stream(request_data: dict):
138
+ k = request_data.get("k", 1)
139
+ exclude = request_data.get("exclude", [])
140
+ timeout = request_data.get("timeout", 0.2)
141
+ message = "\n\n".join(request_data["messages"])
142
 
143
  # Create a StreamResponse
144
+ response = web.StreamResponse(
145
+ status=200, reason="OK", headers={"Content-Type": "text/plain"}
146
+ )
147
  await response.prepare()
148
 
149
+ completion = ""
150
  # Echo the message k times with a timeout between each chunk
151
  for _ in range(k):
152
  for word in message.split():
153
+ chunk = f"{word} "
154
+ await response.write(chunk.encode("utf-8"))
155
  completion += chunk
156
  time.sleep(timeout)
157
  bt.logging.info(f"Echoed: {chunk}")
 
159
  completion = completion.strip()
160
 
161
  # Prepare final JSON chunk
162
+ json_chunk = json.dumps(
163
+ {
164
+ "uids": [0],
165
+ "completion": completion,
166
+ "completions": [completion.strip()],
167
+ "timings": [0],
168
+ "status_messages": ["Went well!"],
169
+ "status_codes": [200],
170
+ "completion_is_valid": [True],
171
+ "task_name": "echo",
172
+ "ensemble_result": {},
173
+ }
174
+ )
175
 
176
  # Send the final JSON as part of the stream
177
+ await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode("utf-8"))
178
 
179
  # Finalize the response
180
  await response.write_eof()
181
+ return response
validators/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  from .base import QueryValidatorParams, ValidatorAPI, MockValidator
2
- from .sn1_validator_wrapper import S1ValidatorAPI
 
1
  from .base import QueryValidatorParams, ValidatorAPI, MockValidator
2
+ from .sn1_validator_wrapper import S1ValidatorAPI
validators/base.py CHANGED
@@ -3,6 +3,7 @@ from typing import List
3
  from dataclasses import dataclass
4
  from aiohttp.web import Response, Request
5
 
 
6
  @dataclass
7
  class QueryValidatorParams:
8
  k_miners: int
@@ -12,28 +13,28 @@ class QueryValidatorParams:
12
  timeout: int
13
  prefer: str
14
  request: Request
15
-
16
  @staticmethod
17
  def from_request(request: Request):
18
- data = request['data']
19
-
20
- return QueryValidatorParams(
21
- k_miners=data.get('k', 10),
22
- exclude=data.get('exclude', []),
23
- roles=data['roles'],
24
- messages=data['messages'],
25
- timeout=data.get('timeout', 10),
26
- prefer=data.get('prefer', 'longest'),
27
- request=request
28
  )
29
 
 
30
  class ValidatorAPI(ABC):
31
  @abstractmethod
32
- async def query_validator(self, params:QueryValidatorParams) -> Response:
33
  pass
34
-
35
-
36
- class MockValidator(ValidatorAPI):
37
- async def query_validator(self, params:QueryValidatorParams) -> Response:
38
  ...
39
-
 
3
  from dataclasses import dataclass
4
  from aiohttp.web import Response, Request
5
 
6
+
7
  @dataclass
8
  class QueryValidatorParams:
9
  k_miners: int
 
13
  timeout: int
14
  prefer: str
15
  request: Request
16
+
17
  @staticmethod
18
  def from_request(request: Request):
19
+ data = request["data"]
20
+
21
+ return QueryValidatorParams(
22
+ k_miners=data.get("k", 10),
23
+ exclude=data.get("exclude", []),
24
+ roles=data["roles"],
25
+ messages=data["messages"],
26
+ timeout=data.get("timeout", 10),
27
+ prefer=data.get("prefer", "longest"),
28
+ request=request,
29
  )
30
 
31
+
32
  class ValidatorAPI(ABC):
33
  @abstractmethod
34
+ async def query_validator(self, params: QueryValidatorParams) -> Response:
35
  pass
36
+
37
+
38
+ class MockValidator(ValidatorAPI):
39
+ async def query_validator(self, params: QueryValidatorParams) -> Response:
40
  ...
 
validators/sn1_validator_wrapper.py CHANGED
@@ -13,23 +13,27 @@ from .base import QueryValidatorParams, ValidatorAPI
13
  from aiohttp.web_response import Response, StreamResponse
14
  from deprecated import deprecated
15
 
 
16
  class S1ValidatorAPI(ValidatorAPI):
17
  def __init__(self):
18
- self.validator = Validator()
19
-
20
 
21
- @deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
22
- async def get_response(self, params:QueryValidatorParams) -> Response:
 
 
23
  try:
24
  # Guess the task name of current request
25
  task_name = utils.guess_task_name(params.messages[-1])
26
 
27
  # Get the list of uids to query for this step.
28
- uids = get_random_uids(self.validator, k=params.k_miners, exclude=params.exclude or []).tolist()
 
 
29
  axons = [self.validator.metagraph.axons[uid] for uid in uids]
30
 
31
  # Make calls to the network with the prompt.
32
- bt.logging.info(f'Calling dendrite')
33
  responses = await self.validator.dendrite(
34
  axons=axons,
35
  synapse=PromptingSynapse(roles=params.roles, messages=params.messages),
@@ -38,89 +42,113 @@ class S1ValidatorAPI(ValidatorAPI):
38
 
39
  bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
40
  # Encapsulate the responses in a response event (dataclass)
41
- response_event = DendriteResponseEvent(responses, torch.LongTensor(uids), params.timeout)
 
 
42
 
43
  # convert dict to json
44
  response = response_event.__state_dict__()
45
 
46
- response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
47
- valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
 
 
 
 
48
 
49
- response['task_name'] = task_name
50
- response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=params.prefer)
 
 
51
 
52
  bt.logging.info(f"Response:\n {response}")
53
- return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
 
 
 
 
54
 
55
  except Exception:
56
- bt.logging.error(f'Encountered in {self.__class__.__name__}:get_response:\n{traceback.format_exc()}')
 
 
57
  return Response(status=500, reason="Internal error")
58
-
59
-
60
- async def process_response(self, response: StreamResponse, uid: int, async_generator: Awaitable):
 
61
  """Process a single response asynchronously."""
62
  try:
63
  chunk = None # Initialize chunk with a default value
64
  async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
65
  bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
66
-
67
  # TODO: SET PROPER IMPLEMENTATION TO RETURN CHUNK
68
  if chunk is not None:
69
  json_data = json.dumps(chunk)
70
- await response.write(json_data.encode('utf-8'))
71
-
72
  except Exception as e:
73
- bt.logging.error(f'Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}')
 
 
74
  response.set_status(500, reason="Internal error")
75
- await response.write(json.dumps({'error': str(e)}).encode('utf-8'))
76
  finally:
77
  await response.write_eof() # Ensure to close the response properly
78
-
79
- async def get_stream_response(self, params:QueryValidatorParams) -> StreamResponse:
80
  response = StreamResponse(status=200, reason="OK")
81
- response.headers['Content-Type'] = 'application/json'
82
 
83
  await response.prepare(params.request) # Prepare and send the headers
84
-
85
  try:
86
  # Guess the task name of current request
87
  task_name = utils.guess_task_name(params.messages[-1])
88
 
89
  # Get the list of uids to query for this step.
90
- uids = get_random_uids(self.validator, k=params.k_miners, exclude=params.exclude or []).tolist()
 
 
91
  axons = [self.validator.metagraph.axons[uid] for uid in uids]
92
 
93
  # Make calls to the network with the prompt.
94
- bt.logging.info(f'Calling dendrite')
95
  streams_responses = await self.validator.dendrite(
96
  axons=axons,
97
- synapse=StreamPromptingSynapse(roles=params.roles, messages=params.messages),
 
 
98
  timeout=params.timeout,
99
  deserialize=False,
100
  streaming=True,
101
  )
102
-
103
- tasks = [self.process_response(uid, res) for uid, res in dict(zip(uids, streams_responses))]
 
 
 
104
  results = await asyncio.gather(*tasks, return_exceptions=True)
105
-
106
- # TODO: Continue implementation, business decision needs to be made on how to handle the results
107
  except Exception as e:
108
- bt.logging.error(f'Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}')
 
 
109
  response.set_status(500, reason="Internal error")
110
- await response.write(json.dumps({'error': str(e)}).encode('utf-8'))
111
  finally:
112
  await response.write_eof() # Ensure to close the response properly
113
 
114
  return response
115
 
116
-
117
- async def query_validator(self, params:QueryValidatorParams) -> Response:
118
  # TODO: SET STREAM AS DEFAULT
119
- stream = params.request.get('stream', False)
120
-
121
  if stream:
122
  return await self.get_stream_response(params)
123
  else:
124
  # DEPRECATED
125
  return await self.get_response(params)
126
-
 
13
  from aiohttp.web_response import Response, StreamResponse
14
  from deprecated import deprecated
15
 
16
+
17
  class S1ValidatorAPI(ValidatorAPI):
18
  def __init__(self):
19
+ self.validator = Validator()
 
20
 
21
+ @deprecated(
22
+ reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead."
23
+ )
24
+ async def get_response(self, params: QueryValidatorParams) -> Response:
25
  try:
26
  # Guess the task name of current request
27
  task_name = utils.guess_task_name(params.messages[-1])
28
 
29
  # Get the list of uids to query for this step.
30
+ uids = get_random_uids(
31
+ self.validator, k=params.k_miners, exclude=params.exclude or []
32
+ ).tolist()
33
  axons = [self.validator.metagraph.axons[uid] for uid in uids]
34
 
35
  # Make calls to the network with the prompt.
36
+ bt.logging.info(f"Calling dendrite")
37
  responses = await self.validator.dendrite(
38
  axons=axons,
39
  synapse=PromptingSynapse(roles=params.roles, messages=params.messages),
 
42
 
43
  bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
44
  # Encapsulate the responses in a response event (dataclass)
45
+ response_event = DendriteResponseEvent(
46
+ responses, torch.LongTensor(uids), params.timeout
47
+ )
48
 
49
  # convert dict to json
50
  response = response_event.__state_dict__()
51
 
52
+ response["completion_is_valid"] = valid = list(
53
+ map(utils.completion_is_valid, response["completions"])
54
+ )
55
+ valid_completions = [
56
+ response["completions"][i] for i, v in enumerate(valid) if v
57
+ ]
58
 
59
+ response["task_name"] = task_name
60
+ response["ensemble_result"] = utils.ensemble_result(
61
+ valid_completions, task_name=task_name, prefer=params.prefer
62
+ )
63
 
64
  bt.logging.info(f"Response:\n {response}")
65
+ return Response(
66
+ status=200,
67
+ reason="I can't believe it's not butter!",
68
+ text=json.dumps(response),
69
+ )
70
 
71
  except Exception:
72
+ bt.logging.error(
73
+ f"Encountered in {self.__class__.__name__}:get_response:\n{traceback.format_exc()}"
74
+ )
75
  return Response(status=500, reason="Internal error")
76
+
77
+ async def process_response(
78
+ self, response: StreamResponse, uid: int, async_generator: Awaitable
79
+ ):
80
  """Process a single response asynchronously."""
81
  try:
82
  chunk = None # Initialize chunk with a default value
83
  async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
84
  bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
85
+
86
  # TODO: SET PROPER IMPLEMENTATION TO RETURN CHUNK
87
  if chunk is not None:
88
  json_data = json.dumps(chunk)
89
+ await response.write(json_data.encode("utf-8"))
90
+
91
  except Exception as e:
92
+ bt.logging.error(
93
+ f"Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}"
94
+ )
95
  response.set_status(500, reason="Internal error")
96
+ await response.write(json.dumps({"error": str(e)}).encode("utf-8"))
97
  finally:
98
  await response.write_eof() # Ensure to close the response properly
99
+
100
+ async def get_stream_response(self, params: QueryValidatorParams) -> StreamResponse:
101
  response = StreamResponse(status=200, reason="OK")
102
+ response.headers["Content-Type"] = "application/json"
103
 
104
  await response.prepare(params.request) # Prepare and send the headers
105
+
106
  try:
107
  # Guess the task name of current request
108
  task_name = utils.guess_task_name(params.messages[-1])
109
 
110
  # Get the list of uids to query for this step.
111
+ uids = get_random_uids(
112
+ self.validator, k=params.k_miners, exclude=params.exclude or []
113
+ ).tolist()
114
  axons = [self.validator.metagraph.axons[uid] for uid in uids]
115
 
116
  # Make calls to the network with the prompt.
117
+ bt.logging.info(f"Calling dendrite")
118
  streams_responses = await self.validator.dendrite(
119
  axons=axons,
120
+ synapse=StreamPromptingSynapse(
121
+ roles=params.roles, messages=params.messages
122
+ ),
123
  timeout=params.timeout,
124
  deserialize=False,
125
  streaming=True,
126
  )
127
+
128
+ tasks = [
129
+ self.process_response(uid, res)
130
+ for uid, res in dict(zip(uids, streams_responses))
131
+ ]
132
  results = await asyncio.gather(*tasks, return_exceptions=True)
133
+
134
+ # TODO: Continue implementation, business decision needs to be made on how to handle the results
135
  except Exception as e:
136
+ bt.logging.error(
137
+ f"Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}"
138
+ )
139
  response.set_status(500, reason="Internal error")
140
+ await response.write(json.dumps({"error": str(e)}).encode("utf-8"))
141
  finally:
142
  await response.write_eof() # Ensure to close the response properly
143
 
144
  return response
145
 
146
+ async def query_validator(self, params: QueryValidatorParams) -> Response:
 
147
  # TODO: SET STREAM AS DEFAULT
148
+ stream = params.request.get("stream", False)
149
+
150
  if stream:
151
  return await self.get_stream_response(params)
152
  else:
153
  # DEPRECATED
154
  return await self.get_response(params)