ronaldaug commited on
Commit
6bc807d
·
verified ·
1 Parent(s): 7903dbc

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +34 -119
apis/chat_api.py CHANGED
@@ -20,16 +20,14 @@ from mocks.stream_chat_mocker import stream_chat_mock
20
  class ChatAPIApp:
21
  def __init__(self):
22
  self.app = FastAPI(
23
- docs_url="/",
 
24
  title="HuggingFace LLM API",
25
- swagger_ui_parameters={"defaultModelsExpandDepth": -1},
26
  version="1.0",
27
  )
28
  self.setup_routes()
29
 
30
  def get_available_models(self):
31
- # https://platform.openai.com/docs/api-reference/models/list
32
- # ANCHOR[id=available-models]: Available models
33
  self.available_models = {
34
  "object": "list",
35
  "data": [
@@ -63,58 +61,25 @@ class ChatAPIApp:
63
  HTTPBearer(auto_error=False)
64
  ),
65
  ):
66
- api_key = None
67
- if credentials:
68
- api_key = credentials.credentials
69
- else:
70
- api_key = os.getenv("HF_TOKEN")
71
-
72
- if api_key:
73
- if api_key.startswith("hf_"):
74
- return api_key
75
- else:
76
- logger.warn(f"Invalid HF Token!")
77
- else:
78
- logger.warn("Not provide HF Token!")
79
  return None
80
 
81
  class ChatCompletionsPostItem(BaseModel):
82
- model: str = Field(
83
- default="mixtral-8x7b",
84
- description="(str) `mixtral-8x7b`",
85
- )
86
- messages: list = Field(
87
- default=[{"role": "user", "content": "Hello, who are you?"}],
88
- description="(list) Messages",
89
- )
90
- temperature: Union[float, None] = Field(
91
- default=0.5,
92
- description="(float) Temperature",
93
- )
94
- top_p: Union[float, None] = Field(
95
- default=0.95,
96
- description="(float) top p",
97
- )
98
- max_tokens: Union[int, None] = Field(
99
- default=-1,
100
- description="(int) Max tokens",
101
- )
102
- use_cache: bool = Field(
103
- default=False,
104
- description="(bool) Use cache",
105
- )
106
- stream: bool = Field(
107
- default=True,
108
- description="(bool) Stream",
109
- )
110
-
111
- def chat_completions(
112
- self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
113
- ):
114
  streamer = MessageStreamer(model=item.model)
115
  composer = MessageComposer(model=item.model)
116
  composer.merge(messages=item.messages)
117
- # streamer.chat = stream_chat_mock
118
 
119
  stream_response = streamer.chat_response(
120
  prompt=composer.merged_str,
@@ -124,80 +89,36 @@ class ChatAPIApp:
124
  api_key=api_key,
125
  use_cache=item.use_cache,
126
  )
127
- if item.stream:
128
- event_source_response = EventSourceResponse(
129
- streamer.chat_return_generator(stream_response),
130
- media_type="text/event-stream",
131
- ping=2000,
132
- ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
133
- )
134
- return event_source_response
135
- else:
136
- data_response = streamer.chat_return_dict(stream_response)
137
- return data_response
138
 
139
  def get_readme(self):
140
  readme_path = Path(__file__).parents[1] / "README.md"
141
  with open(readme_path, "r", encoding="utf-8") as rf:
142
- readme_str = rf.read()
143
- readme_html = markdown2.markdown(
144
- readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"]
145
- )
146
- return readme_html
147
 
148
  def setup_routes(self):
 
 
149
  for prefix in ["", "/v1", "/api", "/api/v1"]:
150
- if prefix in ["/api/v1"]:
151
- include_in_schema = True
152
- else:
153
- include_in_schema = False
154
-
155
- self.app.get(
156
- prefix + "/models",
157
- summary="Get available models",
158
- include_in_schema=include_in_schema,
159
- )(self.get_available_models)
160
-
161
- self.app.post(
162
- prefix + "/chat/completions",
163
- summary="Chat completions in conversation session",
164
- include_in_schema=include_in_schema,
165
- )(self.chat_completions)
166
- self.app.get(
167
- "/readme",
168
- summary="README of HF LLM API",
169
- response_class=HTMLResponse,
170
- include_in_schema=False,
171
- )(self.get_readme)
172
 
173
 
174
  class ArgParser(argparse.ArgumentParser):
175
  def __init__(self, *args, **kwargs):
176
  super(ArgParser, self).__init__(*args, **kwargs)
177
-
178
- self.add_argument(
179
- "-s",
180
- "--server",
181
- type=str,
182
- default="0.0.0.0",
183
- help="Server IP for HF LLM Chat API",
184
- )
185
- self.add_argument(
186
- "-p",
187
- "--port",
188
- type=int,
189
- default=23333,
190
- help="Server Port for HF LLM Chat API",
191
- )
192
-
193
- self.add_argument(
194
- "-d",
195
- "--dev",
196
- default=False,
197
- action="store_true",
198
- help="Run in dev mode",
199
- )
200
-
201
  self.args = self.parse_args(sys.argv[1:])
202
 
203
 
@@ -205,10 +126,4 @@ app = ChatAPIApp().app
205
 
206
  if __name__ == "__main__":
207
  args = ArgParser().args
208
- if args.dev:
209
- uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
210
- else:
211
- uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
212
-
213
- # python -m apis.chat_api # [Docker] on product mode
214
- # python -m apis.chat_api -d # [Dev] on develop mode
 
20
  class ChatAPIApp:
21
  def __init__(self):
22
  self.app = FastAPI(
23
+ docs_url=None, # Hide Swagger UI
24
+ redoc_url=None, # Hide ReDoc UI
25
  title="HuggingFace LLM API",
 
26
  version="1.0",
27
  )
28
  self.setup_routes()
29
 
30
  def get_available_models(self):
 
 
31
  self.available_models = {
32
  "object": "list",
33
  "data": [
 
61
  HTTPBearer(auto_error=False)
62
  ),
63
  ):
64
+ api_key = os.getenv("HF_TOKEN") if credentials is None else credentials.credentials
65
+ if api_key and api_key.startswith("hf_"):
66
+ return api_key
67
+ logger.warn("Invalid or missing HF Token!")
 
 
 
 
 
 
 
 
 
68
  return None
69
 
70
  class ChatCompletionsPostItem(BaseModel):
71
+ model: str = Field(default="mixtral-8x7b", description="(str) `mixtral-8x7b`")
72
+ messages: list = Field(default=[{"role": "user", "content": "Hello, who are you?"}], description="(list) Messages")
73
+ temperature: Union[float, None] = Field(default=0.5, description="(float) Temperature")
74
+ top_p: Union[float, None] = Field(default=0.95, description="(float) top p")
75
+ max_tokens: Union[int, None] = Field(default=-1, description="(int) Max tokens")
76
+ use_cache: bool = Field(default=False, description="(bool) Use cache")
77
+ stream: bool = Field(default=True, description="(bool) Stream")
78
+
79
+ def chat_completions(self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  streamer = MessageStreamer(model=item.model)
81
  composer = MessageComposer(model=item.model)
82
  composer.merge(messages=item.messages)
 
83
 
84
  stream_response = streamer.chat_response(
85
  prompt=composer.merged_str,
 
89
  api_key=api_key,
90
  use_cache=item.use_cache,
91
  )
92
+ return EventSourceResponse(
93
+ streamer.chat_return_generator(stream_response),
94
+ media_type="text/event-stream",
95
+ ping=2000,
96
+ ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
97
+ ) if item.stream else streamer.chat_return_dict(stream_response)
 
 
 
 
 
98
 
99
  def get_readme(self):
100
  readme_path = Path(__file__).parents[1] / "README.md"
101
  with open(readme_path, "r", encoding="utf-8") as rf:
102
+ return markdown2.markdown(rf.read(), extras=["table", "fenced-code-blocks", "highlightjs-lang"])
 
 
 
 
103
 
104
  def setup_routes(self):
105
+ self.app.get("/", summary="Root endpoint", include_in_schema=False)(lambda: "Hello World!") # Root route
106
+
107
  for prefix in ["", "/v1", "/api", "/api/v1"]:
108
+ include_in_schema = prefix == "/api/v1"
109
+
110
+ self.app.get(prefix + "/models", summary="Get available models", include_in_schema=include_in_schema)(self.get_available_models)
111
+ self.app.post(prefix + "/chat/completions", summary="Chat completions in conversation session", include_in_schema=include_in_schema)(self.chat_completions)
112
+
113
+ self.app.get("/readme", summary="README of HF LLM API", response_class=HTMLResponse, include_in_schema=False)(self.get_readme)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  class ArgParser(argparse.ArgumentParser):
117
  def __init__(self, *args, **kwargs):
118
  super(ArgParser, self).__init__(*args, **kwargs)
119
+ self.add_argument("-s", "--server", type=str, default="0.0.0.0", help="Server IP for HF LLM Chat API")
120
+ self.add_argument("-p", "--port", type=int, default=23333, help="Server Port for HF LLM Chat API")
121
+ self.add_argument("-d", "--dev", default=False, action="store_true", help="Run in dev mode")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  self.args = self.parse_args(sys.argv[1:])
123
 
124
 
 
126
 
127
  if __name__ == "__main__":
128
  args = ArgParser().args
129
+ uvicorn.run("__main__:app", host=args.server, port=args.port, reload=args.dev)