Spaces:
Sleeping
Sleeping
changing to routing proxy
Browse files- aiproxy/__init__.py +14 -0
- aiproxy/__main__.py +49 -0
- aiproxy/accesslog.py +262 -0
- aiproxy/aiproxy.db +0 -0
- aiproxy/async_proxy.py +70 -0
- aiproxy/chatgpt.py +531 -0
- aiproxy/proxy.py +74 -0
- aiproxy/queueclient.py +51 -0
aiproxy/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .proxy import (
|
2 |
+
RequestFilterBase,
|
3 |
+
ResponseFilterBase,
|
4 |
+
RequestFilterException,
|
5 |
+
ResponseFilterException
|
6 |
+
)
|
7 |
+
|
8 |
+
from .accesslog import (
|
9 |
+
AccessLogBase,
|
10 |
+
AccessLog,
|
11 |
+
AccessLogWorker
|
12 |
+
)
|
13 |
+
|
14 |
+
from .chatgpt import ChatGPTProxy
|
aiproxy/__main__.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from contextlib import asynccontextmanager
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from fastapi import FastAPI
|
6 |
+
from aiproxy.chatgpt import ChatGPTProxy
|
7 |
+
from aiproxy.accesslog import AccessLogWorker
|
8 |
+
import threading
|
9 |
+
import uvicorn
|
10 |
+
|
11 |
+
# Get API Key from env
|
12 |
+
env_openai_api_key = "test"
|
13 |
+
|
14 |
+
# Get arguments
|
15 |
+
parser = argparse.ArgumentParser(description="UnaProxy usage")
|
16 |
+
parser.add_argument("--host", type=str, default="127.0.0.1", required=False, help="hostname or ipaddress")
|
17 |
+
parser.add_argument("--port", type=int, default="7860", required=False, help="port number")
|
18 |
+
parser.add_argument("--base_url", type=str, default="http://localhost:8000/v1/", required=False, help="port number")
|
19 |
+
parser.add_argument("--openai_api_key", type=str, default=env_openai_api_key, required=False, help="OpenAI API Key")
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
# Setup logger
|
23 |
+
logger = logging.getLogger()
|
24 |
+
logger.setLevel(logging.INFO)
|
25 |
+
log_format = logging.Formatter("%(asctime)s %(levelname)8s %(message)s")
|
26 |
+
streamHandler = logging.StreamHandler()
|
27 |
+
streamHandler.setFormatter(log_format)
|
28 |
+
logger.addHandler(streamHandler)
|
29 |
+
|
30 |
+
# Setup access log worker
|
31 |
+
worker = AccessLogWorker()
|
32 |
+
|
33 |
+
@asynccontextmanager
|
34 |
+
async def lifespan(app: FastAPI):
|
35 |
+
# Start access log worker
|
36 |
+
threading.Thread(target=worker.run, daemon=True).start()
|
37 |
+
yield
|
38 |
+
# Stop access log worker
|
39 |
+
worker.queue_client.put(None)
|
40 |
+
|
41 |
+
# Setup ChatGPTProxy
|
42 |
+
proxy = ChatGPTProxy(base_url=args.base_url, api_key=args.openai_api_key, access_logger_queue=worker.queue_client)
|
43 |
+
|
44 |
+
# Setup server application
|
45 |
+
app = FastAPI(lifespan=lifespan, docs_url=None, redoc_url=None, openapi_url=None)
|
46 |
+
proxy.add_route(app, "/v1/chat/completions")
|
47 |
+
#proxy.add_completion_route(app, "/v1/completions")
|
48 |
+
|
49 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
aiproxy/accesslog.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from datetime import datetime
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
from time import sleep
|
6 |
+
import traceback
|
7 |
+
from sqlalchemy import Column, Integer, String, Float, DateTime, create_engine
|
8 |
+
from sqlalchemy.orm import sessionmaker, declarative_base, declared_attr, Session
|
9 |
+
from .queueclient import DefaultQueueClient, QueueItemBase, QueueClientBase
|
10 |
+
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class _AccessLogBase:
|
16 |
+
@declared_attr
|
17 |
+
def __tablename__(cls):
|
18 |
+
return cls.__name__.lower()
|
19 |
+
|
20 |
+
@declared_attr
|
21 |
+
def id(cls):
|
22 |
+
return Column(Integer, primary_key=True)
|
23 |
+
|
24 |
+
@declared_attr
|
25 |
+
def request_id(cls):
|
26 |
+
return Column(String)
|
27 |
+
|
28 |
+
@declared_attr
|
29 |
+
def created_at(cls):
|
30 |
+
return Column(DateTime)
|
31 |
+
|
32 |
+
@declared_attr
|
33 |
+
def direction(cls):
|
34 |
+
return Column(String)
|
35 |
+
|
36 |
+
@declared_attr
|
37 |
+
def status_code(cls):
|
38 |
+
return Column(Integer)
|
39 |
+
|
40 |
+
@declared_attr
|
41 |
+
def content(cls):
|
42 |
+
return Column(String)
|
43 |
+
|
44 |
+
@declared_attr
|
45 |
+
def function_call(cls):
|
46 |
+
return Column(String)
|
47 |
+
|
48 |
+
@declared_attr
|
49 |
+
def tool_calls(cls):
|
50 |
+
return Column(String)
|
51 |
+
|
52 |
+
@declared_attr
|
53 |
+
def raw_body(cls):
|
54 |
+
return Column(String)
|
55 |
+
|
56 |
+
@declared_attr
|
57 |
+
def raw_headers(cls):
|
58 |
+
return Column(String)
|
59 |
+
|
60 |
+
@declared_attr
|
61 |
+
def model(cls):
|
62 |
+
return Column(String)
|
63 |
+
|
64 |
+
@declared_attr
|
65 |
+
def prompt_tokens(cls):
|
66 |
+
return Column(Integer)
|
67 |
+
|
68 |
+
@declared_attr
|
69 |
+
def completion_tokens(cls):
|
70 |
+
return Column(Integer)
|
71 |
+
|
72 |
+
@declared_attr
|
73 |
+
def request_time(cls):
|
74 |
+
return Column(Float)
|
75 |
+
|
76 |
+
@declared_attr
|
77 |
+
def request_time_api(cls):
|
78 |
+
return Column(Float)
|
79 |
+
|
80 |
+
|
81 |
+
# Classes for access log queue item
|
82 |
+
class RequestItemBase(QueueItemBase):
|
83 |
+
def __init__(self, request_id: str, request_json: dict, request_headers: dict) -> None:
|
84 |
+
self.request_id = request_id
|
85 |
+
self.request_json = request_json
|
86 |
+
self.request_headers = request_headers
|
87 |
+
|
88 |
+
@abstractmethod
|
89 |
+
def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
|
90 |
+
...
|
91 |
+
|
92 |
+
|
93 |
+
class ResponseItemBase(QueueItemBase):
|
94 |
+
def __init__(self, request_id: str, response_json: dict, response_headers: dict = None, duration: float = 0, duration_api: float = 0, status_code: int = 0) -> None:
|
95 |
+
self.request_id = request_id
|
96 |
+
self.response_json = response_json
|
97 |
+
self.response_headers = response_headers
|
98 |
+
self.duration = duration
|
99 |
+
self.duration_api = duration_api
|
100 |
+
self.status_code = status_code
|
101 |
+
|
102 |
+
@abstractmethod
|
103 |
+
def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
|
104 |
+
...
|
105 |
+
|
106 |
+
|
107 |
+
class StreamChunkItemBase(QueueItemBase):
|
108 |
+
def __init__(self, request_id: str, chunk_json: dict = None, response_headers: dict = None, duration: float = 0, duration_api: float = 0, request_json: dict = None, status_code: int = 0) -> None:
|
109 |
+
self.request_id = request_id
|
110 |
+
self.chunk_json = chunk_json
|
111 |
+
self.response_headers = response_headers
|
112 |
+
self.duration = duration
|
113 |
+
self.duration_api = duration_api
|
114 |
+
self.request_json = request_json
|
115 |
+
self.status_code = status_code
|
116 |
+
|
117 |
+
@abstractmethod
|
118 |
+
def to_accesslog(self, chunks: list, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
|
119 |
+
...
|
120 |
+
|
121 |
+
|
122 |
+
class ErrorItemBase(QueueItemBase):
|
123 |
+
def __init__(self, request_id: str, exception: Exception, traceback_info: str, response_json: dict = None, response_headers: dict = None, status_code: int = 0) -> None:
|
124 |
+
self.request_id = request_id
|
125 |
+
self.exception = exception
|
126 |
+
self.traceback_info = traceback_info
|
127 |
+
self.response_json = response_json
|
128 |
+
self.response_headers = response_headers
|
129 |
+
self.status_code = status_code
|
130 |
+
|
131 |
+
def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
|
132 |
+
if isinstance(self.response_json, dict):
|
133 |
+
try:
|
134 |
+
raw_body = json.dumps(self.response_json, ensure_ascii=False)
|
135 |
+
except Exception:
|
136 |
+
raw_body = str(self.response_json)
|
137 |
+
else:
|
138 |
+
raw_body = str(self.response_json)
|
139 |
+
|
140 |
+
return accesslog_cls(
|
141 |
+
request_id=self.request_id,
|
142 |
+
created_at=datetime.utcnow(),
|
143 |
+
direction="error",
|
144 |
+
content=f"{self.exception}\n{self.traceback_info}",
|
145 |
+
raw_body=raw_body,
|
146 |
+
raw_headers=json.dumps(self.response_headers, ensure_ascii=False) if self.response_headers else None,
|
147 |
+
model="error_handler",
|
148 |
+
status_code=self.status_code
|
149 |
+
)
|
150 |
+
|
151 |
+
def to_dict(self) -> dict:
|
152 |
+
return {
|
153 |
+
"type": self.__class__.__name__,
|
154 |
+
"request_id": self.request_id,
|
155 |
+
"exception": str(self.exception),
|
156 |
+
"traceback_info": self.traceback_info,
|
157 |
+
"response_json": self.response_json,
|
158 |
+
"response_headers": self.response_headers
|
159 |
+
}
|
160 |
+
|
161 |
+
|
162 |
+
class WorkerShutdownItem(QueueItemBase):
|
163 |
+
...
|
164 |
+
|
165 |
+
|
166 |
+
AccessLogBase = declarative_base(cls=_AccessLogBase)
|
167 |
+
|
168 |
+
|
169 |
+
class AccessLog(AccessLogBase): ...
|
170 |
+
|
171 |
+
|
172 |
+
class AccessLogWorker:
|
173 |
+
def __init__(self, *, connection_str: str = "sqlite:///aiproxy.db", db_engine = None, accesslog_cls = AccessLog, queue_client: QueueClientBase = None):
|
174 |
+
if db_engine:
|
175 |
+
self.db_engine = db_engine
|
176 |
+
else:
|
177 |
+
self.db_engine = create_engine(connection_str)
|
178 |
+
self.accesslog_cls = accesslog_cls
|
179 |
+
self.accesslog_cls.metadata.create_all(bind=self.db_engine)
|
180 |
+
self.get_session = sessionmaker(autocommit=False, autoflush=False, bind=self.db_engine)
|
181 |
+
self.queue_client = queue_client or DefaultQueueClient()
|
182 |
+
self.chunk_buffer = {}
|
183 |
+
|
184 |
+
def insert_request(self, accesslog: _AccessLogBase, db: Session):
|
185 |
+
db.add(accesslog)
|
186 |
+
db.commit()
|
187 |
+
|
188 |
+
def insert_response(self, accesslog: _AccessLogBase, db: Session):
|
189 |
+
db.add(accesslog)
|
190 |
+
db.commit()
|
191 |
+
|
192 |
+
def use_db(self, item: QueueItemBase):
|
193 |
+
return not (isinstance(item, StreamChunkItemBase) and item.duration == 0)
|
194 |
+
|
195 |
+
def process_item(self, item: QueueItemBase, db: Session):
|
196 |
+
try:
|
197 |
+
# Request
|
198 |
+
if isinstance(item, RequestItemBase):
|
199 |
+
self.insert_request(item.to_accesslog(self.accesslog_cls), db)
|
200 |
+
|
201 |
+
# Non-stream response
|
202 |
+
elif isinstance(item, ResponseItemBase):
|
203 |
+
self.insert_response(item.to_accesslog(self.accesslog_cls), db)
|
204 |
+
|
205 |
+
# Stream response
|
206 |
+
elif isinstance(item, StreamChunkItemBase):
|
207 |
+
if not self.chunk_buffer.get(item.request_id):
|
208 |
+
self.chunk_buffer[item.request_id] = []
|
209 |
+
|
210 |
+
if item.duration == 0:
|
211 |
+
self.chunk_buffer[item.request_id].append(item)
|
212 |
+
|
213 |
+
else:
|
214 |
+
# Last chunk data for specific request_id
|
215 |
+
self.insert_response(item.to_accesslog(
|
216 |
+
self.chunk_buffer[item.request_id], self.accesslog_cls
|
217 |
+
), db)
|
218 |
+
# Remove chunks from buffer
|
219 |
+
del self.chunk_buffer[item.request_id]
|
220 |
+
|
221 |
+
# Error response
|
222 |
+
elif isinstance(item, ErrorItemBase):
|
223 |
+
self.insert_response(item.to_accesslog(self.accesslog_cls), db)
|
224 |
+
|
225 |
+
except Exception as ex:
|
226 |
+
logger.error(f"Error at processing queue item: {ex}\n{traceback.format_exc()}")
|
227 |
+
|
228 |
+
|
229 |
+
def run(self):
|
230 |
+
while True:
|
231 |
+
sleep(self.queue_client.dequeue_interval)
|
232 |
+
db = None
|
233 |
+
try:
|
234 |
+
items = self.queue_client.get()
|
235 |
+
except Exception as ex:
|
236 |
+
logger.error(f"Error at getting items from queue client: {ex}\n{traceback.format_exc()}")
|
237 |
+
continue
|
238 |
+
|
239 |
+
for item in items:
|
240 |
+
try:
|
241 |
+
if isinstance(item, WorkerShutdownItem) or item is None:
|
242 |
+
return
|
243 |
+
|
244 |
+
if db is None and self.use_db(item):
|
245 |
+
# Get db session just once in the loop when the item that uses db found
|
246 |
+
db = self.get_session()
|
247 |
+
|
248 |
+
self.process_item(item, db)
|
249 |
+
|
250 |
+
except Exception as pex:
|
251 |
+
logger.error(f"Error at processing loop: {pex}\n{traceback.format_exc()}")
|
252 |
+
# Try to persist data in error log instead
|
253 |
+
try:
|
254 |
+
logger.error(f"data: {item.to_json()}")
|
255 |
+
except:
|
256 |
+
logger.error(f"data(to_json() failed): {str(item)}")
|
257 |
+
|
258 |
+
if db is not None:
|
259 |
+
try:
|
260 |
+
db.close()
|
261 |
+
except Exception as dbex:
|
262 |
+
logger.error(f"Error at closing db session: {dbex}\n{traceback.format_exc()}")
|
aiproxy/aiproxy.db
ADDED
Binary file (94.2 kB). View file
|
|
aiproxy/async_proxy.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request
|
2 |
+
import httpx
|
3 |
+
from starlette.responses import StreamingResponse, JSONResponse
|
4 |
+
from starlette.background import BackgroundTask
|
5 |
+
import uvicorn
|
6 |
+
import json
|
7 |
+
|
8 |
+
app = FastAPI(debug=True)
|
9 |
+
|
10 |
+
# Define the base URL of your backend server
|
11 |
+
BACKEND_BASE_URL = "http://localhost:8000"
|
12 |
+
TIMEOUT_KEEP_ALIVE = 5.0
|
13 |
+
timeout_config = httpx.Timeout(5.0, read=60.0)
|
14 |
+
|
15 |
+
|
16 |
+
async def hook(response: httpx.Response) -> None:
|
17 |
+
if response.is_error:
|
18 |
+
await response.aread()
|
19 |
+
response.raise_for_status()
|
20 |
+
|
21 |
+
|
22 |
+
@app.get("/{path:path}")
|
23 |
+
async def forward_get_request(path: str, request: Request):
|
24 |
+
async with httpx.AsyncClient() as client:
|
25 |
+
response = await client.get(f"{BACKEND_BASE_URL}/{path}", params=request.query_params)
|
26 |
+
content = response.aiter_bytes() if response.is_stream_consumed else response.content
|
27 |
+
return StreamingResponse(content, media_type=response.headers['Content-Type'])
|
28 |
+
|
29 |
+
|
30 |
+
@app.post("/{path:path}")
|
31 |
+
async def forward_post_request(path: str, request: Request):
|
32 |
+
# Retrieve the request body
|
33 |
+
body = await request.body()
|
34 |
+
|
35 |
+
# Prepare the headers, excluding those that can cause issues
|
36 |
+
headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "content-length"]}
|
37 |
+
|
38 |
+
async with httpx.AsyncClient(event_hooks={'response': [hook]}, timeout=timeout_config) as client:
|
39 |
+
# Send the request and get the response as a stream
|
40 |
+
req = client.build_request("POST", f"{BACKEND_BASE_URL}/{path}", content=body, headers=headers)
|
41 |
+
|
42 |
+
try:
|
43 |
+
response = await client.send(req, stream=True)
|
44 |
+
response.raise_for_status()
|
45 |
+
|
46 |
+
if json.loads(body.decode('utf-8'))['stream']:
|
47 |
+
# Custom streaming function
|
48 |
+
async def stream_response(response):
|
49 |
+
async for chunk in response.aiter_bytes():
|
50 |
+
yield chunk
|
51 |
+
await response.aclose() # Ensure the response is closed after streaming
|
52 |
+
|
53 |
+
return StreamingResponse(stream_response(response),
|
54 |
+
status_code=response.status_code,
|
55 |
+
headers=headers)
|
56 |
+
else: # For regular JSON responses
|
57 |
+
# For non-streaming responses, read the complete response body
|
58 |
+
content = await response.aread()
|
59 |
+
return JSONResponse(content=content, status_code=response.status_code)
|
60 |
+
except httpx.ResponseNotRead as exc:
|
61 |
+
print(f"HTTP Exception for {exc.request.url} - {exc}")
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
uvicorn.run(app,
|
66 |
+
host='127.0.0.1',
|
67 |
+
port=7860,
|
68 |
+
log_level="debug",
|
69 |
+
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
70 |
+
|
aiproxy/chatgpt.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import traceback
|
7 |
+
from typing import List, Union, AsyncGenerator
|
8 |
+
from uuid import uuid4
|
9 |
+
from fastapi import FastAPI, Request
|
10 |
+
from fastapi.responses import JSONResponse
|
11 |
+
from sse_starlette.sse import EventSourceResponse, AsyncContentStream
|
12 |
+
from openai import AsyncClient, APIStatusError, APIResponseValidationError, APIError, OpenAIError
|
13 |
+
from openai.types.chat import ChatCompletion
|
14 |
+
import tiktoken
|
15 |
+
from .proxy import ProxyBase, RequestFilterBase, ResponseFilterBase, RequestFilterException, ResponseFilterException
|
16 |
+
from .accesslog import _AccessLogBase, RequestItemBase, ResponseItemBase, StreamChunkItemBase, ErrorItemBase
|
17 |
+
from .queueclient import QueueClientBase
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class ChatGPTRequestItem(RequestItemBase):
|
23 |
+
def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
|
24 |
+
request_headers_copy = self.request_headers.copy()
|
25 |
+
if auth := request_headers_copy.get("authorization"):
|
26 |
+
request_headers_copy["authorization"] = auth[:12] + "*****" + auth[-2:]
|
27 |
+
|
28 |
+
content = self.request_json["messages"][-1]["content"]
|
29 |
+
if isinstance(content, list):
|
30 |
+
for c in content:
|
31 |
+
if c["type"] == "text":
|
32 |
+
content = c["text"]
|
33 |
+
break
|
34 |
+
else:
|
35 |
+
content = json.dumps(content)
|
36 |
+
|
37 |
+
accesslog = accesslog_cls(
|
38 |
+
request_id=self.request_id,
|
39 |
+
created_at=datetime.utcnow(),
|
40 |
+
direction="request",
|
41 |
+
content=content,
|
42 |
+
raw_body=json.dumps(self.request_json, ensure_ascii=False),
|
43 |
+
raw_headers=json.dumps(request_headers_copy, ensure_ascii=False),
|
44 |
+
model=self.request_json.get("model")
|
45 |
+
)
|
46 |
+
|
47 |
+
return accesslog
|
48 |
+
|
49 |
+
|
50 |
+
class ChatGPTResponseItem(ResponseItemBase):
|
51 |
+
def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
|
52 |
+
content = self.response_json["choices"][0]["message"].get("content")
|
53 |
+
function_call = self.response_json["choices"][0]["message"].get("function_call")
|
54 |
+
tool_calls = self.response_json["choices"][0]["message"].get("tool_calls")
|
55 |
+
response_headers = json.dumps(dict(self.response_headers.items()),
|
56 |
+
ensure_ascii=False) if self.response_headers is not None else None
|
57 |
+
model = self.response_json["model"]
|
58 |
+
prompt_tokens = self.response_json["usage"]["prompt_tokens"]
|
59 |
+
completion_tokens = self.response_json["usage"]["completion_tokens"]
|
60 |
+
|
61 |
+
return accesslog_cls(
|
62 |
+
request_id=self.request_id,
|
63 |
+
created_at=datetime.utcnow(),
|
64 |
+
direction="response",
|
65 |
+
status_code=self.status_code,
|
66 |
+
content=content,
|
67 |
+
function_call=json.dumps(function_call, ensure_ascii=False) if function_call is not None else None,
|
68 |
+
tool_calls=json.dumps(tool_calls, ensure_ascii=False) if tool_calls is not None else None,
|
69 |
+
raw_body=json.dumps(self.response_json, ensure_ascii=False),
|
70 |
+
raw_headers=response_headers,
|
71 |
+
model=model,
|
72 |
+
prompt_tokens=prompt_tokens,
|
73 |
+
completion_tokens=completion_tokens,
|
74 |
+
request_time=self.duration,
|
75 |
+
request_time_api=self.duration_api
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
token_encoder = tiktoken.get_encoding("cl100k_base")
|
80 |
+
|
81 |
+
|
82 |
+
def count_token(content: str):
|
83 |
+
return len(token_encoder.encode(content))
|
84 |
+
|
85 |
+
|
86 |
+
def count_request_token(request_json: dict):
|
87 |
+
tokens_per_message = 3
|
88 |
+
tokens_per_name = 1
|
89 |
+
token_count = 0
|
90 |
+
|
91 |
+
# messages
|
92 |
+
for m in request_json["messages"]:
|
93 |
+
token_count += tokens_per_message
|
94 |
+
for k, v in m.items():
|
95 |
+
if isinstance(v, list):
|
96 |
+
for c in v:
|
97 |
+
if c.get("type") == "text":
|
98 |
+
token_count += count_token(c["text"])
|
99 |
+
else:
|
100 |
+
token_count += count_token(v)
|
101 |
+
if k == "name":
|
102 |
+
token_count += tokens_per_name
|
103 |
+
|
104 |
+
# functions
|
105 |
+
if functions := request_json.get("functions"):
|
106 |
+
for f in functions:
|
107 |
+
token_count += count_token(json.dumps(f))
|
108 |
+
|
109 |
+
# function_call
|
110 |
+
if function_call := request_json.get("function_call"):
|
111 |
+
if isinstance(function_call, dict):
|
112 |
+
token_count += count_token(json.dumps(function_call))
|
113 |
+
else:
|
114 |
+
token_count += count_token(function_call)
|
115 |
+
|
116 |
+
# tools
|
117 |
+
if tools := request_json.get("tools"):
|
118 |
+
for t in tools:
|
119 |
+
token_count += count_token(json.dumps(t))
|
120 |
+
|
121 |
+
if tool_choice := request_json.get("tool_choice"):
|
122 |
+
token_count += count_token(json.dumps(tool_choice))
|
123 |
+
|
124 |
+
token_count += 3
|
125 |
+
return token_count
|
126 |
+
|
127 |
+
|
128 |
+
class ChatGPTStreamResponseItem(StreamChunkItemBase):
|
129 |
+
def to_accesslog(self, chunks: list, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
|
130 |
+
chunk_jsons = []
|
131 |
+
response_content = ""
|
132 |
+
function_call = None
|
133 |
+
tool_calls = None
|
134 |
+
prompt_tokens = 0
|
135 |
+
completion_tokens = 0
|
136 |
+
|
137 |
+
# Parse info from chunks
|
138 |
+
for chunk in chunks:
|
139 |
+
chunk_jsons.append(chunk.chunk_json)
|
140 |
+
|
141 |
+
if len(chunk.chunk_json["choices"]) == 0:
|
142 |
+
# Azure returns the first delta with empty choices
|
143 |
+
continue
|
144 |
+
|
145 |
+
delta = chunk.chunk_json["choices"][0]["delta"]
|
146 |
+
|
147 |
+
# Make tool_calls
|
148 |
+
if delta.get("tool_calls"):
|
149 |
+
if tool_calls is None:
|
150 |
+
tool_calls = []
|
151 |
+
if delta["tool_calls"][0]["function"].get("name"):
|
152 |
+
tool_calls.append({
|
153 |
+
"type": "function",
|
154 |
+
"function": {
|
155 |
+
"name": delta["tool_calls"][0]["function"]["name"],
|
156 |
+
"arguments": ""
|
157 |
+
}
|
158 |
+
})
|
159 |
+
elif delta["tool_calls"][0]["function"].get("arguments"):
|
160 |
+
tool_calls[-1]["function"]["arguments"] += delta["tool_calls"][0]["function"].get("arguments") or ""
|
161 |
+
|
162 |
+
# Make function_call
|
163 |
+
elif delta.get("function_call"):
|
164 |
+
if function_call is None:
|
165 |
+
function_call = {}
|
166 |
+
if delta["function_call"].get("name"):
|
167 |
+
function_call["name"] = delta["function_call"]["name"]
|
168 |
+
function_call["arguments"] = ""
|
169 |
+
elif delta["function_call"].get("arguments"):
|
170 |
+
function_call["arguments"] += delta["function_call"]["arguments"]
|
171 |
+
|
172 |
+
# Text content
|
173 |
+
else:
|
174 |
+
response_content += delta.get("content") or ""
|
175 |
+
|
176 |
+
# Serialize
|
177 |
+
function_call_str = json.dumps(function_call, ensure_ascii=False) if function_call is not None else None
|
178 |
+
tool_calls_str = json.dumps(tool_calls, ensure_ascii=False) if tool_calls is not None else None
|
179 |
+
response_headers = json.dumps(dict(self.response_headers.items()),
|
180 |
+
ensure_ascii=False) if self.response_headers is not None else None
|
181 |
+
|
182 |
+
# Count tokens
|
183 |
+
prompt_tokens = count_request_token(self.request_json)
|
184 |
+
|
185 |
+
if tool_calls_str:
|
186 |
+
completion_tokens = count_token(tool_calls_str)
|
187 |
+
elif function_call_str:
|
188 |
+
completion_tokens = count_token(function_call_str)
|
189 |
+
else:
|
190 |
+
completion_tokens = count_token(response_content)
|
191 |
+
|
192 |
+
return accesslog_cls(
|
193 |
+
request_id=self.request_id,
|
194 |
+
created_at=datetime.utcnow(),
|
195 |
+
direction="response",
|
196 |
+
status_code=self.status_code,
|
197 |
+
content=response_content,
|
198 |
+
function_call=function_call_str,
|
199 |
+
tool_calls=tool_calls_str,
|
200 |
+
raw_body=json.dumps(chunk_jsons, ensure_ascii=False),
|
201 |
+
raw_headers=response_headers,
|
202 |
+
model=chunk_jsons[0]["model"],
|
203 |
+
prompt_tokens=prompt_tokens,
|
204 |
+
completion_tokens=completion_tokens,
|
205 |
+
request_time=self.duration,
|
206 |
+
request_time_api=self.duration_api
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
class ChatGPTErrorItem(ErrorItemBase):
|
211 |
+
...
|
212 |
+
|
213 |
+
|
214 |
+
queue_item_types = [ChatGPTRequestItem, ChatGPTResponseItem, ChatGPTStreamResponseItem, ChatGPTErrorItem]
|
215 |
+
|
216 |
+
|
217 |
+
# Reverse aiproxy application for ChatGPT
|
218 |
+
class ChatGPTProxy(ProxyBase):
|
219 |
+
_empty_openai_api_key = "OPENAI_API_KEY_IS_NOT_SET"
|
220 |
+
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
*,
|
224 |
+
base_url: str = None,
|
225 |
+
api_key: str = None,
|
226 |
+
async_client: AsyncClient = None,
|
227 |
+
max_retries: int = 0,
|
228 |
+
timeout: float = 60.0,
|
229 |
+
request_filters: List[RequestFilterBase] = None,
|
230 |
+
response_filters: List[ResponseFilterBase] = None,
|
231 |
+
request_item_class: type = ChatGPTRequestItem,
|
232 |
+
response_item_class: type = ChatGPTResponseItem,
|
233 |
+
stream_response_item_class: type = ChatGPTStreamResponseItem,
|
234 |
+
error_item_class: type = ChatGPTErrorItem,
|
235 |
+
access_logger_queue: QueueClientBase,
|
236 |
+
):
|
237 |
+
super().__init__(
|
238 |
+
request_filters=request_filters,
|
239 |
+
response_filters=response_filters,
|
240 |
+
access_logger_queue=access_logger_queue
|
241 |
+
)
|
242 |
+
|
243 |
+
# Log items
|
244 |
+
self.request_item_class = request_item_class
|
245 |
+
self.response_item_class = response_item_class
|
246 |
+
self.stream_response_item_class = stream_response_item_class
|
247 |
+
self.error_item_class = error_item_class
|
248 |
+
|
249 |
+
# ChatGPT client config
|
250 |
+
self.base_url = base_url
|
251 |
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY") or self._empty_openai_api_key
|
252 |
+
self.max_retries = max_retries
|
253 |
+
self.timeout = timeout
|
254 |
+
self.async_client = async_client
|
255 |
+
|
256 |
+
async def filter_request(self, request_id: str, request_json: dict, request_headers: dict) -> Union[
|
257 |
+
dict, JSONResponse, EventSourceResponse]:
|
258 |
+
for f in self.request_filters:
|
259 |
+
if json_resp := await f.filter(request_id, request_json, request_headers):
|
260 |
+
# Return response if filter returns string
|
261 |
+
resp_for_log = {
|
262 |
+
"id": "-",
|
263 |
+
"choices": [
|
264 |
+
{"message": {"role": "assistant", "content": json_resp}, "finish_reason": "stop", "index": 0}],
|
265 |
+
"created": 0,
|
266 |
+
"model": "request_filter",
|
267 |
+
"object": "chat.completion",
|
268 |
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
269 |
+
}
|
270 |
+
# Response log
|
271 |
+
self.access_logger_queue.put(self.response_item_class(
|
272 |
+
request_id=request_id,
|
273 |
+
response_json=resp_for_log,
|
274 |
+
status_code=200
|
275 |
+
))
|
276 |
+
|
277 |
+
if request_json.get("stream"):
|
278 |
+
# Stream
|
279 |
+
async def filter_response_stream(content: str):
|
280 |
+
# First delta
|
281 |
+
resp = {
|
282 |
+
"id": "-",
|
283 |
+
"choices": [
|
284 |
+
{"delta": {"role": "assistant", "content": ""}, "finish_reason": None, "index": 0}],
|
285 |
+
"created": 0,
|
286 |
+
"model": "request_filter",
|
287 |
+
"object": "chat.completion",
|
288 |
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
289 |
+
}
|
290 |
+
yield json.dumps(resp)
|
291 |
+
# Last delta
|
292 |
+
resp["choices"][0] = {"delta": {"content": content}, "finish_reason": "stop", "index": 0}
|
293 |
+
yield json.dumps(resp)
|
294 |
+
|
295 |
+
return self.return_response_with_headers(EventSourceResponse(
|
296 |
+
filter_response_stream(json_resp)
|
297 |
+
), request_id)
|
298 |
+
|
299 |
+
else:
|
300 |
+
# Non-stream
|
301 |
+
return self.return_response_with_headers(JSONResponse(resp_for_log), request_id)
|
302 |
+
|
303 |
+
return request_json
|
304 |
+
|
305 |
+
def get_client(self):
|
306 |
+
return self.async_client or AsyncClient(
|
307 |
+
base_url=self.base_url,
|
308 |
+
api_key=self.api_key,
|
309 |
+
max_retries=self.max_retries,
|
310 |
+
timeout=self.timeout
|
311 |
+
)
|
312 |
+
|
313 |
+
async def filter_response(self, request_id: str, response: ChatCompletion) -> ChatCompletion:
|
314 |
+
response_json = response.model_dump()
|
315 |
+
|
316 |
+
for f in self.response_filters:
|
317 |
+
if json_resp := await f.filter(request_id, response_json):
|
318 |
+
return response.model_validate(json_resp)
|
319 |
+
|
320 |
+
return response.model_validate(response_json)
|
321 |
+
|
322 |
+
def return_response_with_headers(self, resp: JSONResponse, request_id: str):
|
323 |
+
self.add_response_headers(response=resp, request_id=request_id)
|
324 |
+
return resp
|
325 |
+
|
326 |
+
def add_route(self, app: FastAPI, base_url: str):
|
327 |
+
@app.post(base_url)
|
328 |
+
async def handle_request(request: Request):
|
329 |
+
request_id = str(uuid4())
|
330 |
+
async_client = None
|
331 |
+
|
332 |
+
try:
|
333 |
+
start_time = time.time()
|
334 |
+
request_json = await request.json()
|
335 |
+
request_headers = dict(request.headers.items())
|
336 |
+
|
337 |
+
# Log request
|
338 |
+
self.access_logger_queue.put(self.request_item_class(
|
339 |
+
request_id=request_id,
|
340 |
+
request_json=request_json,
|
341 |
+
request_headers=request_headers
|
342 |
+
))
|
343 |
+
|
344 |
+
# Filter request
|
345 |
+
request_json = await self.filter_request(request_id, request_json, request_headers)
|
346 |
+
if isinstance(request_json, JSONResponse) or isinstance(request_json, EventSourceResponse):
|
347 |
+
return request_json
|
348 |
+
|
349 |
+
# Call API
|
350 |
+
async_client = self.get_client()
|
351 |
+
start_time_api = time.time()
|
352 |
+
if self.api_key != self._empty_openai_api_key:
|
353 |
+
# Always use server api key if set to client
|
354 |
+
raw_response = await async_client.chat.completions.with_raw_response.create(**request_json)
|
355 |
+
elif user_auth_header := request_headers.get("authorization"): # Lower case from client.
|
356 |
+
raw_response = await async_client.chat.completions.with_raw_response.create(
|
357 |
+
**request_json, extra_headers={"Authorization": user_auth_header} # Pascal to server
|
358 |
+
)
|
359 |
+
else:
|
360 |
+
# Call API anyway ;)
|
361 |
+
raw_response = await async_client.chat.completions.with_raw_response.create(**request_json)
|
362 |
+
|
363 |
+
completion_response = raw_response.parse()
|
364 |
+
completion_response_headers = raw_response.headers
|
365 |
+
completion_status_code = raw_response.status_code
|
366 |
+
if "content-encoding" in completion_response_headers:
|
367 |
+
completion_response_headers.pop(
|
368 |
+
"content-encoding") # Remove "br" that will be changed by this aiproxy
|
369 |
+
|
370 |
+
# Handling response from API
|
371 |
+
if request_json.get("stream"):
|
372 |
+
async def process_stream(stream: AsyncContentStream) -> AsyncGenerator[str, None]:
|
373 |
+
# Async content generator
|
374 |
+
try:
|
375 |
+
async for chunk in stream:
|
376 |
+
self.access_logger_queue.put(self.stream_response_item_class(
|
377 |
+
request_id=request_id,
|
378 |
+
chunk_json=chunk.model_dump()
|
379 |
+
))
|
380 |
+
if chunk:
|
381 |
+
yield chunk.model_dump_json()
|
382 |
+
|
383 |
+
finally:
|
384 |
+
# Close client after reading stream
|
385 |
+
await async_client.close()
|
386 |
+
|
387 |
+
# Response log
|
388 |
+
now = time.time()
|
389 |
+
self.access_logger_queue.put(self.stream_response_item_class(
|
390 |
+
request_id=request_id,
|
391 |
+
response_headers=completion_response_headers,
|
392 |
+
duration=now - start_time,
|
393 |
+
duration_api=now - start_time_api,
|
394 |
+
request_json=request_json,
|
395 |
+
status_code=completion_status_code
|
396 |
+
))
|
397 |
+
|
398 |
+
return self.return_response_with_headers(EventSourceResponse(
|
399 |
+
process_stream(completion_response),
|
400 |
+
headers=completion_response_headers
|
401 |
+
), request_id)
|
402 |
+
|
403 |
+
else:
|
404 |
+
# Close client immediately
|
405 |
+
await async_client.close()
|
406 |
+
|
407 |
+
duration_api = time.time() - start_time_api
|
408 |
+
|
409 |
+
# Filter response
|
410 |
+
completion_response = await self.filter_response(request_id, completion_response)
|
411 |
+
|
412 |
+
# Response log
|
413 |
+
self.access_logger_queue.put(self.response_item_class(
|
414 |
+
request_id=request_id,
|
415 |
+
response_json=completion_response.model_dump(),
|
416 |
+
response_headers=completion_response_headers,
|
417 |
+
duration=time.time() - start_time,
|
418 |
+
duration_api=duration_api,
|
419 |
+
status_code=completion_status_code
|
420 |
+
))
|
421 |
+
|
422 |
+
return self.return_response_with_headers(JSONResponse(
|
423 |
+
content=completion_response.model_dump(),
|
424 |
+
headers=completion_response_headers
|
425 |
+
), request_id)
|
426 |
+
|
427 |
+
# Error handlers
|
428 |
+
except RequestFilterException as rfex:
|
429 |
+
logger.error(f"Request filter error: {rfex}\n{traceback.format_exc()}")
|
430 |
+
|
431 |
+
resp_json = {
|
432 |
+
"error": {"message": rfex.message, "type": "request_filter_error", "param": None, "code": None}}
|
433 |
+
|
434 |
+
# Error log
|
435 |
+
self.access_logger_queue.put(self.error_item_class(
|
436 |
+
request_id=request_id,
|
437 |
+
exception=rfex,
|
438 |
+
traceback_info=traceback.format_exc(),
|
439 |
+
response_json=resp_json,
|
440 |
+
status_code=rfex.status_code
|
441 |
+
))
|
442 |
+
|
443 |
+
return self.return_response_with_headers(JSONResponse(resp_json, status_code=rfex.status_code),
|
444 |
+
request_id)
|
445 |
+
|
446 |
+
except ResponseFilterException as rfex:
|
447 |
+
logger.error(f"Response filter error: {rfex}\n{traceback.format_exc()}")
|
448 |
+
|
449 |
+
resp_json = {
|
450 |
+
"error": {"message": rfex.message, "type": "response_filter_error", "param": None, "code": None}}
|
451 |
+
|
452 |
+
# Error log
|
453 |
+
self.access_logger_queue.put(self.error_item_class(
|
454 |
+
request_id=request_id,
|
455 |
+
exception=rfex,
|
456 |
+
traceback_info=traceback.format_exc(),
|
457 |
+
response_json=resp_json,
|
458 |
+
status_code=rfex.status_code
|
459 |
+
))
|
460 |
+
|
461 |
+
return self.return_response_with_headers(JSONResponse(resp_json, status_code=rfex.status_code),
|
462 |
+
request_id)
|
463 |
+
|
464 |
+
except (APIStatusError, APIResponseValidationError) as status_err:
|
465 |
+
logger.error(f"APIStatusError from ChatGPT: {status_err}\n{traceback.format_exc()}")
|
466 |
+
|
467 |
+
# Error log
|
468 |
+
try:
|
469 |
+
resp_json = status_err.response.json()
|
470 |
+
except:
|
471 |
+
resp_json = str(status_err.response.content)
|
472 |
+
|
473 |
+
self.access_logger_queue.put(self.error_item_class(
|
474 |
+
request_id=request_id,
|
475 |
+
exception=status_err,
|
476 |
+
traceback_info=traceback.format_exc(),
|
477 |
+
response_json=resp_json,
|
478 |
+
status_code=status_err.status_code
|
479 |
+
))
|
480 |
+
|
481 |
+
return self.return_response_with_headers(JSONResponse(resp_json, status_code=status_err.status_code),
|
482 |
+
request_id)
|
483 |
+
|
484 |
+
except APIError as api_err:
|
485 |
+
logger.error(f"APIError from ChatGPT: {api_err}\n{traceback.format_exc()}")
|
486 |
+
|
487 |
+
resp_json = {"error": {"message": api_err.message, "type": api_err.type, "param": api_err.param,
|
488 |
+
"code": api_err.code}}
|
489 |
+
|
490 |
+
# Error log
|
491 |
+
self.access_logger_queue.put(self.error_item_class(
|
492 |
+
request_id=request_id,
|
493 |
+
exception=api_err,
|
494 |
+
traceback_info=traceback.format_exc(),
|
495 |
+
response_json=resp_json,
|
496 |
+
status_code=502
|
497 |
+
))
|
498 |
+
|
499 |
+
return self.return_response_with_headers(JSONResponse(resp_json, status_code=502), request_id)
|
500 |
+
|
501 |
+
except OpenAIError as oai_err:
|
502 |
+
logger.error(f"OpenAIError: {oai_err}\n{traceback.format_exc()}")
|
503 |
+
|
504 |
+
resp_json = {"error": {"message": str(oai_err), "type": "openai_error", "param": None, "code": None}}
|
505 |
+
|
506 |
+
# Error log
|
507 |
+
self.access_logger_queue.put(self.error_item_class(
|
508 |
+
request_id=request_id,
|
509 |
+
exception=oai_err,
|
510 |
+
traceback_info=traceback.format_exc(),
|
511 |
+
response_json=resp_json,
|
512 |
+
status_code=502
|
513 |
+
))
|
514 |
+
|
515 |
+
return self.return_response_with_headers(JSONResponse(resp_json, status_code=502), request_id)
|
516 |
+
|
517 |
+
except Exception as ex:
|
518 |
+
logger.error(f"Error at server: {ex}\n{traceback.format_exc()}")
|
519 |
+
|
520 |
+
resp_json = {"error": {"message": "Proxy error", "type": "proxy_error", "param": None, "code": None}}
|
521 |
+
|
522 |
+
# Error log
|
523 |
+
self.access_logger_queue.put(self.error_item_class(
|
524 |
+
request_id=request_id,
|
525 |
+
exception=ex,
|
526 |
+
traceback_info=traceback.format_exc(),
|
527 |
+
response_json=resp_json,
|
528 |
+
status_code=502
|
529 |
+
))
|
530 |
+
|
531 |
+
return self.return_response_with_headers(JSONResponse(resp_json, status_code=502), request_id)
|
aiproxy/proxy.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import logging
|
3 |
+
from typing import List, Union
|
4 |
+
from fastapi import FastAPI
|
5 |
+
from fastapi.responses import Response
|
6 |
+
from aiproxy.queueclient import QueueClientBase
|
7 |
+
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
# Classes for filter
|
13 |
+
class RequestFilterBase(ABC):
|
14 |
+
@abstractmethod
|
15 |
+
async def filter(self, request_id: str, request_json: dict, request_headers: dict) -> Union[str, None]:
|
16 |
+
...
|
17 |
+
|
18 |
+
|
19 |
+
class ResponseFilterBase(ABC):
|
20 |
+
@abstractmethod
|
21 |
+
async def filter(self, request_id: str, response_json: dict) -> Union[dict, None]:
|
22 |
+
...
|
23 |
+
|
24 |
+
|
25 |
+
class FilterException(Exception):
|
26 |
+
def __init__(self, message: str, status_code: int = 400) -> None:
|
27 |
+
self.message = message
|
28 |
+
self.status_code = status_code
|
29 |
+
|
30 |
+
|
31 |
+
class RequestFilterException(FilterException): ...
|
32 |
+
|
33 |
+
|
34 |
+
class ResponseFilterException(FilterException): ...
|
35 |
+
|
36 |
+
|
37 |
+
class ProxyBase(ABC):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
*,
|
41 |
+
request_filters: List[RequestFilterBase] = None,
|
42 |
+
response_filters: List[ResponseFilterBase] = None,
|
43 |
+
access_logger_queue: QueueClientBase
|
44 |
+
):
|
45 |
+
# Filters
|
46 |
+
self.request_filters = request_filters or []
|
47 |
+
self.response_filters = response_filters or []
|
48 |
+
|
49 |
+
# Access logger queue
|
50 |
+
self.access_logger_queue = access_logger_queue
|
51 |
+
|
52 |
+
def add_filter(self, filter: Union[RequestFilterBase, ResponseFilterBase]):
|
53 |
+
if isinstance(filter, RequestFilterBase):
|
54 |
+
self.request_filters.append(filter)
|
55 |
+
logger.info(f"request filter: {filter.__class__.__name__}")
|
56 |
+
elif isinstance(filter, ResponseFilterBase):
|
57 |
+
self.response_filters.append(filter)
|
58 |
+
logger.info(f"response filter: {filter.__class__.__name__}")
|
59 |
+
else:
|
60 |
+
logger.warning(f"Invalid filter: {filter.__class__.__name__}")
|
61 |
+
|
62 |
+
def add_response_headers(self, response: Response, request_id: str, headers: dict = None):
|
63 |
+
response.headers["X-AIProxy-Request-Id"] = request_id
|
64 |
+
if headers:
|
65 |
+
for k, v in headers.items():
|
66 |
+
response.headers[k] = v
|
67 |
+
|
68 |
+
@abstractmethod
|
69 |
+
def add_route(self, app: FastAPI, base_url: str):
|
70 |
+
...
|
71 |
+
|
72 |
+
# @abstractmethod
|
73 |
+
# def add_completion_route(self, app: FastAPI, base_url: str):
|
74 |
+
# ...
|
aiproxy/queueclient.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import json
|
3 |
+
from queue import Queue
|
4 |
+
from typing import Iterator
|
5 |
+
|
6 |
+
|
7 |
+
class QueueItemBase(ABC):
|
8 |
+
def to_dict(self) -> dict:
|
9 |
+
d = self.__dict__
|
10 |
+
d["type"] = self.__class__.__name__
|
11 |
+
return d
|
12 |
+
|
13 |
+
def to_json(self) -> str:
|
14 |
+
return json.dumps(self.to_dict())
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
def from_dict(cls, d: dict):
|
18 |
+
_d = d.copy()
|
19 |
+
del _d["type"]
|
20 |
+
return cls(**_d)
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def from_json(cls, json_str: str):
|
24 |
+
return cls.from_dict(json.loads(json_str))
|
25 |
+
|
26 |
+
|
27 |
+
class QueueClientBase(ABC):
|
28 |
+
dequeue_interval = 0.5
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def put(self, item: QueueItemBase):
|
32 |
+
...
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def get(self) -> Iterator[QueueItemBase]:
|
36 |
+
...
|
37 |
+
|
38 |
+
|
39 |
+
class DefaultQueueClient(QueueClientBase):
|
40 |
+
def __init__(self) -> None:
|
41 |
+
self.queue = Queue()
|
42 |
+
self.dequeue_interval = 0.5
|
43 |
+
|
44 |
+
def put(self, item: QueueItemBase):
|
45 |
+
self.queue.put(item)
|
46 |
+
|
47 |
+
def get(self) -> Iterator[QueueItemBase]:
|
48 |
+
items = []
|
49 |
+
while not self.queue.empty():
|
50 |
+
items.append(self.queue.get())
|
51 |
+
return iter(items)
|