Spaces:
Sleeping
Sleeping
import argparse | |
import markdown2 | |
import os | |
import sys | |
import uvicorn | |
from pathlib import Path | |
from fastapi import FastAPI, Depends | |
from fastapi.responses import HTMLResponse | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel, Field | |
from typing import Union | |
from sse_starlette.sse import EventSourceResponse, ServerSentEvent | |
from utils.logger import logger | |
from networks.message_streamer import MessageStreamer | |
from messagers.message_composer import MessageComposer | |
from mocks.stream_chat_mocker import stream_chat_mock | |
class ChatAPIApp: | |
def __init__(self): | |
self.app = FastAPI( | |
docs_url=None, # Hide Swagger UI | |
redoc_url=None, # Hide ReDoc UI | |
title="HuggingFace LLM API", | |
version="1.0", | |
) | |
self.setup_routes() | |
def get_available_models(self): | |
self.available_models = { | |
"object": "list", | |
"data": [ | |
{ | |
"id": "mixtral-8x7b", | |
"description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"object": "model", | |
"created": 1700000000, | |
"owned_by": "mistralai", | |
}, | |
{ | |
"id": "mistral-7b", | |
"description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", | |
"object": "model", | |
"created": 1700000000, | |
"owned_by": "mistralai", | |
}, | |
{ | |
"id": "nous-mixtral-8x7b", | |
"description": "[NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO]: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
"object": "model", | |
"created": 1700000000, | |
"owned_by": "NousResearch", | |
}, | |
], | |
} | |
return self.available_models | |
def extract_api_key( | |
credentials: HTTPAuthorizationCredentials = Depends( | |
HTTPBearer(auto_error=False) | |
), | |
): | |
api_key = os.getenv("HF_TOKEN") if credentials is None else credentials.credentials | |
if api_key and api_key.startswith("hf_"): | |
return api_key | |
logger.warn("Invalid or missing HF Token!") | |
return None | |
class ChatCompletionsPostItem(BaseModel): | |
model: str = Field(default="mixtral-8x7b", description="(str) `mixtral-8x7b`") | |
messages: list = Field(default=[{"role": "user", "content": "Hello, who are you?"}], description="(list) Messages") | |
temperature: Union[float, None] = Field(default=0.5, description="(float) Temperature") | |
top_p: Union[float, None] = Field(default=0.95, description="(float) top p") | |
max_tokens: Union[int, None] = Field(default=-1, description="(int) Max tokens") | |
use_cache: bool = Field(default=False, description="(bool) Use cache") | |
stream: bool = Field(default=True, description="(bool) Stream") | |
def chat_completions(self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)): | |
streamer = MessageStreamer(model=item.model) | |
composer = MessageComposer(model=item.model) | |
composer.merge(messages=item.messages) | |
stream_response = streamer.chat_response( | |
prompt=composer.merged_str, | |
temperature=item.temperature, | |
top_p=item.top_p, | |
max_new_tokens=item.max_tokens, | |
api_key=api_key, | |
use_cache=item.use_cache, | |
) | |
return EventSourceResponse( | |
streamer.chat_return_generator(stream_response), | |
media_type="text/event-stream", | |
ping=2000, | |
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}), | |
) if item.stream else streamer.chat_return_dict(stream_response) | |
def get_readme(self): | |
readme_path = Path(__file__).parents[1] / "README.md" | |
with open(readme_path, "r", encoding="utf-8") as rf: | |
return markdown2.markdown(rf.read(), extras=["table", "fenced-code-blocks", "highlightjs-lang"]) | |
def setup_routes(self): | |
self.app.get("/", summary="Root endpoint", include_in_schema=False)(lambda: "Hello World!") # Root route | |
for prefix in ["", "/v1", "/api", "/api/v1"]: | |
include_in_schema = prefix == "/api/v1" | |
self.app.get(prefix + "/models", summary="Get available models", include_in_schema=include_in_schema)(self.get_available_models) | |
self.app.post(prefix + "/chat/completions", summary="Chat completions in conversation session", include_in_schema=include_in_schema)(self.chat_completions) | |
self.app.get("/readme", summary="README of HF LLM API", response_class=HTMLResponse, include_in_schema=False)(self.get_readme) | |
class ArgParser(argparse.ArgumentParser): | |
def __init__(self, *args, **kwargs): | |
super(ArgParser, self).__init__(*args, **kwargs) | |
self.add_argument("-s", "--server", type=str, default="0.0.0.0", help="Server IP for HF LLM Chat API") | |
self.add_argument("-p", "--port", type=int, default=23333, help="Server Port for HF LLM Chat API") | |
self.add_argument("-d", "--dev", default=False, action="store_true", help="Run in dev mode") | |
self.args = self.parse_args(sys.argv[1:]) | |
app = ChatAPIApp().app | |
if __name__ == "__main__": | |
args = ArgParser().args | |
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=args.dev) |