bsmit1659 commited on
Commit
e2d4dfc
·
1 Parent(s): 48802c6

changing to routing proxy

Browse files
aiproxy/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .proxy import (
2
+ RequestFilterBase,
3
+ ResponseFilterBase,
4
+ RequestFilterException,
5
+ ResponseFilterException
6
+ )
7
+
8
+ from .accesslog import (
9
+ AccessLogBase,
10
+ AccessLog,
11
+ AccessLogWorker
12
+ )
13
+
14
+ from .chatgpt import ChatGPTProxy
aiproxy/__main__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from contextlib import asynccontextmanager
3
+ import logging
4
+ import os
5
+ from fastapi import FastAPI
6
+ from aiproxy.chatgpt import ChatGPTProxy
7
+ from aiproxy.accesslog import AccessLogWorker
8
+ import threading
9
+ import uvicorn
10
+
11
+ # Get API Key from env
12
+ env_openai_api_key = "test"
13
+
14
+ # Get arguments
15
+ parser = argparse.ArgumentParser(description="UnaProxy usage")
16
+ parser.add_argument("--host", type=str, default="127.0.0.1", required=False, help="hostname or ipaddress")
17
+ parser.add_argument("--port", type=int, default="7860", required=False, help="port number")
18
+ parser.add_argument("--base_url", type=str, default="http://localhost:8000/v1/", required=False, help="port number")
19
+ parser.add_argument("--openai_api_key", type=str, default=env_openai_api_key, required=False, help="OpenAI API Key")
20
+ args = parser.parse_args()
21
+
22
+ # Setup logger
23
+ logger = logging.getLogger()
24
+ logger.setLevel(logging.INFO)
25
+ log_format = logging.Formatter("%(asctime)s %(levelname)8s %(message)s")
26
+ streamHandler = logging.StreamHandler()
27
+ streamHandler.setFormatter(log_format)
28
+ logger.addHandler(streamHandler)
29
+
30
+ # Setup access log worker
31
+ worker = AccessLogWorker()
32
+
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ # Start access log worker
36
+ threading.Thread(target=worker.run, daemon=True).start()
37
+ yield
38
+ # Stop access log worker
39
+ worker.queue_client.put(None)
40
+
41
+ # Setup ChatGPTProxy
42
+ proxy = ChatGPTProxy(base_url=args.base_url, api_key=args.openai_api_key, access_logger_queue=worker.queue_client)
43
+
44
+ # Setup server application
45
+ app = FastAPI(lifespan=lifespan, docs_url=None, redoc_url=None, openapi_url=None)
46
+ proxy.add_route(app, "/v1/chat/completions")
47
+ #proxy.add_completion_route(app, "/v1/completions")
48
+
49
+ uvicorn.run(app, host=args.host, port=args.port)
aiproxy/accesslog.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from datetime import datetime
3
+ import json
4
+ import logging
5
+ from time import sleep
6
+ import traceback
7
+ from sqlalchemy import Column, Integer, String, Float, DateTime, create_engine
8
+ from sqlalchemy.orm import sessionmaker, declarative_base, declared_attr, Session
9
+ from .queueclient import DefaultQueueClient, QueueItemBase, QueueClientBase
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class _AccessLogBase:
16
+ @declared_attr
17
+ def __tablename__(cls):
18
+ return cls.__name__.lower()
19
+
20
+ @declared_attr
21
+ def id(cls):
22
+ return Column(Integer, primary_key=True)
23
+
24
+ @declared_attr
25
+ def request_id(cls):
26
+ return Column(String)
27
+
28
+ @declared_attr
29
+ def created_at(cls):
30
+ return Column(DateTime)
31
+
32
+ @declared_attr
33
+ def direction(cls):
34
+ return Column(String)
35
+
36
+ @declared_attr
37
+ def status_code(cls):
38
+ return Column(Integer)
39
+
40
+ @declared_attr
41
+ def content(cls):
42
+ return Column(String)
43
+
44
+ @declared_attr
45
+ def function_call(cls):
46
+ return Column(String)
47
+
48
+ @declared_attr
49
+ def tool_calls(cls):
50
+ return Column(String)
51
+
52
+ @declared_attr
53
+ def raw_body(cls):
54
+ return Column(String)
55
+
56
+ @declared_attr
57
+ def raw_headers(cls):
58
+ return Column(String)
59
+
60
+ @declared_attr
61
+ def model(cls):
62
+ return Column(String)
63
+
64
+ @declared_attr
65
+ def prompt_tokens(cls):
66
+ return Column(Integer)
67
+
68
+ @declared_attr
69
+ def completion_tokens(cls):
70
+ return Column(Integer)
71
+
72
+ @declared_attr
73
+ def request_time(cls):
74
+ return Column(Float)
75
+
76
+ @declared_attr
77
+ def request_time_api(cls):
78
+ return Column(Float)
79
+
80
+
81
+ # Classes for access log queue item
82
+ class RequestItemBase(QueueItemBase):
83
+ def __init__(self, request_id: str, request_json: dict, request_headers: dict) -> None:
84
+ self.request_id = request_id
85
+ self.request_json = request_json
86
+ self.request_headers = request_headers
87
+
88
+ @abstractmethod
89
+ def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
90
+ ...
91
+
92
+
93
+ class ResponseItemBase(QueueItemBase):
94
+ def __init__(self, request_id: str, response_json: dict, response_headers: dict = None, duration: float = 0, duration_api: float = 0, status_code: int = 0) -> None:
95
+ self.request_id = request_id
96
+ self.response_json = response_json
97
+ self.response_headers = response_headers
98
+ self.duration = duration
99
+ self.duration_api = duration_api
100
+ self.status_code = status_code
101
+
102
+ @abstractmethod
103
+ def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
104
+ ...
105
+
106
+
107
+ class StreamChunkItemBase(QueueItemBase):
108
+ def __init__(self, request_id: str, chunk_json: dict = None, response_headers: dict = None, duration: float = 0, duration_api: float = 0, request_json: dict = None, status_code: int = 0) -> None:
109
+ self.request_id = request_id
110
+ self.chunk_json = chunk_json
111
+ self.response_headers = response_headers
112
+ self.duration = duration
113
+ self.duration_api = duration_api
114
+ self.request_json = request_json
115
+ self.status_code = status_code
116
+
117
+ @abstractmethod
118
+ def to_accesslog(self, chunks: list, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
119
+ ...
120
+
121
+
122
+ class ErrorItemBase(QueueItemBase):
123
+ def __init__(self, request_id: str, exception: Exception, traceback_info: str, response_json: dict = None, response_headers: dict = None, status_code: int = 0) -> None:
124
+ self.request_id = request_id
125
+ self.exception = exception
126
+ self.traceback_info = traceback_info
127
+ self.response_json = response_json
128
+ self.response_headers = response_headers
129
+ self.status_code = status_code
130
+
131
+ def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
132
+ if isinstance(self.response_json, dict):
133
+ try:
134
+ raw_body = json.dumps(self.response_json, ensure_ascii=False)
135
+ except Exception:
136
+ raw_body = str(self.response_json)
137
+ else:
138
+ raw_body = str(self.response_json)
139
+
140
+ return accesslog_cls(
141
+ request_id=self.request_id,
142
+ created_at=datetime.utcnow(),
143
+ direction="error",
144
+ content=f"{self.exception}\n{self.traceback_info}",
145
+ raw_body=raw_body,
146
+ raw_headers=json.dumps(self.response_headers, ensure_ascii=False) if self.response_headers else None,
147
+ model="error_handler",
148
+ status_code=self.status_code
149
+ )
150
+
151
+ def to_dict(self) -> dict:
152
+ return {
153
+ "type": self.__class__.__name__,
154
+ "request_id": self.request_id,
155
+ "exception": str(self.exception),
156
+ "traceback_info": self.traceback_info,
157
+ "response_json": self.response_json,
158
+ "response_headers": self.response_headers
159
+ }
160
+
161
+
162
+ class WorkerShutdownItem(QueueItemBase):
163
+ ...
164
+
165
+
166
+ AccessLogBase = declarative_base(cls=_AccessLogBase)
167
+
168
+
169
+ class AccessLog(AccessLogBase): ...
170
+
171
+
172
+ class AccessLogWorker:
173
+ def __init__(self, *, connection_str: str = "sqlite:///aiproxy.db", db_engine = None, accesslog_cls = AccessLog, queue_client: QueueClientBase = None):
174
+ if db_engine:
175
+ self.db_engine = db_engine
176
+ else:
177
+ self.db_engine = create_engine(connection_str)
178
+ self.accesslog_cls = accesslog_cls
179
+ self.accesslog_cls.metadata.create_all(bind=self.db_engine)
180
+ self.get_session = sessionmaker(autocommit=False, autoflush=False, bind=self.db_engine)
181
+ self.queue_client = queue_client or DefaultQueueClient()
182
+ self.chunk_buffer = {}
183
+
184
+ def insert_request(self, accesslog: _AccessLogBase, db: Session):
185
+ db.add(accesslog)
186
+ db.commit()
187
+
188
+ def insert_response(self, accesslog: _AccessLogBase, db: Session):
189
+ db.add(accesslog)
190
+ db.commit()
191
+
192
+ def use_db(self, item: QueueItemBase):
193
+ return not (isinstance(item, StreamChunkItemBase) and item.duration == 0)
194
+
195
+ def process_item(self, item: QueueItemBase, db: Session):
196
+ try:
197
+ # Request
198
+ if isinstance(item, RequestItemBase):
199
+ self.insert_request(item.to_accesslog(self.accesslog_cls), db)
200
+
201
+ # Non-stream response
202
+ elif isinstance(item, ResponseItemBase):
203
+ self.insert_response(item.to_accesslog(self.accesslog_cls), db)
204
+
205
+ # Stream response
206
+ elif isinstance(item, StreamChunkItemBase):
207
+ if not self.chunk_buffer.get(item.request_id):
208
+ self.chunk_buffer[item.request_id] = []
209
+
210
+ if item.duration == 0:
211
+ self.chunk_buffer[item.request_id].append(item)
212
+
213
+ else:
214
+ # Last chunk data for specific request_id
215
+ self.insert_response(item.to_accesslog(
216
+ self.chunk_buffer[item.request_id], self.accesslog_cls
217
+ ), db)
218
+ # Remove chunks from buffer
219
+ del self.chunk_buffer[item.request_id]
220
+
221
+ # Error response
222
+ elif isinstance(item, ErrorItemBase):
223
+ self.insert_response(item.to_accesslog(self.accesslog_cls), db)
224
+
225
+ except Exception as ex:
226
+ logger.error(f"Error at processing queue item: {ex}\n{traceback.format_exc()}")
227
+
228
+
229
+ def run(self):
230
+ while True:
231
+ sleep(self.queue_client.dequeue_interval)
232
+ db = None
233
+ try:
234
+ items = self.queue_client.get()
235
+ except Exception as ex:
236
+ logger.error(f"Error at getting items from queue client: {ex}\n{traceback.format_exc()}")
237
+ continue
238
+
239
+ for item in items:
240
+ try:
241
+ if isinstance(item, WorkerShutdownItem) or item is None:
242
+ return
243
+
244
+ if db is None and self.use_db(item):
245
+ # Get db session just once in the loop when the item that uses db found
246
+ db = self.get_session()
247
+
248
+ self.process_item(item, db)
249
+
250
+ except Exception as pex:
251
+ logger.error(f"Error at processing loop: {pex}\n{traceback.format_exc()}")
252
+ # Try to persist data in error log instead
253
+ try:
254
+ logger.error(f"data: {item.to_json()}")
255
+ except:
256
+ logger.error(f"data(to_json() failed): {str(item)}")
257
+
258
+ if db is not None:
259
+ try:
260
+ db.close()
261
+ except Exception as dbex:
262
+ logger.error(f"Error at closing db session: {dbex}\n{traceback.format_exc()}")
aiproxy/aiproxy.db ADDED
Binary file (94.2 kB). View file
 
aiproxy/async_proxy.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ import httpx
3
+ from starlette.responses import StreamingResponse, JSONResponse
4
+ from starlette.background import BackgroundTask
5
+ import uvicorn
6
+ import json
7
+
8
+ app = FastAPI(debug=True)
9
+
10
+ # Define the base URL of your backend server
11
+ BACKEND_BASE_URL = "http://localhost:8000"
12
+ TIMEOUT_KEEP_ALIVE = 5.0
13
+ timeout_config = httpx.Timeout(5.0, read=60.0)
14
+
15
+
16
+ async def hook(response: httpx.Response) -> None:
17
+ if response.is_error:
18
+ await response.aread()
19
+ response.raise_for_status()
20
+
21
+
22
+ @app.get("/{path:path}")
23
+ async def forward_get_request(path: str, request: Request):
24
+ async with httpx.AsyncClient() as client:
25
+ response = await client.get(f"{BACKEND_BASE_URL}/{path}", params=request.query_params)
26
+ content = response.aiter_bytes() if response.is_stream_consumed else response.content
27
+ return StreamingResponse(content, media_type=response.headers['Content-Type'])
28
+
29
+
30
+ @app.post("/{path:path}")
31
+ async def forward_post_request(path: str, request: Request):
32
+ # Retrieve the request body
33
+ body = await request.body()
34
+
35
+ # Prepare the headers, excluding those that can cause issues
36
+ headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "content-length"]}
37
+
38
+ async with httpx.AsyncClient(event_hooks={'response': [hook]}, timeout=timeout_config) as client:
39
+ # Send the request and get the response as a stream
40
+ req = client.build_request("POST", f"{BACKEND_BASE_URL}/{path}", content=body, headers=headers)
41
+
42
+ try:
43
+ response = await client.send(req, stream=True)
44
+ response.raise_for_status()
45
+
46
+ if json.loads(body.decode('utf-8'))['stream']:
47
+ # Custom streaming function
48
+ async def stream_response(response):
49
+ async for chunk in response.aiter_bytes():
50
+ yield chunk
51
+ await response.aclose() # Ensure the response is closed after streaming
52
+
53
+ return StreamingResponse(stream_response(response),
54
+ status_code=response.status_code,
55
+ headers=headers)
56
+ else: # For regular JSON responses
57
+ # For non-streaming responses, read the complete response body
58
+ content = await response.aread()
59
+ return JSONResponse(content=content, status_code=response.status_code)
60
+ except httpx.ResponseNotRead as exc:
61
+ print(f"HTTP Exception for {exc.request.url} - {exc}")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ uvicorn.run(app,
66
+ host='127.0.0.1',
67
+ port=7860,
68
+ log_level="debug",
69
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
70
+
aiproxy/chatgpt.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import json
3
+ import logging
4
+ import os
5
+ import time
6
+ import traceback
7
+ from typing import List, Union, AsyncGenerator
8
+ from uuid import uuid4
9
+ from fastapi import FastAPI, Request
10
+ from fastapi.responses import JSONResponse
11
+ from sse_starlette.sse import EventSourceResponse, AsyncContentStream
12
+ from openai import AsyncClient, APIStatusError, APIResponseValidationError, APIError, OpenAIError
13
+ from openai.types.chat import ChatCompletion
14
+ import tiktoken
15
+ from .proxy import ProxyBase, RequestFilterBase, ResponseFilterBase, RequestFilterException, ResponseFilterException
16
+ from .accesslog import _AccessLogBase, RequestItemBase, ResponseItemBase, StreamChunkItemBase, ErrorItemBase
17
+ from .queueclient import QueueClientBase
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ChatGPTRequestItem(RequestItemBase):
23
+ def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
24
+ request_headers_copy = self.request_headers.copy()
25
+ if auth := request_headers_copy.get("authorization"):
26
+ request_headers_copy["authorization"] = auth[:12] + "*****" + auth[-2:]
27
+
28
+ content = self.request_json["messages"][-1]["content"]
29
+ if isinstance(content, list):
30
+ for c in content:
31
+ if c["type"] == "text":
32
+ content = c["text"]
33
+ break
34
+ else:
35
+ content = json.dumps(content)
36
+
37
+ accesslog = accesslog_cls(
38
+ request_id=self.request_id,
39
+ created_at=datetime.utcnow(),
40
+ direction="request",
41
+ content=content,
42
+ raw_body=json.dumps(self.request_json, ensure_ascii=False),
43
+ raw_headers=json.dumps(request_headers_copy, ensure_ascii=False),
44
+ model=self.request_json.get("model")
45
+ )
46
+
47
+ return accesslog
48
+
49
+
50
+ class ChatGPTResponseItem(ResponseItemBase):
51
+ def to_accesslog(self, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
52
+ content = self.response_json["choices"][0]["message"].get("content")
53
+ function_call = self.response_json["choices"][0]["message"].get("function_call")
54
+ tool_calls = self.response_json["choices"][0]["message"].get("tool_calls")
55
+ response_headers = json.dumps(dict(self.response_headers.items()),
56
+ ensure_ascii=False) if self.response_headers is not None else None
57
+ model = self.response_json["model"]
58
+ prompt_tokens = self.response_json["usage"]["prompt_tokens"]
59
+ completion_tokens = self.response_json["usage"]["completion_tokens"]
60
+
61
+ return accesslog_cls(
62
+ request_id=self.request_id,
63
+ created_at=datetime.utcnow(),
64
+ direction="response",
65
+ status_code=self.status_code,
66
+ content=content,
67
+ function_call=json.dumps(function_call, ensure_ascii=False) if function_call is not None else None,
68
+ tool_calls=json.dumps(tool_calls, ensure_ascii=False) if tool_calls is not None else None,
69
+ raw_body=json.dumps(self.response_json, ensure_ascii=False),
70
+ raw_headers=response_headers,
71
+ model=model,
72
+ prompt_tokens=prompt_tokens,
73
+ completion_tokens=completion_tokens,
74
+ request_time=self.duration,
75
+ request_time_api=self.duration_api
76
+ )
77
+
78
+
79
+ token_encoder = tiktoken.get_encoding("cl100k_base")
80
+
81
+
82
+ def count_token(content: str):
83
+ return len(token_encoder.encode(content))
84
+
85
+
86
+ def count_request_token(request_json: dict):
87
+ tokens_per_message = 3
88
+ tokens_per_name = 1
89
+ token_count = 0
90
+
91
+ # messages
92
+ for m in request_json["messages"]:
93
+ token_count += tokens_per_message
94
+ for k, v in m.items():
95
+ if isinstance(v, list):
96
+ for c in v:
97
+ if c.get("type") == "text":
98
+ token_count += count_token(c["text"])
99
+ else:
100
+ token_count += count_token(v)
101
+ if k == "name":
102
+ token_count += tokens_per_name
103
+
104
+ # functions
105
+ if functions := request_json.get("functions"):
106
+ for f in functions:
107
+ token_count += count_token(json.dumps(f))
108
+
109
+ # function_call
110
+ if function_call := request_json.get("function_call"):
111
+ if isinstance(function_call, dict):
112
+ token_count += count_token(json.dumps(function_call))
113
+ else:
114
+ token_count += count_token(function_call)
115
+
116
+ # tools
117
+ if tools := request_json.get("tools"):
118
+ for t in tools:
119
+ token_count += count_token(json.dumps(t))
120
+
121
+ if tool_choice := request_json.get("tool_choice"):
122
+ token_count += count_token(json.dumps(tool_choice))
123
+
124
+ token_count += 3
125
+ return token_count
126
+
127
+
128
+ class ChatGPTStreamResponseItem(StreamChunkItemBase):
129
+ def to_accesslog(self, chunks: list, accesslog_cls: _AccessLogBase) -> _AccessLogBase:
130
+ chunk_jsons = []
131
+ response_content = ""
132
+ function_call = None
133
+ tool_calls = None
134
+ prompt_tokens = 0
135
+ completion_tokens = 0
136
+
137
+ # Parse info from chunks
138
+ for chunk in chunks:
139
+ chunk_jsons.append(chunk.chunk_json)
140
+
141
+ if len(chunk.chunk_json["choices"]) == 0:
142
+ # Azure returns the first delta with empty choices
143
+ continue
144
+
145
+ delta = chunk.chunk_json["choices"][0]["delta"]
146
+
147
+ # Make tool_calls
148
+ if delta.get("tool_calls"):
149
+ if tool_calls is None:
150
+ tool_calls = []
151
+ if delta["tool_calls"][0]["function"].get("name"):
152
+ tool_calls.append({
153
+ "type": "function",
154
+ "function": {
155
+ "name": delta["tool_calls"][0]["function"]["name"],
156
+ "arguments": ""
157
+ }
158
+ })
159
+ elif delta["tool_calls"][0]["function"].get("arguments"):
160
+ tool_calls[-1]["function"]["arguments"] += delta["tool_calls"][0]["function"].get("arguments") or ""
161
+
162
+ # Make function_call
163
+ elif delta.get("function_call"):
164
+ if function_call is None:
165
+ function_call = {}
166
+ if delta["function_call"].get("name"):
167
+ function_call["name"] = delta["function_call"]["name"]
168
+ function_call["arguments"] = ""
169
+ elif delta["function_call"].get("arguments"):
170
+ function_call["arguments"] += delta["function_call"]["arguments"]
171
+
172
+ # Text content
173
+ else:
174
+ response_content += delta.get("content") or ""
175
+
176
+ # Serialize
177
+ function_call_str = json.dumps(function_call, ensure_ascii=False) if function_call is not None else None
178
+ tool_calls_str = json.dumps(tool_calls, ensure_ascii=False) if tool_calls is not None else None
179
+ response_headers = json.dumps(dict(self.response_headers.items()),
180
+ ensure_ascii=False) if self.response_headers is not None else None
181
+
182
+ # Count tokens
183
+ prompt_tokens = count_request_token(self.request_json)
184
+
185
+ if tool_calls_str:
186
+ completion_tokens = count_token(tool_calls_str)
187
+ elif function_call_str:
188
+ completion_tokens = count_token(function_call_str)
189
+ else:
190
+ completion_tokens = count_token(response_content)
191
+
192
+ return accesslog_cls(
193
+ request_id=self.request_id,
194
+ created_at=datetime.utcnow(),
195
+ direction="response",
196
+ status_code=self.status_code,
197
+ content=response_content,
198
+ function_call=function_call_str,
199
+ tool_calls=tool_calls_str,
200
+ raw_body=json.dumps(chunk_jsons, ensure_ascii=False),
201
+ raw_headers=response_headers,
202
+ model=chunk_jsons[0]["model"],
203
+ prompt_tokens=prompt_tokens,
204
+ completion_tokens=completion_tokens,
205
+ request_time=self.duration,
206
+ request_time_api=self.duration_api
207
+ )
208
+
209
+
210
+ class ChatGPTErrorItem(ErrorItemBase):
211
+ ...
212
+
213
+
214
+ queue_item_types = [ChatGPTRequestItem, ChatGPTResponseItem, ChatGPTStreamResponseItem, ChatGPTErrorItem]
215
+
216
+
217
+ # Reverse aiproxy application for ChatGPT
218
+ class ChatGPTProxy(ProxyBase):
219
+ _empty_openai_api_key = "OPENAI_API_KEY_IS_NOT_SET"
220
+
221
+ def __init__(
222
+ self,
223
+ *,
224
+ base_url: str = None,
225
+ api_key: str = None,
226
+ async_client: AsyncClient = None,
227
+ max_retries: int = 0,
228
+ timeout: float = 60.0,
229
+ request_filters: List[RequestFilterBase] = None,
230
+ response_filters: List[ResponseFilterBase] = None,
231
+ request_item_class: type = ChatGPTRequestItem,
232
+ response_item_class: type = ChatGPTResponseItem,
233
+ stream_response_item_class: type = ChatGPTStreamResponseItem,
234
+ error_item_class: type = ChatGPTErrorItem,
235
+ access_logger_queue: QueueClientBase,
236
+ ):
237
+ super().__init__(
238
+ request_filters=request_filters,
239
+ response_filters=response_filters,
240
+ access_logger_queue=access_logger_queue
241
+ )
242
+
243
+ # Log items
244
+ self.request_item_class = request_item_class
245
+ self.response_item_class = response_item_class
246
+ self.stream_response_item_class = stream_response_item_class
247
+ self.error_item_class = error_item_class
248
+
249
+ # ChatGPT client config
250
+ self.base_url = base_url
251
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY") or self._empty_openai_api_key
252
+ self.max_retries = max_retries
253
+ self.timeout = timeout
254
+ self.async_client = async_client
255
+
256
+ async def filter_request(self, request_id: str, request_json: dict, request_headers: dict) -> Union[
257
+ dict, JSONResponse, EventSourceResponse]:
258
+ for f in self.request_filters:
259
+ if json_resp := await f.filter(request_id, request_json, request_headers):
260
+ # Return response if filter returns string
261
+ resp_for_log = {
262
+ "id": "-",
263
+ "choices": [
264
+ {"message": {"role": "assistant", "content": json_resp}, "finish_reason": "stop", "index": 0}],
265
+ "created": 0,
266
+ "model": "request_filter",
267
+ "object": "chat.completion",
268
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
269
+ }
270
+ # Response log
271
+ self.access_logger_queue.put(self.response_item_class(
272
+ request_id=request_id,
273
+ response_json=resp_for_log,
274
+ status_code=200
275
+ ))
276
+
277
+ if request_json.get("stream"):
278
+ # Stream
279
+ async def filter_response_stream(content: str):
280
+ # First delta
281
+ resp = {
282
+ "id": "-",
283
+ "choices": [
284
+ {"delta": {"role": "assistant", "content": ""}, "finish_reason": None, "index": 0}],
285
+ "created": 0,
286
+ "model": "request_filter",
287
+ "object": "chat.completion",
288
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
289
+ }
290
+ yield json.dumps(resp)
291
+ # Last delta
292
+ resp["choices"][0] = {"delta": {"content": content}, "finish_reason": "stop", "index": 0}
293
+ yield json.dumps(resp)
294
+
295
+ return self.return_response_with_headers(EventSourceResponse(
296
+ filter_response_stream(json_resp)
297
+ ), request_id)
298
+
299
+ else:
300
+ # Non-stream
301
+ return self.return_response_with_headers(JSONResponse(resp_for_log), request_id)
302
+
303
+ return request_json
304
+
305
+ def get_client(self):
306
+ return self.async_client or AsyncClient(
307
+ base_url=self.base_url,
308
+ api_key=self.api_key,
309
+ max_retries=self.max_retries,
310
+ timeout=self.timeout
311
+ )
312
+
313
+ async def filter_response(self, request_id: str, response: ChatCompletion) -> ChatCompletion:
314
+ response_json = response.model_dump()
315
+
316
+ for f in self.response_filters:
317
+ if json_resp := await f.filter(request_id, response_json):
318
+ return response.model_validate(json_resp)
319
+
320
+ return response.model_validate(response_json)
321
+
322
+ def return_response_with_headers(self, resp: JSONResponse, request_id: str):
323
+ self.add_response_headers(response=resp, request_id=request_id)
324
+ return resp
325
+
326
+ def add_route(self, app: FastAPI, base_url: str):
327
+ @app.post(base_url)
328
+ async def handle_request(request: Request):
329
+ request_id = str(uuid4())
330
+ async_client = None
331
+
332
+ try:
333
+ start_time = time.time()
334
+ request_json = await request.json()
335
+ request_headers = dict(request.headers.items())
336
+
337
+ # Log request
338
+ self.access_logger_queue.put(self.request_item_class(
339
+ request_id=request_id,
340
+ request_json=request_json,
341
+ request_headers=request_headers
342
+ ))
343
+
344
+ # Filter request
345
+ request_json = await self.filter_request(request_id, request_json, request_headers)
346
+ if isinstance(request_json, JSONResponse) or isinstance(request_json, EventSourceResponse):
347
+ return request_json
348
+
349
+ # Call API
350
+ async_client = self.get_client()
351
+ start_time_api = time.time()
352
+ if self.api_key != self._empty_openai_api_key:
353
+ # Always use server api key if set to client
354
+ raw_response = await async_client.chat.completions.with_raw_response.create(**request_json)
355
+ elif user_auth_header := request_headers.get("authorization"): # Lower case from client.
356
+ raw_response = await async_client.chat.completions.with_raw_response.create(
357
+ **request_json, extra_headers={"Authorization": user_auth_header} # Pascal to server
358
+ )
359
+ else:
360
+ # Call API anyway ;)
361
+ raw_response = await async_client.chat.completions.with_raw_response.create(**request_json)
362
+
363
+ completion_response = raw_response.parse()
364
+ completion_response_headers = raw_response.headers
365
+ completion_status_code = raw_response.status_code
366
+ if "content-encoding" in completion_response_headers:
367
+ completion_response_headers.pop(
368
+ "content-encoding") # Remove "br" that will be changed by this aiproxy
369
+
370
+ # Handling response from API
371
+ if request_json.get("stream"):
372
+ async def process_stream(stream: AsyncContentStream) -> AsyncGenerator[str, None]:
373
+ # Async content generator
374
+ try:
375
+ async for chunk in stream:
376
+ self.access_logger_queue.put(self.stream_response_item_class(
377
+ request_id=request_id,
378
+ chunk_json=chunk.model_dump()
379
+ ))
380
+ if chunk:
381
+ yield chunk.model_dump_json()
382
+
383
+ finally:
384
+ # Close client after reading stream
385
+ await async_client.close()
386
+
387
+ # Response log
388
+ now = time.time()
389
+ self.access_logger_queue.put(self.stream_response_item_class(
390
+ request_id=request_id,
391
+ response_headers=completion_response_headers,
392
+ duration=now - start_time,
393
+ duration_api=now - start_time_api,
394
+ request_json=request_json,
395
+ status_code=completion_status_code
396
+ ))
397
+
398
+ return self.return_response_with_headers(EventSourceResponse(
399
+ process_stream(completion_response),
400
+ headers=completion_response_headers
401
+ ), request_id)
402
+
403
+ else:
404
+ # Close client immediately
405
+ await async_client.close()
406
+
407
+ duration_api = time.time() - start_time_api
408
+
409
+ # Filter response
410
+ completion_response = await self.filter_response(request_id, completion_response)
411
+
412
+ # Response log
413
+ self.access_logger_queue.put(self.response_item_class(
414
+ request_id=request_id,
415
+ response_json=completion_response.model_dump(),
416
+ response_headers=completion_response_headers,
417
+ duration=time.time() - start_time,
418
+ duration_api=duration_api,
419
+ status_code=completion_status_code
420
+ ))
421
+
422
+ return self.return_response_with_headers(JSONResponse(
423
+ content=completion_response.model_dump(),
424
+ headers=completion_response_headers
425
+ ), request_id)
426
+
427
+ # Error handlers
428
+ except RequestFilterException as rfex:
429
+ logger.error(f"Request filter error: {rfex}\n{traceback.format_exc()}")
430
+
431
+ resp_json = {
432
+ "error": {"message": rfex.message, "type": "request_filter_error", "param": None, "code": None}}
433
+
434
+ # Error log
435
+ self.access_logger_queue.put(self.error_item_class(
436
+ request_id=request_id,
437
+ exception=rfex,
438
+ traceback_info=traceback.format_exc(),
439
+ response_json=resp_json,
440
+ status_code=rfex.status_code
441
+ ))
442
+
443
+ return self.return_response_with_headers(JSONResponse(resp_json, status_code=rfex.status_code),
444
+ request_id)
445
+
446
+ except ResponseFilterException as rfex:
447
+ logger.error(f"Response filter error: {rfex}\n{traceback.format_exc()}")
448
+
449
+ resp_json = {
450
+ "error": {"message": rfex.message, "type": "response_filter_error", "param": None, "code": None}}
451
+
452
+ # Error log
453
+ self.access_logger_queue.put(self.error_item_class(
454
+ request_id=request_id,
455
+ exception=rfex,
456
+ traceback_info=traceback.format_exc(),
457
+ response_json=resp_json,
458
+ status_code=rfex.status_code
459
+ ))
460
+
461
+ return self.return_response_with_headers(JSONResponse(resp_json, status_code=rfex.status_code),
462
+ request_id)
463
+
464
+ except (APIStatusError, APIResponseValidationError) as status_err:
465
+ logger.error(f"APIStatusError from ChatGPT: {status_err}\n{traceback.format_exc()}")
466
+
467
+ # Error log
468
+ try:
469
+ resp_json = status_err.response.json()
470
+ except:
471
+ resp_json = str(status_err.response.content)
472
+
473
+ self.access_logger_queue.put(self.error_item_class(
474
+ request_id=request_id,
475
+ exception=status_err,
476
+ traceback_info=traceback.format_exc(),
477
+ response_json=resp_json,
478
+ status_code=status_err.status_code
479
+ ))
480
+
481
+ return self.return_response_with_headers(JSONResponse(resp_json, status_code=status_err.status_code),
482
+ request_id)
483
+
484
+ except APIError as api_err:
485
+ logger.error(f"APIError from ChatGPT: {api_err}\n{traceback.format_exc()}")
486
+
487
+ resp_json = {"error": {"message": api_err.message, "type": api_err.type, "param": api_err.param,
488
+ "code": api_err.code}}
489
+
490
+ # Error log
491
+ self.access_logger_queue.put(self.error_item_class(
492
+ request_id=request_id,
493
+ exception=api_err,
494
+ traceback_info=traceback.format_exc(),
495
+ response_json=resp_json,
496
+ status_code=502
497
+ ))
498
+
499
+ return self.return_response_with_headers(JSONResponse(resp_json, status_code=502), request_id)
500
+
501
+ except OpenAIError as oai_err:
502
+ logger.error(f"OpenAIError: {oai_err}\n{traceback.format_exc()}")
503
+
504
+ resp_json = {"error": {"message": str(oai_err), "type": "openai_error", "param": None, "code": None}}
505
+
506
+ # Error log
507
+ self.access_logger_queue.put(self.error_item_class(
508
+ request_id=request_id,
509
+ exception=oai_err,
510
+ traceback_info=traceback.format_exc(),
511
+ response_json=resp_json,
512
+ status_code=502
513
+ ))
514
+
515
+ return self.return_response_with_headers(JSONResponse(resp_json, status_code=502), request_id)
516
+
517
+ except Exception as ex:
518
+ logger.error(f"Error at server: {ex}\n{traceback.format_exc()}")
519
+
520
+ resp_json = {"error": {"message": "Proxy error", "type": "proxy_error", "param": None, "code": None}}
521
+
522
+ # Error log
523
+ self.access_logger_queue.put(self.error_item_class(
524
+ request_id=request_id,
525
+ exception=ex,
526
+ traceback_info=traceback.format_exc(),
527
+ response_json=resp_json,
528
+ status_code=502
529
+ ))
530
+
531
+ return self.return_response_with_headers(JSONResponse(resp_json, status_code=502), request_id)
aiproxy/proxy.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import logging
3
+ from typing import List, Union
4
+ from fastapi import FastAPI
5
+ from fastapi.responses import Response
6
+ from aiproxy.queueclient import QueueClientBase
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ # Classes for filter
13
+ class RequestFilterBase(ABC):
14
+ @abstractmethod
15
+ async def filter(self, request_id: str, request_json: dict, request_headers: dict) -> Union[str, None]:
16
+ ...
17
+
18
+
19
+ class ResponseFilterBase(ABC):
20
+ @abstractmethod
21
+ async def filter(self, request_id: str, response_json: dict) -> Union[dict, None]:
22
+ ...
23
+
24
+
25
+ class FilterException(Exception):
26
+ def __init__(self, message: str, status_code: int = 400) -> None:
27
+ self.message = message
28
+ self.status_code = status_code
29
+
30
+
31
+ class RequestFilterException(FilterException): ...
32
+
33
+
34
+ class ResponseFilterException(FilterException): ...
35
+
36
+
37
+ class ProxyBase(ABC):
38
+ def __init__(
39
+ self,
40
+ *,
41
+ request_filters: List[RequestFilterBase] = None,
42
+ response_filters: List[ResponseFilterBase] = None,
43
+ access_logger_queue: QueueClientBase
44
+ ):
45
+ # Filters
46
+ self.request_filters = request_filters or []
47
+ self.response_filters = response_filters or []
48
+
49
+ # Access logger queue
50
+ self.access_logger_queue = access_logger_queue
51
+
52
+ def add_filter(self, filter: Union[RequestFilterBase, ResponseFilterBase]):
53
+ if isinstance(filter, RequestFilterBase):
54
+ self.request_filters.append(filter)
55
+ logger.info(f"request filter: {filter.__class__.__name__}")
56
+ elif isinstance(filter, ResponseFilterBase):
57
+ self.response_filters.append(filter)
58
+ logger.info(f"response filter: {filter.__class__.__name__}")
59
+ else:
60
+ logger.warning(f"Invalid filter: {filter.__class__.__name__}")
61
+
62
+ def add_response_headers(self, response: Response, request_id: str, headers: dict = None):
63
+ response.headers["X-AIProxy-Request-Id"] = request_id
64
+ if headers:
65
+ for k, v in headers.items():
66
+ response.headers[k] = v
67
+
68
+ @abstractmethod
69
+ def add_route(self, app: FastAPI, base_url: str):
70
+ ...
71
+
72
+ # @abstractmethod
73
+ # def add_completion_route(self, app: FastAPI, base_url: str):
74
+ # ...
aiproxy/queueclient.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import json
3
+ from queue import Queue
4
+ from typing import Iterator
5
+
6
+
7
+ class QueueItemBase(ABC):
8
+ def to_dict(self) -> dict:
9
+ d = self.__dict__
10
+ d["type"] = self.__class__.__name__
11
+ return d
12
+
13
+ def to_json(self) -> str:
14
+ return json.dumps(self.to_dict())
15
+
16
+ @classmethod
17
+ def from_dict(cls, d: dict):
18
+ _d = d.copy()
19
+ del _d["type"]
20
+ return cls(**_d)
21
+
22
+ @classmethod
23
+ def from_json(cls, json_str: str):
24
+ return cls.from_dict(json.loads(json_str))
25
+
26
+
27
+ class QueueClientBase(ABC):
28
+ dequeue_interval = 0.5
29
+
30
+ @abstractmethod
31
+ def put(self, item: QueueItemBase):
32
+ ...
33
+
34
+ @abstractmethod
35
+ def get(self) -> Iterator[QueueItemBase]:
36
+ ...
37
+
38
+
39
+ class DefaultQueueClient(QueueClientBase):
40
+ def __init__(self) -> None:
41
+ self.queue = Queue()
42
+ self.dequeue_interval = 0.5
43
+
44
+ def put(self, item: QueueItemBase):
45
+ self.queue.put(item)
46
+
47
+ def get(self) -> Iterator[QueueItemBase]:
48
+ items = []
49
+ while not self.queue.empty():
50
+ items.append(self.queue.get())
51
+ return iter(items)