gpt-4 / apis /chat_api.py
Hansimov's picture
:recycle: [Refactor] Remove unused imports
82b42e7
raw
history blame
5.69 kB
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from conversations import (
ConversationConnector,
ConversationCreator,
ConversationSession,
)
class ChatAPIApp:
def __init__(self):
self.app = FastAPI(
docs_url="/",
title="Bing Chat API",
version="1.0",
)
self.setup_routes()
def get_available_models(self):
self.available_models = [
{
"id": "precise",
"description": "Bing (Precise): Concise and straightforward.",
},
{
"id": "balanced",
"description": "Bing (Balanced): Informative and friendly.",
},
{
"id": "creative",
"description": "Bing (Creative): Original and imaginative.",
},
{
"id": "precise-offline",
"description": "Bing (Precise): (No Internet) Concise and straightforward.",
},
{
"id": "balanced-offline",
"description": "Bing (Balanced): (No Internet) Informative and friendly.",
},
{
"id": "creative-offline",
"description": "Bing (Creative): (No Internet) Original and imaginative.",
},
]
return self.available_models
class CreateConversationSessionPostItem(BaseModel):
model: str = Field(
default="precise",
description="(str) `precise`, `balanced`, `creative`, `precise-offline`, `balanced-offline`, `creative-offline`",
)
def create_conversation_session(self, item: CreateConversationSessionPostItem):
creator = ConversationCreator()
creator.create()
return {
"model": item.model,
"sec_access_token": creator.sec_access_token,
"client_id": creator.client_id,
"conversation_id": creator.conversation_id,
}
class ChatPostItem(BaseModel):
prompt: str = Field(
default="Hello, who are you?",
description="(str) Prompt",
)
model: str = Field(
default="precise",
description="(str) `precise`, `balanced`, `creative`, `precise-offline`, `balanced-offline`, `creative-offline`",
)
sec_access_token: str = Field(
default="",
description="(str) Sec Access Token",
)
client_id: str = Field(
default="",
description="(str) Client ID",
)
conversation_id: str = Field(
default="",
description="(str) Conversation ID",
)
invocation_id: int = Field(
default=0,
description="(int) Invocation ID",
)
def chat(self, item: ChatPostItem):
connector = ConversationConnector(
conversation_style=item.model,
sec_access_token=item.sec_access_token,
client_id=item.client_id,
conversation_id=item.conversation_id,
invocation_id=item.invocation_id,
)
session = ConversationSession(
conversation_style=item.model,
connector=connector,
)
with session:
session.chat(prompt=item.prompt)
class ChatCompletionsPostItem(BaseModel):
model: str = Field(
default="precise",
description="(str) `precise`, `balanced`, `creative`, `precise-offline`, `balanced-offline`, `creative-offline`",
)
messages: list = Field(
default=[{"role": "user", "content": "Hello, who are you?"}],
description="(list) Messages",
)
sec_access_token: str = Field(
default="",
description="(str) Sec Access Token",
)
client_id: str = Field(
default="",
description="(str) Client ID",
)
conversation_id: str = Field(
default="",
description="(str) Conversation ID",
)
invocation_id: int = Field(
default=0,
description="(int) Invocation ID",
)
def chat_completions(self, item: ChatCompletionsPostItem):
connector = ConversationConnector(
conversation_style=item.model,
sec_access_token=item.sec_access_token,
client_id=item.client_id,
conversation_id=item.conversation_id,
invocation_id=item.invocation_id,
)
if item.invocation_id == 0:
# TODO: History Messages Merger
prompt = item.messages[-1]["content"]
else:
prompt = item.messages[-1]["content"]
return EventSourceResponse(
connector.stream_chat(prompt=prompt, yield_output=True),
media_type="text/event-stream",
)
def setup_routes(self):
self.app.get(
"/models",
summary="Get available models",
)(self.get_available_models)
self.app.post(
"/create",
summary="Create a conversation session",
)(self.create_conversation_session)
self.app.post(
"/chat",
summary="Chat in conversation session",
)(self.chat)
self.app.post(
"/chat/completions",
summary="Chat completions in conversation session",
)(self.chat_completions)
app = ChatAPIApp().app
if __name__ == "__main__":
uvicorn.run("__main__:app", host="0.0.0.0", port=22222, reload=True)