File size: 1,778 Bytes
e2d4dfc
 
 
 
 
 
 
 
 
 
7c49463
 
 
e2d4dfc
 
 
7c49463
 
 
 
e2d4dfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
from contextlib import asynccontextmanager
import logging
import os
from fastapi import FastAPI
from aiproxy.chatgpt import ChatGPTProxy
from aiproxy.accesslog import AccessLogWorker
import threading
import uvicorn

# Get Base URL and API Key from env
base_url = os.environ.get("BASE_URL")
env_api_key = os.environ.get("API_KEY")

# Get arguments
parser = argparse.ArgumentParser(description="UnaProxy usage")
parser.add_argument("--host", type=str, default=None, required=False, help="hostname or ipaddress")
parser.add_argument("--port", type=int, default=None, required=False, help="port number")
parser.add_argument("--base_url", type=str, default=base_url, required=False, help="port number")
parser.add_argument("--openai_api_key", type=str, default=env_api_key, required=False, help="OpenAI API Key")
args = parser.parse_args()

# Setup logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_format = logging.Formatter("%(asctime)s %(levelname)8s %(message)s")
streamHandler = logging.StreamHandler()
streamHandler.setFormatter(log_format)
logger.addHandler(streamHandler)

# Setup access log worker
worker = AccessLogWorker()

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Start access log worker
    threading.Thread(target=worker.run, daemon=True).start()
    yield
    # Stop access log worker
    worker.queue_client.put(None)

# Setup ChatGPTProxy
proxy = ChatGPTProxy(base_url=args.base_url, api_key=args.openai_api_key, access_logger_queue=worker.queue_client)

# Setup server application
app = FastAPI(lifespan=lifespan, docs_url=None, redoc_url=None, openapi_url=None)
proxy.add_route(app, "/v1/chat/completions")
#proxy.add_completion_route(app, "/v1/completions")

uvicorn.run(app, host=args.host, port=args.port)