Spaces:
Sleeping
Sleeping
Update apis/chat_api.py
Browse files- apis/chat_api.py +34 -119
apis/chat_api.py
CHANGED
@@ -20,16 +20,14 @@ from mocks.stream_chat_mocker import stream_chat_mock
|
|
20 |
class ChatAPIApp:
|
21 |
def __init__(self):
|
22 |
self.app = FastAPI(
|
23 |
-
docs_url=
|
|
|
24 |
title="HuggingFace LLM API",
|
25 |
-
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
|
26 |
version="1.0",
|
27 |
)
|
28 |
self.setup_routes()
|
29 |
|
30 |
def get_available_models(self):
|
31 |
-
# https://platform.openai.com/docs/api-reference/models/list
|
32 |
-
# ANCHOR[id=available-models]: Available models
|
33 |
self.available_models = {
|
34 |
"object": "list",
|
35 |
"data": [
|
@@ -63,58 +61,25 @@ class ChatAPIApp:
|
|
63 |
HTTPBearer(auto_error=False)
|
64 |
),
|
65 |
):
|
66 |
-
api_key = None
|
67 |
-
if
|
68 |
-
api_key
|
69 |
-
|
70 |
-
api_key = os.getenv("HF_TOKEN")
|
71 |
-
|
72 |
-
if api_key:
|
73 |
-
if api_key.startswith("hf_"):
|
74 |
-
return api_key
|
75 |
-
else:
|
76 |
-
logger.warn(f"Invalid HF Token!")
|
77 |
-
else:
|
78 |
-
logger.warn("Not provide HF Token!")
|
79 |
return None
|
80 |
|
81 |
class ChatCompletionsPostItem(BaseModel):
|
82 |
-
model: str = Field(
|
83 |
-
|
84 |
-
|
85 |
-
)
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
default=0.5,
|
92 |
-
description="(float) Temperature",
|
93 |
-
)
|
94 |
-
top_p: Union[float, None] = Field(
|
95 |
-
default=0.95,
|
96 |
-
description="(float) top p",
|
97 |
-
)
|
98 |
-
max_tokens: Union[int, None] = Field(
|
99 |
-
default=-1,
|
100 |
-
description="(int) Max tokens",
|
101 |
-
)
|
102 |
-
use_cache: bool = Field(
|
103 |
-
default=False,
|
104 |
-
description="(bool) Use cache",
|
105 |
-
)
|
106 |
-
stream: bool = Field(
|
107 |
-
default=True,
|
108 |
-
description="(bool) Stream",
|
109 |
-
)
|
110 |
-
|
111 |
-
def chat_completions(
|
112 |
-
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
|
113 |
-
):
|
114 |
streamer = MessageStreamer(model=item.model)
|
115 |
composer = MessageComposer(model=item.model)
|
116 |
composer.merge(messages=item.messages)
|
117 |
-
# streamer.chat = stream_chat_mock
|
118 |
|
119 |
stream_response = streamer.chat_response(
|
120 |
prompt=composer.merged_str,
|
@@ -124,80 +89,36 @@ class ChatAPIApp:
|
|
124 |
api_key=api_key,
|
125 |
use_cache=item.use_cache,
|
126 |
)
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
)
|
134 |
-
return event_source_response
|
135 |
-
else:
|
136 |
-
data_response = streamer.chat_return_dict(stream_response)
|
137 |
-
return data_response
|
138 |
|
139 |
def get_readme(self):
|
140 |
readme_path = Path(__file__).parents[1] / "README.md"
|
141 |
with open(readme_path, "r", encoding="utf-8") as rf:
|
142 |
-
|
143 |
-
readme_html = markdown2.markdown(
|
144 |
-
readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"]
|
145 |
-
)
|
146 |
-
return readme_html
|
147 |
|
148 |
def setup_routes(self):
|
|
|
|
|
149 |
for prefix in ["", "/v1", "/api", "/api/v1"]:
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
prefix + "/models",
|
157 |
-
summary="Get available models",
|
158 |
-
include_in_schema=include_in_schema,
|
159 |
-
)(self.get_available_models)
|
160 |
-
|
161 |
-
self.app.post(
|
162 |
-
prefix + "/chat/completions",
|
163 |
-
summary="Chat completions in conversation session",
|
164 |
-
include_in_schema=include_in_schema,
|
165 |
-
)(self.chat_completions)
|
166 |
-
self.app.get(
|
167 |
-
"/readme",
|
168 |
-
summary="README of HF LLM API",
|
169 |
-
response_class=HTMLResponse,
|
170 |
-
include_in_schema=False,
|
171 |
-
)(self.get_readme)
|
172 |
|
173 |
|
174 |
class ArgParser(argparse.ArgumentParser):
|
175 |
def __init__(self, *args, **kwargs):
|
176 |
super(ArgParser, self).__init__(*args, **kwargs)
|
177 |
-
|
178 |
-
self.add_argument(
|
179 |
-
|
180 |
-
"--server",
|
181 |
-
type=str,
|
182 |
-
default="0.0.0.0",
|
183 |
-
help="Server IP for HF LLM Chat API",
|
184 |
-
)
|
185 |
-
self.add_argument(
|
186 |
-
"-p",
|
187 |
-
"--port",
|
188 |
-
type=int,
|
189 |
-
default=23333,
|
190 |
-
help="Server Port for HF LLM Chat API",
|
191 |
-
)
|
192 |
-
|
193 |
-
self.add_argument(
|
194 |
-
"-d",
|
195 |
-
"--dev",
|
196 |
-
default=False,
|
197 |
-
action="store_true",
|
198 |
-
help="Run in dev mode",
|
199 |
-
)
|
200 |
-
|
201 |
self.args = self.parse_args(sys.argv[1:])
|
202 |
|
203 |
|
@@ -205,10 +126,4 @@ app = ChatAPIApp().app
|
|
205 |
|
206 |
if __name__ == "__main__":
|
207 |
args = ArgParser().args
|
208 |
-
|
209 |
-
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
|
210 |
-
else:
|
211 |
-
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
|
212 |
-
|
213 |
-
# python -m apis.chat_api # [Docker] on product mode
|
214 |
-
# python -m apis.chat_api -d # [Dev] on develop mode
|
|
|
20 |
class ChatAPIApp:
|
21 |
def __init__(self):
|
22 |
self.app = FastAPI(
|
23 |
+
docs_url=None, # Hide Swagger UI
|
24 |
+
redoc_url=None, # Hide ReDoc UI
|
25 |
title="HuggingFace LLM API",
|
|
|
26 |
version="1.0",
|
27 |
)
|
28 |
self.setup_routes()
|
29 |
|
30 |
def get_available_models(self):
|
|
|
|
|
31 |
self.available_models = {
|
32 |
"object": "list",
|
33 |
"data": [
|
|
|
61 |
HTTPBearer(auto_error=False)
|
62 |
),
|
63 |
):
|
64 |
+
api_key = os.getenv("HF_TOKEN") if credentials is None else credentials.credentials
|
65 |
+
if api_key and api_key.startswith("hf_"):
|
66 |
+
return api_key
|
67 |
+
logger.warn("Invalid or missing HF Token!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
return None
|
69 |
|
70 |
class ChatCompletionsPostItem(BaseModel):
|
71 |
+
model: str = Field(default="mixtral-8x7b", description="(str) `mixtral-8x7b`")
|
72 |
+
messages: list = Field(default=[{"role": "user", "content": "Hello, who are you?"}], description="(list) Messages")
|
73 |
+
temperature: Union[float, None] = Field(default=0.5, description="(float) Temperature")
|
74 |
+
top_p: Union[float, None] = Field(default=0.95, description="(float) top p")
|
75 |
+
max_tokens: Union[int, None] = Field(default=-1, description="(int) Max tokens")
|
76 |
+
use_cache: bool = Field(default=False, description="(bool) Use cache")
|
77 |
+
stream: bool = Field(default=True, description="(bool) Stream")
|
78 |
+
|
79 |
+
def chat_completions(self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
streamer = MessageStreamer(model=item.model)
|
81 |
composer = MessageComposer(model=item.model)
|
82 |
composer.merge(messages=item.messages)
|
|
|
83 |
|
84 |
stream_response = streamer.chat_response(
|
85 |
prompt=composer.merged_str,
|
|
|
89 |
api_key=api_key,
|
90 |
use_cache=item.use_cache,
|
91 |
)
|
92 |
+
return EventSourceResponse(
|
93 |
+
streamer.chat_return_generator(stream_response),
|
94 |
+
media_type="text/event-stream",
|
95 |
+
ping=2000,
|
96 |
+
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
|
97 |
+
) if item.stream else streamer.chat_return_dict(stream_response)
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
def get_readme(self):
|
100 |
readme_path = Path(__file__).parents[1] / "README.md"
|
101 |
with open(readme_path, "r", encoding="utf-8") as rf:
|
102 |
+
return markdown2.markdown(rf.read(), extras=["table", "fenced-code-blocks", "highlightjs-lang"])
|
|
|
|
|
|
|
|
|
103 |
|
104 |
def setup_routes(self):
|
105 |
+
self.app.get("/", summary="Root endpoint", include_in_schema=False)(lambda: "Hello World!") # Root route
|
106 |
+
|
107 |
for prefix in ["", "/v1", "/api", "/api/v1"]:
|
108 |
+
include_in_schema = prefix == "/api/v1"
|
109 |
+
|
110 |
+
self.app.get(prefix + "/models", summary="Get available models", include_in_schema=include_in_schema)(self.get_available_models)
|
111 |
+
self.app.post(prefix + "/chat/completions", summary="Chat completions in conversation session", include_in_schema=include_in_schema)(self.chat_completions)
|
112 |
+
|
113 |
+
self.app.get("/readme", summary="README of HF LLM API", response_class=HTMLResponse, include_in_schema=False)(self.get_readme)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
|
116 |
class ArgParser(argparse.ArgumentParser):
|
117 |
def __init__(self, *args, **kwargs):
|
118 |
super(ArgParser, self).__init__(*args, **kwargs)
|
119 |
+
self.add_argument("-s", "--server", type=str, default="0.0.0.0", help="Server IP for HF LLM Chat API")
|
120 |
+
self.add_argument("-p", "--port", type=int, default=23333, help="Server Port for HF LLM Chat API")
|
121 |
+
self.add_argument("-d", "--dev", default=False, action="store_true", help="Run in dev mode")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
self.args = self.parse_args(sys.argv[1:])
|
123 |
|
124 |
|
|
|
126 |
|
127 |
if __name__ == "__main__":
|
128 |
args = ArgParser().args
|
129 |
+
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=args.dev)
|
|
|
|
|
|
|
|
|
|
|
|