File size: 5,092 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
#!/usr/bin/env python
import configargparse
from flask import Flask, jsonify, request
from waitress import serve
from onmt.translate import TranslationServer, ServerModelError
import logging
from logging.handlers import RotatingFileHandler
STATUS_OK = "ok"
STATUS_ERROR = "error"
def start(config_file,
url_root="./translator",
host="0.0.0.0",
port=5000,
debug=False):
def prefix_route(route_function, prefix='', mask='{0}{1}'):
def newroute(route, *args, **kwargs):
return route_function(mask.format(prefix, route), *args, **kwargs)
return newroute
if debug:
logger = logging.getLogger("main")
log_format = logging.Formatter(
"[%(asctime)s %(levelname)s] %(message)s")
file_handler = RotatingFileHandler(
"debug_requests.log",
maxBytes=1000000, backupCount=10)
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
app = Flask(__name__)
app.route = prefix_route(app.route, url_root)
translation_server = TranslationServer()
translation_server.start(config_file)
@app.route('/models', methods=['GET'])
def get_models():
out = translation_server.list_models()
return jsonify(out)
@app.route('/health', methods=['GET'])
def health():
out = {}
out['status'] = STATUS_OK
return jsonify(out)
@app.route('/clone_model/<int:model_id>', methods=['POST'])
def clone_model(model_id):
out = {}
data = request.get_json(force=True)
timeout = -1
if 'timeout' in data:
timeout = data['timeout']
del data['timeout']
opt = data.get('opt', None)
try:
model_id, load_time = translation_server.clone_model(
model_id, opt, timeout)
except ServerModelError as e:
out['status'] = STATUS_ERROR
out['error'] = str(e)
else:
out['status'] = STATUS_OK
out['model_id'] = model_id
out['load_time'] = load_time
return jsonify(out)
@app.route('/unload_model/<int:model_id>', methods=['GET'])
def unload_model(model_id):
out = {"model_id": model_id}
try:
translation_server.unload_model(model_id)
out['status'] = STATUS_OK
except Exception as e:
out['status'] = STATUS_ERROR
out['error'] = str(e)
return jsonify(out)
@app.route('/translate', methods=['POST'])
def translate():
inputs = request.get_json(force=True)
if debug:
logger.info(inputs)
out = {}
try:
trans, scores, n_best, _, aligns = translation_server.run(inputs)
assert len(trans) == len(inputs) * n_best
assert len(scores) == len(inputs) * n_best
assert len(aligns) == len(inputs) * n_best
out = [[] for _ in range(n_best)]
for i in range(len(trans)):
response = {"src": inputs[i // n_best]['src'], "tgt": trans[i],
"n_best": n_best, "pred_score": scores[i]}
if len(aligns[i]) > 0 and aligns[i][0] is not None:
response["align"] = aligns[i]
out[i % n_best].append(response)
except ServerModelError as e:
model_id = inputs[0].get("id")
if debug:
logger.warning("Unload model #{} "
"because of an error".format(model_id))
translation_server.models[model_id].unload()
out['error'] = str(e)
out['status'] = STATUS_ERROR
if debug:
logger.info(out)
return jsonify(out)
@app.route('/to_cpu/<int:model_id>', methods=['GET'])
def to_cpu(model_id):
out = {'model_id': model_id}
translation_server.models[model_id].to_cpu()
out['status'] = STATUS_OK
return jsonify(out)
@app.route('/to_gpu/<int:model_id>', methods=['GET'])
def to_gpu(model_id):
out = {'model_id': model_id}
translation_server.models[model_id].to_gpu()
out['status'] = STATUS_OK
return jsonify(out)
serve(app, host=host, port=port)
def _get_parser():
parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
description="OpenNMT-py REST Server")
parser.add_argument("--ip", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default="5000")
parser.add_argument("--url_root", type=str, default="/translator")
parser.add_argument("--debug", "-d", action="store_true")
parser.add_argument("--config", "-c", type=str,
default="./available_models/conf.json")
return parser
def main():
parser = _get_parser()
args = parser.parse_args()
start(args.config, url_root=args.url_root, host=args.ip, port=args.port,
debug=args.debug)
if __name__ == "__main__":
main()
|