|
import argparse |
|
import uvicorn |
|
import sys |
|
import json |
|
|
|
from fastapi import FastAPI |
|
from fastapi.encoders import jsonable_encoder |
|
from fastapi.responses import JSONResponse |
|
from pydantic import BaseModel, Field |
|
from sse_starlette.sse import EventSourceResponse |
|
from utils.logger import logger |
|
from networks.message_streamer import MessageStreamer |
|
from messagers.message_composer import MessageComposer |
|
|
|
|
|
class ChatAPIApp: |
|
def __init__(self): |
|
self.app = FastAPI( |
|
docs_url="/", |
|
title="HuggingFace LLM API", |
|
swagger_ui_parameters={"defaultModelsExpandDepth": -1}, |
|
version="1.0", |
|
) |
|
self.setup_routes() |
|
|
|
def get_available_models(self): |
|
f = open('apis/lang_name.json', "r") |
|
self.available_models = json.loads(f.read()) |
|
return self.available_models |
|
|
|
class ChatCompletionsPostItem(BaseModel): |
|
from_language: str = Field( |
|
default="auto", |
|
description="(str) `Detect`", |
|
) |
|
to_language: str = Field( |
|
default="en", |
|
description="(str) `en`", |
|
) |
|
text: str = Field( |
|
default="Hello", |
|
description="(str) `Text for translate`", |
|
) |
|
|
|
def chat_completions(self, item: ChatCompletionsPostItem): |
|
item_response = { |
|
"from_language": item.from_language, |
|
"to_language": item.to_language, |
|
"text": item.text, |
|
"translate": "" |
|
} |
|
json_compatible_item_data = jsonable_encoder(item_response) |
|
return JSONResponse(content=json_compatible_item_data) |
|
|
|
def setup_routes(self): |
|
for prefix in ["", "/v1"]: |
|
self.app.get( |
|
prefix + "/models", |
|
summary="Get available languages", |
|
)(self.get_available_models) |
|
|
|
self.app.post( |
|
prefix + "/translate", |
|
summary="translate text", |
|
)(self.chat_completions) |
|
|
|
|
|
class ArgParser(argparse.ArgumentParser): |
|
def __init__(self, *args, **kwargs): |
|
super(ArgParser, self).__init__(*args, **kwargs) |
|
|
|
self.add_argument( |
|
"-s", |
|
"--server", |
|
type=str, |
|
default="0.0.0.0", |
|
help="Server IP for HF LLM Chat API", |
|
) |
|
self.add_argument( |
|
"-p", |
|
"--port", |
|
type=int, |
|
default=23333, |
|
help="Server Port for HF LLM Chat API", |
|
) |
|
|
|
self.add_argument( |
|
"-d", |
|
"--dev", |
|
default=False, |
|
action="store_true", |
|
help="Run in dev mode", |
|
) |
|
|
|
self.args = self.parse_args(sys.argv[1:]) |
|
|
|
|
|
app = ChatAPIApp().app |
|
|
|
if __name__ == "__main__": |
|
args = ArgParser().args |
|
if args.dev: |
|
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) |
|
else: |
|
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) |
|
|
|
|
|
|
|
|