Spaces:
Sleeping
Sleeping
Commit
·
a415c67
1
Parent(s):
ab063cf
minor integration adjustments
Browse files- middlewares.py +22 -24
- server.py +7 -9
- validators/__init__.py +2 -2
- validators/sn1_validator_wrapper.py +4 -5
middlewares.py
CHANGED
@@ -1,34 +1,32 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
import bittensor as bt
|
4 |
-
from aiohttp.web import Response
|
5 |
|
6 |
EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
return middleware_handler
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
return middleware_handler
|
|
|
1 |
import os
|
2 |
import json
|
3 |
import bittensor as bt
|
4 |
+
from aiohttp.web import Request, Response, middleware
|
5 |
|
6 |
EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
|
7 |
|
8 |
+
@middleware
|
9 |
+
async def api_key_middleware(request: Request, handler):
|
10 |
+
# Logging the request
|
11 |
+
bt.logging.info(f"Handling {request.method} request to {request.path}")
|
12 |
|
13 |
+
# Check access key
|
14 |
+
access_key = request.headers.get("api_key")
|
15 |
+
if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
|
16 |
+
bt.logging.error(f'Invalid access key: {access_key}')
|
17 |
+
return Response(status=401, reason="Invalid access key")
|
18 |
|
19 |
+
# Continue to the next handler if the API key is valid
|
20 |
+
return await handler(request)
|
|
|
21 |
|
22 |
+
@middleware
|
23 |
+
async def json_parsing_middleware(request: Request, handler):
|
24 |
+
try:
|
25 |
+
# Parsing JSON data from the request
|
26 |
+
request['data'] = await request.json()
|
27 |
+
except json.JSONDecodeError as e:
|
28 |
+
bt.logging.error(f'Invalid JSON data: {str(e)}')
|
29 |
+
return Response(status=400, text="Invalid JSON")
|
30 |
|
31 |
+
# Continue to the next handler if JSON is successfully parsed
|
32 |
+
return await handler(request)
|
|
server.py
CHANGED
@@ -34,8 +34,6 @@ EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.n
|
|
34 |
```
|
35 |
add --mock to test the echo stream
|
36 |
"""
|
37 |
-
@api_key_middleware
|
38 |
-
@json_parsing_middleware
|
39 |
async def chat(request: web.Request) -> Response:
|
40 |
"""
|
41 |
Chat endpoint for the validator.
|
@@ -43,7 +41,7 @@ async def chat(request: web.Request) -> Response:
|
|
43 |
request_data = request['data']
|
44 |
params = QueryValidatorParams.from_dict(request_data)
|
45 |
# TODO: SET STREAM AS DEFAULT
|
46 |
-
stream = request_data.get('stream',
|
47 |
|
48 |
# Access the validator from the application context
|
49 |
validator: ValidatorAPI = request.app['validator']
|
@@ -52,29 +50,29 @@ async def chat(request: web.Request) -> Response:
|
|
52 |
return response
|
53 |
|
54 |
|
55 |
-
@api_key_middleware
|
56 |
-
@json_parsing_middleware
|
57 |
async def echo_stream(request, request_data):
|
58 |
request_data = request['data']
|
59 |
return await utils.echo_stream(request_data)
|
60 |
|
61 |
|
|
|
62 |
class ValidatorApplication(web.Application):
|
63 |
def __init__(self, validator_instance=None, *args, **kwargs):
|
64 |
super().__init__(*args, **kwargs)
|
65 |
|
66 |
self['validator'] = validator_instance if validator_instance else S1ValidatorAPI()
|
67 |
|
68 |
-
# Add middlewares to application
|
69 |
-
self.middlewares.append(api_key_middleware)
|
70 |
-
self.middlewares.append(json_parsing_middleware)
|
71 |
-
|
72 |
self.add_routes([
|
73 |
web.post('/chat/', chat),
|
74 |
web.post('/echo/', echo_stream)
|
75 |
])
|
|
|
76 |
# TODO: Enable rewarding and other features
|
77 |
|
|
|
|
|
|
|
78 |
|
79 |
def main(run_aio_app=True, test=False) -> None:
|
80 |
loop = asyncio.get_event_loop()
|
|
|
34 |
```
|
35 |
add --mock to test the echo stream
|
36 |
"""
|
|
|
|
|
37 |
async def chat(request: web.Request) -> Response:
|
38 |
"""
|
39 |
Chat endpoint for the validator.
|
|
|
41 |
request_data = request['data']
|
42 |
params = QueryValidatorParams.from_dict(request_data)
|
43 |
# TODO: SET STREAM AS DEFAULT
|
44 |
+
stream = request_data.get('stream', True)
|
45 |
|
46 |
# Access the validator from the application context
|
47 |
validator: ValidatorAPI = request.app['validator']
|
|
|
50 |
return response
|
51 |
|
52 |
|
|
|
|
|
53 |
async def echo_stream(request, request_data):
|
54 |
request_data = request['data']
|
55 |
return await utils.echo_stream(request_data)
|
56 |
|
57 |
|
58 |
+
|
59 |
class ValidatorApplication(web.Application):
|
60 |
def __init__(self, validator_instance=None, *args, **kwargs):
|
61 |
super().__init__(*args, **kwargs)
|
62 |
|
63 |
self['validator'] = validator_instance if validator_instance else S1ValidatorAPI()
|
64 |
|
65 |
+
# Add middlewares to application
|
|
|
|
|
|
|
66 |
self.add_routes([
|
67 |
web.post('/chat/', chat),
|
68 |
web.post('/echo/', echo_stream)
|
69 |
])
|
70 |
+
self.setup_middlewares()
|
71 |
# TODO: Enable rewarding and other features
|
72 |
|
73 |
+
def setup_middlewares(self):
|
74 |
+
self.middlewares.append(json_parsing_middleware)
|
75 |
+
self.middlewares.append(api_key_middleware)
|
76 |
|
77 |
def main(run_aio_app=True, test=False) -> None:
|
78 |
loop = asyncio.get_event_loop()
|
validators/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
from base import QueryValidatorParams, ValidatorAPI, MockValidator
|
2 |
-
from sn1_validator_wrapper import S1ValidatorAPI
|
|
|
1 |
+
from .base import QueryValidatorParams, ValidatorAPI, MockValidator
|
2 |
+
from .sn1_validator_wrapper import S1ValidatorAPI
|
validators/sn1_validator_wrapper.py
CHANGED
@@ -1,14 +1,13 @@
|
|
1 |
import json
|
2 |
import utils
|
|
|
3 |
import traceback
|
4 |
import bittensor as bt
|
5 |
-
import asyncio
|
6 |
-
from prompting.forward import handle_response
|
7 |
from prompting.validator import Validator
|
8 |
from prompting.utils.uids import get_random_uids
|
9 |
from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
|
10 |
from prompting.dendrite import DendriteResponseEvent
|
11 |
-
from base import QueryValidatorParams, ValidatorAPI
|
12 |
from aiohttp.web_response import Response, StreamResponse
|
13 |
from deprecated import deprecated
|
14 |
|
@@ -16,7 +15,7 @@ class S1ValidatorAPI(ValidatorAPI):
|
|
16 |
def __init__(self):
|
17 |
self.validator = Validator()
|
18 |
|
19 |
-
|
20 |
@deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
|
21 |
async def get_response(self, params:QueryValidatorParams) -> Response:
|
22 |
try:
|
@@ -37,7 +36,7 @@ class S1ValidatorAPI(ValidatorAPI):
|
|
37 |
|
38 |
bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
|
39 |
# Encapsulate the responses in a response event (dataclass)
|
40 |
-
response_event = DendriteResponseEvent(responses, uids)
|
41 |
|
42 |
# convert dict to json
|
43 |
response = response_event.__state_dict__()
|
|
|
1 |
import json
|
2 |
import utils
|
3 |
+
import torch
|
4 |
import traceback
|
5 |
import bittensor as bt
|
|
|
|
|
6 |
from prompting.validator import Validator
|
7 |
from prompting.utils.uids import get_random_uids
|
8 |
from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
|
9 |
from prompting.dendrite import DendriteResponseEvent
|
10 |
+
from .base import QueryValidatorParams, ValidatorAPI
|
11 |
from aiohttp.web_response import Response, StreamResponse
|
12 |
from deprecated import deprecated
|
13 |
|
|
|
15 |
def __init__(self):
|
16 |
self.validator = Validator()
|
17 |
|
18 |
+
|
19 |
@deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
|
20 |
async def get_response(self, params:QueryValidatorParams) -> Response:
|
21 |
try:
|
|
|
36 |
|
37 |
bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
|
38 |
# Encapsulate the responses in a response event (dataclass)
|
39 |
+
response_event = DendriteResponseEvent(responses, torch.LongTensor(uids), params.timeout)
|
40 |
|
41 |
# convert dict to json
|
42 |
response = response_event.__state_dict__()
|