|
from typing import Dict, List, Optional |
|
import asyncio |
|
import os |
|
|
|
import httpx |
|
from fastchat.protocol.chat_completion import ( |
|
ChatCompletionRequest, |
|
ChatCompletionResponse, |
|
) |
|
|
|
_BASE_URL = "http://localhost:8000" |
|
|
|
if os.environ.get("FASTCHAT_BASE_URL"): |
|
_BASE_URL = os.environ.get("FASTCHAT_BASE_URL") |
|
|
|
|
|
def set_baseurl(base_url: str): |
|
global _BASE_URL |
|
_BASE_URL = base_url |
|
|
|
|
|
class ChatCompletionClient: |
|
def __init__(self, base_url: str): |
|
self.base_url = base_url |
|
|
|
async def request_completion( |
|
self, request: ChatCompletionRequest, timeout: Optional[float] = None |
|
) -> ChatCompletionResponse: |
|
async with httpx.AsyncClient() as client: |
|
response = await client.post( |
|
f"{self.base_url}/v1/chat/completions", |
|
json=request.dict(), |
|
timeout=timeout, |
|
) |
|
response.raise_for_status() |
|
return ChatCompletionResponse.parse_obj(response.json()) |
|
|
|
|
|
class ChatCompletion: |
|
OBJECT_NAME = "chat.completions" |
|
|
|
@classmethod |
|
def create(cls, *args, **kwargs) -> ChatCompletionResponse: |
|
"""Creates a new chat completion for the provided messages and parameters. |
|
|
|
See `acreate` for more details. |
|
""" |
|
return asyncio.run(cls.acreate(*args, **kwargs)) |
|
|
|
@classmethod |
|
async def acreate( |
|
cls, |
|
model: str, |
|
messages: List[Dict[str, str]], |
|
temperature: Optional[float] = 0.7, |
|
n: int = 1, |
|
max_tokens: Optional[int] = None, |
|
stop: Optional[str] = None, |
|
timeout: Optional[float] = None, |
|
) -> ChatCompletionResponse: |
|
"""Creates a new chat completion for the provided messages and parameters.""" |
|
request = ChatCompletionRequest( |
|
model=model, |
|
messages=messages, |
|
temperature=temperature, |
|
n=n, |
|
max_tokens=max_tokens, |
|
stop=stop, |
|
) |
|
client = ChatCompletionClient(_BASE_URL) |
|
response = await client.request_completion(request, timeout=timeout) |
|
return response |
|
|