p-ferreira commited on
Commit
d7cdabb
·
1 Parent(s): c60daaf

adds initial wrapper for validator + requirements.txt

Browse files
Files changed (4) hide show
  1. README.md +15 -1
  2. requirements.txt +2 -0
  3. server.py +9 -30
  4. validator_wrapper.py +67 -0
README.md CHANGED
@@ -1,4 +1,18 @@
1
  # chattensor-backend
2
  Backend for Chattensor app
3
 
4
- To run, you will need a bittensor wallet which is registered to the relevant subnet (1@mainnet or 61@testnet).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # chattensor-backend
2
  Backend for Chattensor app
3
 
4
+ To run, you will need a bittensor wallet which is registered to the relevant subnet (1@mainnet or 61@testnet).
5
+
6
+
7
+
8
+
9
+ ## Install
10
+ Create a new python environment and install the dependencies with the command
11
+
12
+ ```bash
13
+ pip install -r requirements.txt
14
+ ```
15
+
16
+ > Note: Currently the prompting library is only installable on machines with cuda devices (NVIDIA-GPU).
17
+
18
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ git+https://github.com/opentensor/prompting.git@main
2
+ aiohttp
server.py CHANGED
@@ -8,13 +8,8 @@ import asyncio
8
  import json
9
  import traceback
10
  import bittensor as bt
11
-
12
  from collections import Counter
13
-
14
- from neurons.validator import Validator
15
- from prompting.dendrite import DendriteResponseEvent
16
- from prompting.protocol import PromptingSynapse
17
- from prompting.utils.uids import get_random_uids
18
  from prompting.rewards import DateRewardModel, FloatDiffModel
19
  from aiohttp import web
20
  from aiohttp.web_response import Response
@@ -177,31 +172,16 @@ async def chat(request: web.Request) -> Response:
177
  bt.logging.error(f'Invalid request data: {request_data}')
178
  return Response(status=400)
179
 
180
- bt.logging.info(f'Request data: {request_data}')
181
- k = request_data.get('k', 10)
182
- exclude = request_data.get('exclude', [])
183
- timeout = request_data.get('timeout', 10)
184
- prefer = request_data.get('prefer', 'longest')
185
  try:
186
  # Guess the task name of current request
187
  task_name = guess_task_name(request_data['messages'][-1])
188
-
189
  # Get the list of uids to query for this step.
190
- uids = get_random_uids(validator, k=k, exclude=exclude or []).to(validator.device)
191
- axons = [validator.metagraph.axons[uid] for uid in uids]
192
-
193
- # Make calls to the network with the prompt.
194
- bt.logging.info(f'Calling dendrite')
195
- responses = await validator.dendrite(
196
- axons=axons,
197
- synapse=PromptingSynapse(roles=request_data['roles'], messages=request_data['messages']),
198
- timeout=timeout,
199
- )
200
-
201
- bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
202
- # Encapsulate the responses in a response event (dataclass)
203
- response_event = DendriteResponseEvent(responses, uids)
204
-
205
  # convert dict to json
206
  response = response_event.__state_dict__()
207
 
@@ -209,11 +189,11 @@ async def chat(request: web.Request) -> Response:
209
  valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
210
 
211
  response['task_name'] = task_name
 
212
  response['ensemble_result'] = ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
213
 
214
  bt.logging.info(f"Response:\n {response}")
215
  return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
216
-
217
  except Exception:
218
  bt.logging.error(f'Encountered in {chat.__name__}:\n{traceback.format_exc()}')
219
  return Response(status=500, reason="Internal error")
@@ -294,7 +274,6 @@ bt.logging.info(validator_app)
294
 
295
 
296
  def main(run_aio_app=True, test=False) -> None:
297
-
298
  loop = asyncio.get_event_loop()
299
 
300
  # port = validator.metagraph.axons[validator.uid].port
@@ -308,5 +287,5 @@ def main(run_aio_app=True, test=False) -> None:
308
  pass
309
 
310
  if __name__ == "__main__":
311
- validator = Validator()
312
  main()
 
8
  import json
9
  import traceback
10
  import bittensor as bt
 
11
  from collections import Counter
12
+ from validator_wrapper import QueryValidatorParams, S1ValidatorWrapper
 
 
 
 
13
  from prompting.rewards import DateRewardModel, FloatDiffModel
14
  from aiohttp import web
15
  from aiohttp.web_response import Response
 
172
  bt.logging.error(f'Invalid request data: {request_data}')
173
  return Response(status=400)
174
 
175
+ bt.logging.info(f'Request data: {request_data}')
176
+
 
 
 
177
  try:
178
  # Guess the task name of current request
179
  task_name = guess_task_name(request_data['messages'][-1])
180
+
181
  # Get the list of uids to query for this step.
182
+ params = QueryValidatorParams.from_dict(request_data)
183
+ response_event = await validator.query_validator(params)
184
+
 
 
 
 
 
 
 
 
 
 
 
 
185
  # convert dict to json
186
  response = response_event.__state_dict__()
187
 
 
189
  valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
190
 
191
  response['task_name'] = task_name
192
+ prefer = request_data.get('prefer', 'longest')
193
  response['ensemble_result'] = ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
194
 
195
  bt.logging.info(f"Response:\n {response}")
196
  return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
 
197
  except Exception:
198
  bt.logging.error(f'Encountered in {chat.__name__}:\n{traceback.format_exc()}')
199
  return Response(status=500, reason="Internal error")
 
274
 
275
 
276
  def main(run_aio_app=True, test=False) -> None:
 
277
  loop = asyncio.get_event_loop()
278
 
279
  # port = validator.metagraph.axons[validator.uid].port
 
287
  pass
288
 
289
  if __name__ == "__main__":
290
+ validator = S1ValidatorWrapper()
291
  main()
validator_wrapper.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bittensor as bt
2
+ from neurons.validator import Validator
3
+ from prompting.utils.uids import get_random_uids
4
+ from prompting.protocol import PromptingSynapse
5
+ from prompting.dendrite import DendriteResponseEvent
6
+ from abc import ABC, abstractmethod
7
+ from typing import List
8
+ from dataclasses import dataclass
9
+
10
+ @dataclass
11
+ class QueryValidatorParams:
12
+ k_miners: int
13
+ exclude: List[str]
14
+ roles: List[str]
15
+ messages: List[str]
16
+ timeout: int
17
+
18
+ @staticmethod
19
+ def from_dict(data: dict):
20
+ return QueryValidatorParams(
21
+ k_miners=data.get('k', 10),
22
+ exclude=data.get('exclude', []),
23
+ roles=data['roles'],
24
+ messages=data['messages'],
25
+ timeout=data.get('timeout', 10)
26
+ )
27
+
28
+ class ValidatorWrapper(ABC):
29
+ @abstractmethod
30
+ async def query_validator(self, params:QueryValidatorParams):
31
+ pass
32
+
33
+
34
+ class S1ValidatorWrapper(ValidatorWrapper):
35
+ def __init__(self):
36
+ self.validator = Validator()
37
+
38
+ async def query_validator(self, params:QueryValidatorParams) -> DendriteResponseEvent:
39
+ # Get the list of uids to query for this step.
40
+ uids = get_random_uids(
41
+ self.validator,
42
+ k=params.k_miners,
43
+ exclude=params.exclude).to(self.validator.device)
44
+
45
+ axons = [self.validator.metagraph.axons[uid] for uid in uids]
46
+
47
+ # Make calls to the network with the prompt.
48
+ bt.logging.info(f'Calling dendrite')
49
+ responses = await self.validator.dendrite(
50
+ axons=axons,
51
+ synapse=PromptingSynapse(roles=params.roles, messages=params.request_data.messages),
52
+ timeout=params.timeout,
53
+ )
54
+
55
+ # Encapsulate the responses in a response event (dataclass)
56
+ bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
57
+ response_event = DendriteResponseEvent(responses, uids)
58
+ return response_event
59
+
60
+
61
+ class MockValidator(ValidatorWrapper):
62
+ def query_validator(self, query: str) -> bool:
63
+ return False
64
+
65
+
66
+
67
+