steffenc commited on
Commit
89b6d9e
·
unverified ·
2 Parent(s): c60daaf fdc8fdb

Merge pull request #2 from macrocosm-os/features/mock-validator-integration

Browse files

Back-end development: code refactoring, middlewares and validator abstraction

Files changed (11) hide show
  1. .gitignore +169 -0
  2. README.md +15 -1
  3. forward.py +10 -5
  4. middlewares.py +34 -0
  5. requirements.txt +3 -0
  6. server.py +31 -250
  7. test.py +17 -0
  8. utils.py +181 -0
  9. validators/__init__.py +2 -0
  10. validators/base.py +40 -0
  11. validators/sn1_validator_wrapper.py +154 -0
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ .DS_Store
6
+ **/.DS_Store
7
+
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+
165
+ testing/
166
+ core
167
+ app.config.js
168
+ wandb
169
+ .vscode
README.md CHANGED
@@ -1,4 +1,18 @@
1
  # chattensor-backend
2
  Backend for Chattensor app
3
 
4
- To run, you will need a bittensor wallet which is registered to the relevant subnet (1@mainnet or 61@testnet).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # chattensor-backend
2
  Backend for Chattensor app
3
 
4
+ To run, you will need a bittensor wallet which is registered to the relevant subnet (1@mainnet or 61@testnet).
5
+
6
+
7
+
8
+
9
+ ## Install
10
+ Create a new python environment and install the dependencies with the command
11
+
12
+ ```bash
13
+ pip install -r requirements.txt
14
+ ```
15
+
16
+ > Note: Currently the prompting library is only installable on machines with cuda devices (NVIDIA-GPU).
17
+
18
+
forward.py CHANGED
@@ -15,17 +15,22 @@ from prompting.utils.logging import log_event
15
  from prompting.utils.misc import async_log, serialize_exception_to_string
16
  from dataclasses import dataclass
17
 
 
18
  @async_log
19
- async def generate_reference(agent):
20
  loop = asyncio.get_running_loop()
21
- result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
- return result
 
 
 
23
 
24
  @async_log
25
  async def execute_dendrite_call(dendrite_call):
26
  responses = await dendrite_call
27
  return responses
28
 
 
29
  @dataclass
30
  class StreamResult:
31
  synapse: StreamPromptingSynapse = None
@@ -199,8 +204,8 @@ async def run_step(
199
 
200
  log_stream_results(stream_results)
201
 
202
- all_synapses_results = [stream_result.synapse for stream_result in stream_results]
203
-
204
  # Encapsulate the responses in a response event (dataclass)
205
  response_event = DendriteResponseEvent(
206
  responses=all_synapses_results, uids=uids, timeout=timeout
 
15
  from prompting.utils.misc import async_log, serialize_exception_to_string
16
  from dataclasses import dataclass
17
 
18
+
19
  @async_log
20
+ async def generate_reference(agent):
21
  loop = asyncio.get_running_loop()
22
+ result = await loop.run_in_executor(
23
+ None, agent.task.generate_reference, agent.llm_pipeline
24
+ )
25
+ return result
26
+
27
 
28
  @async_log
29
  async def execute_dendrite_call(dendrite_call):
30
  responses = await dendrite_call
31
  return responses
32
 
33
+
34
  @dataclass
35
  class StreamResult:
36
  synapse: StreamPromptingSynapse = None
 
204
 
205
  log_stream_results(stream_results)
206
 
207
+ all_synapses_results = [stream_result.synapse for stream_result in stream_results]
208
+
209
  # Encapsulate the responses in a response event (dataclass)
210
  response_event = DendriteResponseEvent(
211
  responses=all_synapses_results, uids=uids, timeout=timeout
middlewares.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import bittensor as bt
4
+ from aiohttp.web import Request, Response, middleware
5
+
6
+ EXPECTED_ACCESS_KEY = os.environ.get("EXPECTED_ACCESS_KEY")
7
+
8
+
9
+ @middleware
10
+ async def api_key_middleware(request: Request, handler):
11
+ # Logging the request
12
+ bt.logging.info(f"Handling {request.method} request to {request.path}")
13
+
14
+ # Check access key
15
+ access_key = request.headers.get("api_key")
16
+ if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
17
+ bt.logging.error(f"Invalid access key: {access_key}")
18
+ return Response(status=401, reason="Invalid access key")
19
+
20
+ # Continue to the next handler if the API key is valid
21
+ return await handler(request)
22
+
23
+
24
+ @middleware
25
+ async def json_parsing_middleware(request: Request, handler):
26
+ try:
27
+ # Parsing JSON data from the request
28
+ request["data"] = await request.json()
29
+ except json.JSONDecodeError as e:
30
+ bt.logging.error(f"Invalid JSON data: {str(e)}")
31
+ return Response(status=400, text="Invalid JSON")
32
+
33
+ # Continue to the next handler if JSON is successfully parsed
34
+ return await handler(request)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git+https://github.com/opentensor/prompting.git@features/move-validator-into-prompting
2
+ aiohttp
3
+ deprecated
server.py CHANGED
@@ -1,23 +1,10 @@
1
-
2
-
3
-
4
- import os
5
- import re
6
- import time
7
  import asyncio
8
- import json
9
- import traceback
10
  import bittensor as bt
11
-
12
- from collections import Counter
13
-
14
- from neurons.validator import Validator
15
- from prompting.dendrite import DendriteResponseEvent
16
- from prompting.protocol import PromptingSynapse
17
- from prompting.utils.uids import get_random_uids
18
- from prompting.rewards import DateRewardModel, FloatDiffModel
19
  from aiohttp import web
20
  from aiohttp.web_response import Response
 
 
21
 
22
  """
23
  # test
@@ -43,270 +30,64 @@ EXPECTED_ACCESS_KEY="hey-michal" pm2 start app.py --interpreter python3 --name a
43
 
44
  basic testing
45
  ```
46
- EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
47
  ```
48
  add --mock to test the echo stream
49
  """
50
 
51
- EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
52
-
53
- validator = None
54
- reward_models = {
55
- 'date_qa': DateRewardModel(),
56
- 'math': FloatDiffModel(),
57
- }
58
-
59
- def completion_is_valid(completion: str):
60
- """
61
- Get the completion statuses from the completions.
62
- """
63
- patt = re.compile(r'I\'m sorry|unable to|I cannot|I can\'t|I am unable|I am sorry|I can not|don\'t know|not sure|don\'t understand')
64
- if not len(re.findall(r'\w+',completion)) or patt.search(completion):
65
- return False
66
- return True
67
-
68
-
69
- def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
70
- """
71
- Ensemble completions from multiple models.
72
- # TODO: Measure agreement
73
- # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
74
- # TODO: Reward pipeline
75
- """
76
- if not completions:
77
- return None
78
-
79
-
80
- answer = None
81
- if task_name in ('qa', 'summarization'):
82
- # No special handling for QA or summarization
83
- supporting_completions = completions
84
-
85
- elif task_name == 'date_qa':
86
- # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
87
- dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
88
- bt.logging.info(f"Unprocessed dates: {dates}")
89
- valid_date_indices = [i for i, d in enumerate(dates) if d]
90
- valid_completions = [completions[i] for i in valid_date_indices]
91
- valid_dates = [dates[i] for i in valid_date_indices]
92
- dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
93
- if not dates:
94
- return None
95
-
96
- counter = Counter(dates)
97
- most_common, count = counter.most_common()[0]
98
- answer = most_common
99
- if count == 1:
100
- supporting_completions = valid_completions
101
- else:
102
- supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common]
103
-
104
- elif task_name == 'math':
105
- # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
106
- # TODO: use the median instead of the most common value
107
- vals = list(map(reward_models[task_name].extract_number, completions))
108
- vals = [val for val in vals if val]
109
- if not vals:
110
- return None
111
-
112
- most_common, count = Counter(dates).most_common()[0]
113
- bt.logging.info(f"Most common value: {most_common}, count: {count}")
114
- answer = most_common
115
- if count == 1:
116
- supporting_completions = completions
117
- else:
118
- supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common]
119
-
120
-
121
- bt.logging.info(f"Supporting completions: {supporting_completions}")
122
- if prefer == 'longest':
123
- preferred_completion = sorted(supporting_completions, key=len)[-1]
124
- elif prefer == 'shortest':
125
- preferred_completion = sorted(supporting_completions, key=len)[0]
126
- elif prefer == 'most_common':
127
- preferred_completion = max(set(supporting_completions), key=supporting_completions.count)
128
- else:
129
- raise ValueError(f"Unknown ensemble preference: {prefer}")
130
-
131
- return {
132
- 'completion': preferred_completion,
133
- 'accepted_answer': answer,
134
- 'support': len(supporting_completions),
135
- 'support_indices': [completions.index(c) for c in supporting_completions],
136
- 'method': f'Selected the {prefer.replace("_", " ")} completion'
137
- }
138
-
139
- def guess_task_name(challenge: str):
140
- categories = {
141
- 'summarization': re.compile('summar|quick rundown|overview'),
142
- 'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'),
143
- 'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'),
144
- }
145
- for task_name, patt in categories.items():
146
- if patt.search(challenge):
147
- return task_name
148
-
149
- return 'qa'
150
 
151
  async def chat(request: web.Request) -> Response:
152
  """
153
  Chat endpoint for the validator.
154
-
155
- Required headers:
156
- - api_key: The access key for the validator.
157
-
158
- Required body:
159
- - roles: The list of roles to query.
160
- - messages: The list of messages to query.
161
- Optional body:
162
- - k: The number of nodes to query.
163
- - exclude: The list of nodes to exclude from the query.
164
- - timeout: The timeout for the query.
165
  """
 
166
 
167
- bt.logging.info(f'chat()')
168
- # Check access key
169
- access_key = request.headers.get("api_key")
170
- if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
171
- bt.logging.error(f'Invalid access key: {access_key}')
172
- return Response(status=401, reason="Invalid access key")
173
-
174
- try:
175
- request_data = await request.json()
176
- except ValueError:
177
- bt.logging.error(f'Invalid request data: {request_data}')
178
- return Response(status=400)
179
-
180
- bt.logging.info(f'Request data: {request_data}')
181
- k = request_data.get('k', 10)
182
- exclude = request_data.get('exclude', [])
183
- timeout = request_data.get('timeout', 10)
184
- prefer = request_data.get('prefer', 'longest')
185
- try:
186
- # Guess the task name of current request
187
- task_name = guess_task_name(request_data['messages'][-1])
188
-
189
- # Get the list of uids to query for this step.
190
- uids = get_random_uids(validator, k=k, exclude=exclude or []).to(validator.device)
191
- axons = [validator.metagraph.axons[uid] for uid in uids]
192
-
193
- # Make calls to the network with the prompt.
194
- bt.logging.info(f'Calling dendrite')
195
- responses = await validator.dendrite(
196
- axons=axons,
197
- synapse=PromptingSynapse(roles=request_data['roles'], messages=request_data['messages']),
198
- timeout=timeout,
199
- )
200
 
201
- bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
202
- # Encapsulate the responses in a response event (dataclass)
203
- response_event = DendriteResponseEvent(responses, uids)
204
-
205
- # convert dict to json
206
- response = response_event.__state_dict__()
207
-
208
- response['completion_is_valid'] = valid = list(map(completion_is_valid, response['completions']))
209
- valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
210
-
211
- response['task_name'] = task_name
212
- response['ensemble_result'] = ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
213
-
214
- bt.logging.info(f"Response:\n {response}")
215
- return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
216
-
217
- except Exception:
218
- bt.logging.error(f'Encountered in {chat.__name__}:\n{traceback.format_exc()}')
219
- return Response(status=500, reason="Internal error")
220
-
221
-
222
-
223
- async def echo_stream(request):
224
-
225
- bt.logging.info(f'echo_stream()')
226
- # Check access key
227
- access_key = request.headers.get("api_key")
228
- if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
229
- bt.logging.error(f'Invalid access key: {access_key}')
230
- return Response(status=401, reason="Invalid access key")
231
-
232
- try:
233
- request_data = await request.json()
234
- except ValueError:
235
- bt.logging.error(f'Invalid request data: {request_data}')
236
- return Response(status=400)
237
-
238
- bt.logging.info(f'Request data: {request_data}')
239
- k = request_data.get('k', 1)
240
- exclude = request_data.get('exclude', [])
241
- timeout = request_data.get('timeout', 0.2)
242
- message = '\n\n'.join(request_data['messages'])
243
-
244
- # Create a StreamResponse
245
- response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'text/plain'})
246
- await response.prepare(request)
247
-
248
- completion = ''
249
- # Echo the message k times with a timeout between each chunk
250
- for _ in range(k):
251
- for word in message.split():
252
- chunk = f'{word} '
253
- await response.write(chunk.encode('utf-8'))
254
- completion += chunk
255
- time.sleep(timeout)
256
- bt.logging.info(f"Echoed: {chunk}")
257
 
258
- completion = completion.strip()
259
 
260
- # Prepare final JSON chunk
261
- json_chunk = json.dumps({
262
- "uids": [0],
263
- "completion": completion,
264
- "completions": [completion.strip()],
265
- "timings": [0],
266
- "status_messages": ['Went well!'],
267
- "status_codes": [200],
268
- "completion_is_valid": [True],
269
- "task_name": 'echo',
270
- "ensemble_result": {}
271
- })
272
-
273
- # Send the final JSON as part of the stream
274
- await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8'))
275
 
276
- # Finalize the response
277
- await response.write_eof()
278
- return response
279
 
280
  class ValidatorApplication(web.Application):
281
- def __init__(self, *a, **kw):
282
- super().__init__(*a, **kw)
283
- # TODO: Enable rewarding and other features
284
 
 
 
 
285
 
286
- validator_app = ValidatorApplication()
287
- validator_app.add_routes([
288
- web.post('/chat/', chat),
289
- web.post('/echo/', echo_stream)
290
- ])
291
 
292
- bt.logging.info("Starting validator application.")
293
- bt.logging.info(validator_app)
 
294
 
295
 
296
  def main(run_aio_app=True, test=False) -> None:
297
-
298
  loop = asyncio.get_event_loop()
299
-
300
- # port = validator.metagraph.axons[validator.uid].port
301
  port = 10000
302
  if run_aio_app:
 
 
 
 
 
303
  try:
304
  web.run_app(validator_app, port=port, loop=loop)
305
  except KeyboardInterrupt:
306
- bt.logging.info("Keyboard interrupt detected. Exiting validator.")
307
  finally:
308
  pass
309
 
 
310
  if __name__ == "__main__":
311
- validator = Validator()
312
  main()
 
 
 
 
 
 
 
1
  import asyncio
2
+ import utils
 
3
  import bittensor as bt
 
 
 
 
 
 
 
 
4
  from aiohttp import web
5
  from aiohttp.web_response import Response
6
+ from validators import S1ValidatorAPI, QueryValidatorParams, ValidatorAPI
7
+ from middlewares import api_key_middleware, json_parsing_middleware
8
 
9
  """
10
  # test
 
30
 
31
  basic testing
32
  ```
33
+ EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
34
  ```
35
  add --mock to test the echo stream
36
  """
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  async def chat(request: web.Request) -> Response:
40
  """
41
  Chat endpoint for the validator.
 
 
 
 
 
 
 
 
 
 
 
42
  """
43
+ params = QueryValidatorParams.from_request(request)
44
 
45
+ # Access the validator from the application context
46
+ validator: ValidatorAPI = request.app["validator"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ response = await validator.query_validator(params)
49
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
 
51
 
52
+ async def echo_stream(request, request_data):
53
+ request_data = request["data"]
54
+ return await utils.echo_stream(request_data)
 
 
 
 
 
 
 
 
 
 
 
 
55
 
 
 
 
56
 
57
  class ValidatorApplication(web.Application):
58
+ def __init__(self, validator_instance=None, *args, **kwargs):
59
+ super().__init__(*args, **kwargs)
 
60
 
61
+ self["validator"] = (
62
+ validator_instance if validator_instance else S1ValidatorAPI()
63
+ )
64
 
65
+ # Add middlewares to application
66
+ self.add_routes([web.post("/chat/", chat), web.post("/echo/", echo_stream)])
67
+ self.setup_middlewares()
68
+ # TODO: Enable rewarding and other features
 
69
 
70
+ def setup_middlewares(self):
71
+ self.middlewares.append(json_parsing_middleware)
72
+ self.middlewares.append(api_key_middleware)
73
 
74
 
75
  def main(run_aio_app=True, test=False) -> None:
 
76
  loop = asyncio.get_event_loop()
 
 
77
  port = 10000
78
  if run_aio_app:
79
+ # Instantiate the application with the actual validator
80
+ bt.logging.info("Starting validator application.")
81
+ validator_app = ValidatorApplication()
82
+ bt.logging.success(f"Validator app initialized successfully", validator_app)
83
+
84
  try:
85
  web.run_app(validator_app, port=port, loop=loop)
86
  except KeyboardInterrupt:
87
+ print("Keyboard interrupt detected. Exiting validator.")
88
  finally:
89
  pass
90
 
91
+
92
  if __name__ == "__main__":
 
93
  main()
test.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+
4
+ def test_query_network():
5
+ pass
6
+
7
+
8
+ def test_filter_completions():
9
+ pass
10
+
11
+
12
+ def test_guess_task_name():
13
+ pass
14
+
15
+
16
+ def test_ensemble_completions():
17
+ pass
utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import bittensor as bt
3
+ import time
4
+ import json
5
+ from aiohttp import web
6
+ from collections import Counter
7
+ from prompting.rewards import DateRewardModel, FloatDiffModel
8
+
9
+ UNSUCCESSFUL_RESPONSE_PATTERNS = [
10
+ "I'm sorry",
11
+ "unable to",
12
+ "I cannot",
13
+ "I can't",
14
+ "I am unable",
15
+ "I am sorry",
16
+ "I can not",
17
+ "don't know",
18
+ "not sure",
19
+ "don't understand",
20
+ "not capable",
21
+ ]
22
+
23
+ reward_models = {
24
+ "date_qa": DateRewardModel(),
25
+ "math": FloatDiffModel(),
26
+ }
27
+
28
+
29
+ def completion_is_valid(completion: str):
30
+ """
31
+ Get the completion statuses from the completions.
32
+ """
33
+ if not completion.strip():
34
+ return False
35
+
36
+ patt = re.compile(
37
+ r"\b(?:" + "|".join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r")\b", re.IGNORECASE
38
+ )
39
+ if not len(re.findall(r"\w+", completion)) or patt.search(completion):
40
+ return False
41
+ return True
42
+
43
+
44
+ def ensemble_result(completions: list, task_name: str, prefer: str = "longest"):
45
+ """
46
+ Ensemble completions from multiple models.
47
+ # TODO: Measure agreement
48
+ # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
49
+ # TODO: Reward pipeline
50
+ """
51
+ if not completions:
52
+ return None
53
+
54
+ answer = None
55
+ if task_name in ("qa", "summarization"):
56
+ # No special handling for QA or summarization
57
+ supporting_completions = completions
58
+
59
+ elif task_name == "date_qa":
60
+ # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
61
+ dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
62
+ bt.logging.info(f"Unprocessed dates: {dates}")
63
+ valid_date_indices = [i for i, d in enumerate(dates) if d]
64
+ valid_completions = [completions[i] for i in valid_date_indices]
65
+ valid_dates = [dates[i] for i in valid_date_indices]
66
+ dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
67
+ if not dates:
68
+ return None
69
+
70
+ counter = Counter(dates)
71
+ most_common, count = counter.most_common()[0]
72
+ answer = most_common
73
+ if count == 1:
74
+ supporting_completions = valid_completions
75
+ else:
76
+ supporting_completions = [
77
+ c for i, c in enumerate(valid_completions) if dates[i] == most_common
78
+ ]
79
+
80
+ elif task_name == "math":
81
+ # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
82
+ # TODO: use the median instead of the most common value
83
+ vals = list(map(reward_models[task_name].extract_number, completions))
84
+ vals = [val for val in vals if val]
85
+ if not vals:
86
+ return None
87
+
88
+ most_common, count = Counter(dates).most_common()[0]
89
+ bt.logging.info(f"Most common value: {most_common}, count: {count}")
90
+ answer = most_common
91
+ if count == 1:
92
+ supporting_completions = completions
93
+ else:
94
+ supporting_completions = [
95
+ c for i, c in enumerate(completions) if vals[i] == most_common
96
+ ]
97
+
98
+ bt.logging.info(f"Supporting completions: {supporting_completions}")
99
+ if prefer == "longest":
100
+ preferred_completion = sorted(supporting_completions, key=len)[-1]
101
+ elif prefer == "shortest":
102
+ preferred_completion = sorted(supporting_completions, key=len)[0]
103
+ elif prefer == "most_common":
104
+ preferred_completion = max(
105
+ set(supporting_completions), key=supporting_completions.count
106
+ )
107
+ else:
108
+ raise ValueError(f"Unknown ensemble preference: {prefer}")
109
+
110
+ return {
111
+ "completion": preferred_completion,
112
+ "accepted_answer": answer,
113
+ "support": len(supporting_completions),
114
+ "support_indices": [completions.index(c) for c in supporting_completions],
115
+ "method": f'Selected the {prefer.replace("_", " ")} completion',
116
+ }
117
+
118
+
119
+ def guess_task_name(challenge: str):
120
+ # TODO: use a pre-trained classifier to guess the task name
121
+ categories = {
122
+ "summarization": re.compile("summar|quick rundown|overview"),
123
+ "date_qa": re.compile(
124
+ "exact date|tell me when|on what date|on what day|was born?|died?"
125
+ ),
126
+ "math": re.compile(
127
+ "math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial"
128
+ ),
129
+ }
130
+ for task_name, patt in categories.items():
131
+ if patt.search(challenge):
132
+ return task_name
133
+
134
+ return "qa"
135
+
136
+
137
+ async def echo_stream(request_data: dict):
138
+ k = request_data.get("k", 1)
139
+ exclude = request_data.get("exclude", [])
140
+ timeout = request_data.get("timeout", 0.2)
141
+ message = "\n\n".join(request_data["messages"])
142
+
143
+ # Create a StreamResponse
144
+ response = web.StreamResponse(
145
+ status=200, reason="OK", headers={"Content-Type": "text/plain"}
146
+ )
147
+ await response.prepare()
148
+
149
+ completion = ""
150
+ # Echo the message k times with a timeout between each chunk
151
+ for _ in range(k):
152
+ for word in message.split():
153
+ chunk = f"{word} "
154
+ await response.write(chunk.encode("utf-8"))
155
+ completion += chunk
156
+ time.sleep(timeout)
157
+ bt.logging.info(f"Echoed: {chunk}")
158
+
159
+ completion = completion.strip()
160
+
161
+ # Prepare final JSON chunk
162
+ json_chunk = json.dumps(
163
+ {
164
+ "uids": [0],
165
+ "completion": completion,
166
+ "completions": [completion.strip()],
167
+ "timings": [0],
168
+ "status_messages": ["Went well!"],
169
+ "status_codes": [200],
170
+ "completion_is_valid": [True],
171
+ "task_name": "echo",
172
+ "ensemble_result": {},
173
+ }
174
+ )
175
+
176
+ # Send the final JSON as part of the stream
177
+ await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode("utf-8"))
178
+
179
+ # Finalize the response
180
+ await response.write_eof()
181
+ return response
validators/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .base import QueryValidatorParams, ValidatorAPI, MockValidator
2
+ from .sn1_validator_wrapper import S1ValidatorAPI
validators/base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List
3
+ from dataclasses import dataclass
4
+ from aiohttp.web import Response, Request
5
+
6
+
7
+ @dataclass
8
+ class QueryValidatorParams:
9
+ k_miners: int
10
+ exclude: List[str]
11
+ roles: List[str]
12
+ messages: List[str]
13
+ timeout: int
14
+ prefer: str
15
+ request: Request
16
+
17
+ @staticmethod
18
+ def from_request(request: Request):
19
+ data = request["data"]
20
+
21
+ return QueryValidatorParams(
22
+ k_miners=data.get("k", 10),
23
+ exclude=data.get("exclude", []),
24
+ roles=data["roles"],
25
+ messages=data["messages"],
26
+ timeout=data.get("timeout", 10),
27
+ prefer=data.get("prefer", "longest"),
28
+ request=request,
29
+ )
30
+
31
+
32
+ class ValidatorAPI(ABC):
33
+ @abstractmethod
34
+ async def query_validator(self, params: QueryValidatorParams) -> Response:
35
+ pass
36
+
37
+
38
+ class MockValidator(ValidatorAPI):
39
+ async def query_validator(self, params: QueryValidatorParams) -> Response:
40
+ ...
validators/sn1_validator_wrapper.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import utils
3
+ import torch
4
+ import traceback
5
+ import asyncio
6
+ import bittensor as bt
7
+ from typing import Awaitable
8
+ from prompting.validator import Validator
9
+ from prompting.utils.uids import get_random_uids
10
+ from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
11
+ from prompting.dendrite import DendriteResponseEvent
12
+ from .base import QueryValidatorParams, ValidatorAPI
13
+ from aiohttp.web_response import Response, StreamResponse
14
+ from deprecated import deprecated
15
+
16
+
17
+ class S1ValidatorAPI(ValidatorAPI):
18
+ def __init__(self):
19
+ self.validator = Validator()
20
+
21
+ @deprecated(
22
+ reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead."
23
+ )
24
+ async def get_response(self, params: QueryValidatorParams) -> Response:
25
+ try:
26
+ # Guess the task name of current request
27
+ task_name = utils.guess_task_name(params.messages[-1])
28
+
29
+ # Get the list of uids to query for this step.
30
+ uids = get_random_uids(
31
+ self.validator, k=params.k_miners, exclude=params.exclude or []
32
+ ).tolist()
33
+ axons = [self.validator.metagraph.axons[uid] for uid in uids]
34
+
35
+ # Make calls to the network with the prompt.
36
+ bt.logging.info(f"Calling dendrite")
37
+ responses = await self.validator.dendrite(
38
+ axons=axons,
39
+ synapse=PromptingSynapse(roles=params.roles, messages=params.messages),
40
+ timeout=params.timeout,
41
+ )
42
+
43
+ bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
44
+ # Encapsulate the responses in a response event (dataclass)
45
+ response_event = DendriteResponseEvent(
46
+ responses, torch.LongTensor(uids), params.timeout
47
+ )
48
+
49
+ # convert dict to json
50
+ response = response_event.__state_dict__()
51
+
52
+ response["completion_is_valid"] = valid = list(
53
+ map(utils.completion_is_valid, response["completions"])
54
+ )
55
+ valid_completions = [
56
+ response["completions"][i] for i, v in enumerate(valid) if v
57
+ ]
58
+
59
+ response["task_name"] = task_name
60
+ response["ensemble_result"] = utils.ensemble_result(
61
+ valid_completions, task_name=task_name, prefer=params.prefer
62
+ )
63
+
64
+ bt.logging.info(f"Response:\n {response}")
65
+ return Response(
66
+ status=200,
67
+ reason="I can't believe it's not butter!",
68
+ text=json.dumps(response),
69
+ )
70
+
71
+ except Exception:
72
+ bt.logging.error(
73
+ f"Encountered in {self.__class__.__name__}:get_response:\n{traceback.format_exc()}"
74
+ )
75
+ return Response(status=500, reason="Internal error")
76
+
77
+ async def process_response(
78
+ self, response: StreamResponse, uid: int, async_generator: Awaitable
79
+ ):
80
+ """Process a single response asynchronously."""
81
+ try:
82
+ chunk = None # Initialize chunk with a default value
83
+ async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
84
+ bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")
85
+
86
+ # TODO: SET PROPER IMPLEMENTATION TO RETURN CHUNK
87
+ if chunk is not None:
88
+ json_data = json.dumps(chunk)
89
+ await response.write(json_data.encode("utf-8"))
90
+
91
+ except Exception as e:
92
+ bt.logging.error(
93
+ f"Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}"
94
+ )
95
+ response.set_status(500, reason="Internal error")
96
+ await response.write(json.dumps({"error": str(e)}).encode("utf-8"))
97
+ finally:
98
+ await response.write_eof() # Ensure to close the response properly
99
+
100
+ async def get_stream_response(self, params: QueryValidatorParams) -> StreamResponse:
101
+ response = StreamResponse(status=200, reason="OK")
102
+ response.headers["Content-Type"] = "application/json"
103
+
104
+ await response.prepare(params.request) # Prepare and send the headers
105
+
106
+ try:
107
+ # Guess the task name of current request
108
+ task_name = utils.guess_task_name(params.messages[-1])
109
+
110
+ # Get the list of uids to query for this step.
111
+ uids = get_random_uids(
112
+ self.validator, k=params.k_miners, exclude=params.exclude or []
113
+ ).tolist()
114
+ axons = [self.validator.metagraph.axons[uid] for uid in uids]
115
+
116
+ # Make calls to the network with the prompt.
117
+ bt.logging.info(f"Calling dendrite")
118
+ streams_responses = await self.validator.dendrite(
119
+ axons=axons,
120
+ synapse=StreamPromptingSynapse(
121
+ roles=params.roles, messages=params.messages
122
+ ),
123
+ timeout=params.timeout,
124
+ deserialize=False,
125
+ streaming=True,
126
+ )
127
+
128
+ tasks = [
129
+ self.process_response(uid, res)
130
+ for uid, res in dict(zip(uids, streams_responses))
131
+ ]
132
+ results = await asyncio.gather(*tasks, return_exceptions=True)
133
+
134
+ # TODO: Continue implementation, business decision needs to be made on how to handle the results
135
+ except Exception as e:
136
+ bt.logging.error(
137
+ f"Encountered an error in {self.__class__.__name__}:get_stream_response:\n{traceback.format_exc()}"
138
+ )
139
+ response.set_status(500, reason="Internal error")
140
+ await response.write(json.dumps({"error": str(e)}).encode("utf-8"))
141
+ finally:
142
+ await response.write_eof() # Ensure to close the response properly
143
+
144
+ return response
145
+
146
+ async def query_validator(self, params: QueryValidatorParams) -> Response:
147
+ # TODO: SET STREAM AS DEFAULT
148
+ stream = params.request.get("stream", False)
149
+
150
+ if stream:
151
+ return await self.get_stream_response(params)
152
+ else:
153
+ # DEPRECATED
154
+ return await self.get_response(params)