Spaces:
Runtime error
Runtime error
import uvicorn | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from contextlib import asynccontextmanager | |
from sse_starlette import EventSourceResponse | |
from typing import List, Tuple | |
from llmtuner.extras.misc import torch_gc | |
from llmtuner.chat import ChatModel | |
from llmtuner.api.protocol import ( | |
Role, | |
Finish, | |
ModelCard, | |
ModelList, | |
ChatMessage, | |
DeltaMessage, | |
ChatCompletionRequest, | |
ChatCompletionResponse, | |
ChatCompletionStreamResponse, | |
ChatCompletionResponseChoice, | |
ChatCompletionResponseStreamChoice, | |
ChatCompletionResponseUsage | |
) | |
async def lifespan(app: FastAPI): # collects GPU memory | |
yield | |
torch_gc() | |
def create_app(chat_model: ChatModel) -> FastAPI: | |
app = FastAPI(lifespan=lifespan) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def list_models(): | |
model_card = ModelCard(id="gpt-3.5-turbo") | |
return ModelList(data=[model_card]) | |
async def create_chat_completion(request: ChatCompletionRequest): | |
if len(request.messages) < 1 or request.messages[-1].role != Role.USER: | |
raise HTTPException(status_code=400, detail="Invalid request") | |
query = request.messages[-1].content | |
prev_messages = request.messages[:-1] | |
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: | |
system = prev_messages.pop(0).content | |
else: | |
system = None | |
history = [] | |
if len(prev_messages) % 2 == 0: | |
for i in range(0, len(prev_messages), 2): | |
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT: | |
history.append([prev_messages[i].content, prev_messages[i+1].content]) | |
if request.stream: | |
generate = predict(query, history, system, request) | |
return EventSourceResponse(generate, media_type="text/event-stream") | |
response, (prompt_length, response_length) = chat_model.chat( | |
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens | |
) | |
usage = ChatCompletionResponseUsage( | |
prompt_tokens=prompt_length, | |
completion_tokens=response_length, | |
total_tokens=prompt_length+response_length | |
) | |
choice_data = ChatCompletionResponseChoice( | |
index=0, | |
message=ChatMessage(role=Role.ASSISTANT, content=response), | |
finish_reason=Finish.STOP | |
) | |
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) | |
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(role=Role.ASSISTANT), | |
finish_reason=None | |
) | |
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) | |
yield chunk.json(exclude_unset=True, ensure_ascii=False) | |
for new_text in chat_model.stream_chat( | |
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens | |
): | |
if len(new_text) == 0: | |
continue | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(content=new_text), | |
finish_reason=None | |
) | |
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) | |
yield chunk.json(exclude_unset=True, ensure_ascii=False) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(), | |
finish_reason=Finish.STOP | |
) | |
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) | |
yield chunk.json(exclude_unset=True, ensure_ascii=False) | |
yield "[DONE]" | |
return app | |
if __name__ == "__main__": | |
chat_model = ChatModel() | |
app = create_app(chat_model) | |
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) | |