File size: 1,903 Bytes
424919d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Two server routes that OctoAI containers should have:
# a route for inference requests (e.g. ”/predict”). This route for inference requests must receive JSON inputs and JSON outputs.
# a route for health checks (e.g. ”/healthcheck”).
# Number of workers (not required). Typical best practice is to make this number some function of the # of CPU cores that the server has access to and should use.

"""HTTP Inference serving interface using sanic."""
import os

from custommodel import CustomModel
from sanic import Request, Sanic, response

_DEFAULT_PORT = 8000
"""Default port to serve inference on."""

# Load and initialize the model on startup globally, so it can be reused.
model_instance = CustomModel(ckpt='models/distil-240709-tot-model_epoch_20.pth')
"""Global instance of the model to serve."""

server = Sanic("server")
"""Global instance of the web server."""


@server.route("/healthcheck", methods=["GET"])
def healthcheck(_: Request) -> response.JSONResponse:
    """Responds to healthcheck requests.

    :param request: the incoming healthcheck request.
    :return: json responding to the healthcheck.
    """
    return response.json({"healthy": "yes"})


@server.route("/reconstruction/predict", methods=["POST"])
def predict(request: Request) -> response.JSONResponse:
    """Responds to inference/prediction requests.

    :param request: the incoming request containing inputs for the model.
    :return: json containing the inference results.
    """

    try:
        inputs = request.json
        output = model_instance.predict(inputs)
        return response.json(output)
    except Exception as e:
        return response.json({'error': str(e)}, status=500)


def main():
    """Entry point for the server."""
    port = int(os.environ.get("SERVING_PORT", _DEFAULT_PORT))
    server.run(host="0.0.0.0", port=port, workers=1)

if __name__ == "__main__":
    main()