File size: 3,335 Bytes
04ffec9
 
 
 
 
 
 
 
 
 
 
 
05dbe5d
 
04ffec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05dbe5d
 
 
04ffec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from fastapi import FastAPI
from starlette.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from speakers.server.utils import MakeFastAPIOffline
from speakers.server.model.result import BaseResponse
from speakers.server.servlet.document import page_index, document
from speakers.server.servlet.runner import (submit_async,
                                            get_task_async,
                                            post_task_update_async,
                                            result_async)
from speakers.server.bootstrap.bootstrap_register import bootstrap_register
from speakers.server.bootstrap.base import Bootstrap
from speakers.common.registry import registry
from fastapi.staticfiles import StaticFiles
import uvicorn
import threading


@bootstrap_register.register_bootstrap("runner_bootstrap_web")
class RunnerBootstrapBaseWeb(Bootstrap):
    """
    Bootstrap Server Lifecycle
    """
    app: FastAPI
    server_thread: threading

    def __init__(self, host: str, port: int):
        super().__init__()

        self.host = host
        self.port = port

    @classmethod
    def from_config(cls, cfg=None):
        host = cfg.get("host")
        port = cfg.get("port")
        return cls(host=host, port=port)

    async def run(self):
        self.app = FastAPI(
            title="API Server",
            version=self.version
        )
        MakeFastAPIOffline(self.app)
        self.app.mount("/static",
                       StaticFiles(directory=f"{registry.get_path('server_library_root')}/static/static"),
                       name="static")
        # Add CORS middleware to allow all origins
        # 在config.py中设置OPEN_DOMAIN=True,允许跨域
        # set OPEN_DOMAIN=True in config.py to allow cross-domain
        self.app.add_middleware(
            CORSMiddleware,
            allow_origins=["*"],
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )

        self.app.get("/",
                     response_model=BaseResponse,
                     summary="演示首页")(page_index)
        self.app.get("/docs",
                     response_model=BaseResponse,
                     summary="swagger 文档")(document)
        self.app.post("/runner/submit",
                      tags=["Runner"],
                      summary="提交调度Runner")(submit_async)
        self.app.get("/runner/task-internal",
                     tags=["Runner"],
                     summary="内部获取调度Runner")(get_task_async)
        self.app.post("/runner/task-update-internal",
                      tags=["Runner"],
                      summary="内部同步调度RunnerStat")(post_task_update_async)
        self.app.get("/runner/result",
                     tags=["Runner"],
                     summary="获取任务结果")(result_async)
        app = self.app

        def run_server():
            uvicorn.run(app, host=self.host, port=self.port)

        self.server_thread = threading.Thread(target=run_server)
        self.server_thread.start()

    async def destroy(self):
        server_thread = self.server_thread
        app = self.app

        @app.on_event("shutdown")
        def shutdown_event():
            server_thread.join()  # 等待服务器线程结束