pedroferreira commited on
Commit
a415c67
·
1 Parent(s): ab063cf

minor integration adjustments

Browse files
middlewares.py CHANGED
@@ -1,34 +1,32 @@
1
  import os
2
  import json
3
  import bittensor as bt
4
- from aiohttp.web import Response
5
 
6
  EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
7
 
8
- async def api_key_middleware(app, handler):
9
- async def middleware_handler(request):
10
- # Logging the request
11
- bt.logging.info(f"Handling {request.method} request to {request.path}")
12
 
13
- # Check access key
14
- access_key = request.headers.get("api_key")
15
- if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
16
- bt.logging.error(f'Invalid access key: {access_key}')
17
- return Response(status=401, reason="Invalid access key")
18
 
19
- # Continue to the next handler if the API key is valid
20
- return await handler(request)
21
- return middleware_handler
22
 
23
- async def json_parsing_middleware(app, handler):
24
- async def middleware_handler(request):
25
- try:
26
- # Parsing JSON data from the request
27
- request['data'] = await request.json()
28
- except json.JSONDecodeError as e:
29
- bt.logging.error(f'Invalid JSON data: {str(e)}')
30
- return Response(status=400, text="Invalid JSON")
31
 
32
- # Continue to the next handler if JSON is successfully parsed
33
- return await handler(request)
34
- return middleware_handler
 
1
  import os
2
  import json
3
  import bittensor as bt
4
+ from aiohttp.web import Request, Response, middleware
5
 
6
  EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
7
 
8
+ @middleware
9
+ async def api_key_middleware(request: Request, handler):
10
+ # Logging the request
11
+ bt.logging.info(f"Handling {request.method} request to {request.path}")
12
 
13
+ # Check access key
14
+ access_key = request.headers.get("api_key")
15
+ if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
16
+ bt.logging.error(f'Invalid access key: {access_key}')
17
+ return Response(status=401, reason="Invalid access key")
18
 
19
+ # Continue to the next handler if the API key is valid
20
+ return await handler(request)
 
21
 
22
+ @middleware
23
+ async def json_parsing_middleware(request: Request, handler):
24
+ try:
25
+ # Parsing JSON data from the request
26
+ request['data'] = await request.json()
27
+ except json.JSONDecodeError as e:
28
+ bt.logging.error(f'Invalid JSON data: {str(e)}')
29
+ return Response(status=400, text="Invalid JSON")
30
 
31
+ # Continue to the next handler if JSON is successfully parsed
32
+ return await handler(request)
 
server.py CHANGED
@@ -34,8 +34,6 @@ EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.n
34
  ```
35
  add --mock to test the echo stream
36
  """
37
- @api_key_middleware
38
- @json_parsing_middleware
39
  async def chat(request: web.Request) -> Response:
40
  """
41
  Chat endpoint for the validator.
@@ -43,7 +41,7 @@ async def chat(request: web.Request) -> Response:
43
  request_data = request['data']
44
  params = QueryValidatorParams.from_dict(request_data)
45
  # TODO: SET STREAM AS DEFAULT
46
- stream = request_data.get('stream', False)
47
 
48
  # Access the validator from the application context
49
  validator: ValidatorAPI = request.app['validator']
@@ -52,29 +50,29 @@ async def chat(request: web.Request) -> Response:
52
  return response
53
 
54
 
55
- @api_key_middleware
56
- @json_parsing_middleware
57
  async def echo_stream(request, request_data):
58
  request_data = request['data']
59
  return await utils.echo_stream(request_data)
60
 
61
 
 
62
  class ValidatorApplication(web.Application):
63
  def __init__(self, validator_instance=None, *args, **kwargs):
64
  super().__init__(*args, **kwargs)
65
 
66
  self['validator'] = validator_instance if validator_instance else S1ValidatorAPI()
67
 
68
- # Add middlewares to application
69
- self.middlewares.append(api_key_middleware)
70
- self.middlewares.append(json_parsing_middleware)
71
-
72
  self.add_routes([
73
  web.post('/chat/', chat),
74
  web.post('/echo/', echo_stream)
75
  ])
 
76
  # TODO: Enable rewarding and other features
77
 
 
 
 
78
 
79
  def main(run_aio_app=True, test=False) -> None:
80
  loop = asyncio.get_event_loop()
 
34
  ```
35
  add --mock to test the echo stream
36
  """
 
 
37
  async def chat(request: web.Request) -> Response:
38
  """
39
  Chat endpoint for the validator.
 
41
  request_data = request['data']
42
  params = QueryValidatorParams.from_dict(request_data)
43
  # TODO: SET STREAM AS DEFAULT
44
+ stream = request_data.get('stream', True)
45
 
46
  # Access the validator from the application context
47
  validator: ValidatorAPI = request.app['validator']
 
50
  return response
51
 
52
 
 
 
53
  async def echo_stream(request, request_data):
54
  request_data = request['data']
55
  return await utils.echo_stream(request_data)
56
 
57
 
58
+
59
  class ValidatorApplication(web.Application):
60
  def __init__(self, validator_instance=None, *args, **kwargs):
61
  super().__init__(*args, **kwargs)
62
 
63
  self['validator'] = validator_instance if validator_instance else S1ValidatorAPI()
64
 
65
+ # Add middlewares to application
 
 
 
66
  self.add_routes([
67
  web.post('/chat/', chat),
68
  web.post('/echo/', echo_stream)
69
  ])
70
+ self.setup_middlewares()
71
  # TODO: Enable rewarding and other features
72
 
73
+ def setup_middlewares(self):
74
+ self.middlewares.append(json_parsing_middleware)
75
+ self.middlewares.append(api_key_middleware)
76
 
77
  def main(run_aio_app=True, test=False) -> None:
78
  loop = asyncio.get_event_loop()
validators/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from base import QueryValidatorParams, ValidatorAPI, MockValidator
2
- from sn1_validator_wrapper import S1ValidatorAPI
 
1
+ from .base import QueryValidatorParams, ValidatorAPI, MockValidator
2
+ from .sn1_validator_wrapper import S1ValidatorAPI
validators/sn1_validator_wrapper.py CHANGED
@@ -1,14 +1,13 @@
1
  import json
2
  import utils
 
3
  import traceback
4
  import bittensor as bt
5
- import asyncio
6
- from prompting.forward import handle_response
7
  from prompting.validator import Validator
8
  from prompting.utils.uids import get_random_uids
9
  from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
10
  from prompting.dendrite import DendriteResponseEvent
11
- from base import QueryValidatorParams, ValidatorAPI
12
  from aiohttp.web_response import Response, StreamResponse
13
  from deprecated import deprecated
14
 
@@ -16,7 +15,7 @@ class S1ValidatorAPI(ValidatorAPI):
16
  def __init__(self):
17
  self.validator = Validator()
18
 
19
-
20
  @deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
21
  async def get_response(self, params:QueryValidatorParams) -> Response:
22
  try:
@@ -37,7 +36,7 @@ class S1ValidatorAPI(ValidatorAPI):
37
 
38
  bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
39
  # Encapsulate the responses in a response event (dataclass)
40
- response_event = DendriteResponseEvent(responses, uids)
41
 
42
  # convert dict to json
43
  response = response_event.__state_dict__()
 
1
  import json
2
  import utils
3
+ import torch
4
  import traceback
5
  import bittensor as bt
 
 
6
  from prompting.validator import Validator
7
  from prompting.utils.uids import get_random_uids
8
  from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
9
  from prompting.dendrite import DendriteResponseEvent
10
+ from .base import QueryValidatorParams, ValidatorAPI
11
  from aiohttp.web_response import Response, StreamResponse
12
  from deprecated import deprecated
13
 
 
15
  def __init__(self):
16
  self.validator = Validator()
17
 
18
+
19
  @deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
20
  async def get_response(self, params:QueryValidatorParams) -> Response:
21
  try:
 
36
 
37
  bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
38
  # Encapsulate the responses in a response event (dataclass)
39
+ response_event = DendriteResponseEvent(responses, torch.LongTensor(uids), params.timeout)
40
 
41
  # convert dict to json
42
  response = response_event.__state_dict__()