gpt-4 / apis /chat_api.py
Hansimov's picture
:gem: [Feature] Support openai streaming response format
e4f31be
raw
history blame
5.74 kB
import json
import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from conversations import (
ConversationConnector,
ConversationCreator,
ConversationSession,
)
from networks import StreamResponseConstructor
class ChatAPIApp:
def __init__(self):
self.app = FastAPI(
docs_url="/",
title="Bing Chat API",
version="1.0",
)
self.setup_routes()
def get_available_models(self):
self.available_models = [
{
"id": "precise",
"description": "Bing (Precise): Concise and straightforward.",
},
{
"id": "balanced",
"description": "Bing (Balanced): Informative and friendly.",
},
{
"id": "creative",
"description": "Bing (Creative): Original and imaginative.",
},
{
"id": "precise-offline",
"description": "Bing (Precise): (No Internet) Concise and straightforward.",
},
{
"id": "balanced-offline",
"description": "Bing (Balanced): (No Internet) Informative and friendly.",
},
{
"id": "creative-offline",
"description": "Bing (Creative): (No Internet) Original and imaginative.",
},
]
return self.available_models
class CreateConversationSessionPostItem(BaseModel):
model: str = Field(
default="precise",
description="(str) `precise`, `balanced`, `creative`, `precise-offline`, `balanced-offline`, `creative-offline`",
)
def create_conversation_session(self, item: CreateConversationSessionPostItem):
creator = ConversationCreator()
creator.create()
return {
"model": item.model,
"sec_access_token": creator.sec_access_token,
"client_id": creator.client_id,
"conversation_id": creator.conversation_id,
}
class ChatPostItem(BaseModel):
prompt: str = Field(
default="Hello, who are you?",
description="(str) Prompt",
)
model: str = Field(
default="precise",
description="(str) `precise`, `balanced`, `creative`, `precise-offline`, `balanced-offline`, `creative-offline`",
)
sec_access_token: str = Field(
default="",
description="(str) Sec Access Token",
)
client_id: str = Field(
default="",
description="(str) Client ID",
)
conversation_id: str = Field(
default="",
description="(str) Conversation ID",
)
invocation_id: int = Field(
default=0,
description="(int) Invocation ID",
)
def chat(self, item: ChatPostItem):
connector = ConversationConnector(
conversation_style=item.model,
sec_access_token=item.sec_access_token,
client_id=item.client_id,
conversation_id=item.conversation_id,
invocation_id=item.invocation_id,
)
session = ConversationSession(
conversation_style=item.model,
connector=connector,
)
with session:
session.chat(prompt=item.prompt)
class ChatCompletionsPostItem(BaseModel):
model: str = Field(
default="precise",
description="(str) `precise`, `balanced`, `creative`, `precise-offline`, `balanced-offline`, `creative-offline`",
)
messages: list = Field(
default=[{"role": "user", "content": "Hello, who are you?"}],
description="(list) Messages",
)
sec_access_token: str = Field(
default="",
description="(str) Sec Access Token",
)
client_id: str = Field(
default="",
description="(str) Client ID",
)
conversation_id: str = Field(
default="",
description="(str) Conversation ID",
)
invocation_id: int = Field(
default=0,
description="(int) Invocation ID",
)
async def chat_completions(self, item: ChatCompletionsPostItem):
connector = ConversationConnector(
conversation_style=item.model,
sec_access_token=item.sec_access_token,
client_id=item.client_id,
conversation_id=item.conversation_id,
invocation_id=item.invocation_id,
)
if item.invocation_id == 0:
# TODO: History Messages Merger
prompt = item.messages[-1]["content"]
else:
prompt = item.messages[-1]["content"]
return StreamingResponse(
connector.stream_chat(prompt=prompt, yield_output=True),
media_type="text/plain",
)
def setup_routes(self):
self.app.get(
"/models",
summary="Get available models",
)(self.get_available_models)
self.app.post(
"/create",
summary="Create a conversation session",
)(self.create_conversation_session)
self.app.post(
"/chat",
summary="Chat in conversation session",
)(self.chat)
self.app.post(
"/chat/completions",
summary="Chat completions in conversation session",
)(self.chat_completions)
app = ChatAPIApp().app
if __name__ == "__main__":
uvicorn.run("__main__:app", host="0.0.0.0", port=22222, reload=True)