File size: 2,948 Bytes
82b3169
c4957f4
82b3169
 
df1e6f4
 
 
 
 
 
 
82b3169
151d1dc
e842043
60fe698
c4957f4
fdc8fdb
65526b0
df1e6f4
60fe698
65526b0
 
7ea7f29
65526b0
fdc8fdb
1e9def1
ab063cf
fdc8fdb
 
1e9def1
728d41d
82b3169
48f02b6
65526b0
 
 
df1e6f4
65526b0
60fe698
65526b0
 
1b0bdb5
7ea7f29
32e1e2e
ab063cf
82b3169
ab063cf
 
fdc8fdb
df1e6f4
 
 
fdc8fdb
 
65526b0
 
 
 
 
 
 
a415c67
82b3169
fdc8fdb
a415c67
65526b0
a415c67
 
fdc8fdb
65526b0
 
 
 
 
 
 
 
fdc8fdb
82b3169
 
 
 
ab063cf
 
fdc8fdb
65526b0
fdc8fdb
82b3169
 
 
ab063cf
82b3169
 
 
fdc8fdb
82b3169
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import asyncio

import bittensor as bt
from aiohttp import web
from aiohttp_apispec import (
    docs,
    request_schema,
    response_schema,
    setup_aiohttp_apispec,
    validation_middleware,
)

from common import utils
from common.middlewares import api_key_middleware, json_parsing_middleware
from common.schemas import QueryChatSchema, StreamChunkSchema, StreamErrorSchema
from validators import QueryValidatorParams, S1ValidatorAPI, ValidatorAPI


@docs(tags=["Prompting API"], summary="Chat", description="Chat endpoint.")
@request_schema(QueryChatSchema)
@response_schema(StreamChunkSchema, 200)
@response_schema(StreamErrorSchema, 400)
async def chat(request: web.Request) -> web.StreamResponse:
    """Chat endpoint for the validator"""
    params = QueryValidatorParams.from_request(request)

    # Access the validator from the application context
    validator: ValidatorAPI = request.app["validator"]

    response = await validator.query_validator(params)
    return response


@docs(
    tags=["Prompting API"],
    summary="Echo test",
    description="Echo endpoint for testing purposes.",
)
@request_schema(QueryChatSchema)
@response_schema(StreamChunkSchema, 200)
@response_schema(StreamErrorSchema, 400)
async def echo_stream(request: web.Request) -> web.StreamResponse:
    return await utils.echo_stream(request)


class ValidatorApplication(web.Application):
    def __init__(self, validator_instance=None, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self["validator"] = (
            validator_instance if validator_instance else S1ValidatorAPI()
        )

        # Add middlewares to application
        self.add_routes(
            [
                web.post("/chat/", chat),
                web.post("/echo/", echo_stream),
            ]
        )
        self.setup_openapi()
        self.setup_middlewares()
        # TODO: Enable rewarding and other features

    def setup_middlewares(self):
        self.middlewares.append(validation_middleware)
        self.middlewares.append(json_parsing_middleware)
        self.middlewares.append(api_key_middleware)

    def setup_openapi(self):
        setup_aiohttp_apispec(
            app=self,
            title="Prompting API",
            url="/docs/swagger.json",
            swagger_path="/docs",
        )


def main(run_aio_app=True, test=False) -> None:
    loop = asyncio.get_event_loop()
    port = 10000
    if run_aio_app:
        # Instantiate the application with the actual validator
        bt.logging.info("Starting validator application.")
        validator_app = ValidatorApplication()
        bt.logging.success("Validator app initialized successfully", validator_app)

        try:
            web.run_app(validator_app, port=port, loop=loop)
        except KeyboardInterrupt:
            print("Keyboard interrupt detected. Exiting validator.")
        finally:
            pass


if __name__ == "__main__":
    main()