Update main.py
Browse files
main.py
CHANGED
|
@@ -10,8 +10,8 @@ from aiohttp import ClientSession, ClientTimeout, ClientError
|
|
| 10 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header, status
|
| 11 |
from fastapi.responses import StreamingResponse, JSONResponse
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
-
from pydantic import BaseModel, Field
|
| 14 |
-
from typing import List, Dict, Any, Optional, AsyncGenerator
|
| 15 |
from datetime import datetime
|
| 16 |
from slowapi import Limiter, _rate_limit_exceeded_handler
|
| 17 |
from slowapi.util import get_remote_address
|
|
@@ -100,7 +100,7 @@ def to_data_uri(image: Any) -> str:
|
|
| 100 |
return "data:image/png;base64,..." # Replace with actual base64 data if needed
|
| 101 |
|
| 102 |
# Token Counting using tiktoken
|
| 103 |
-
def count_tokens(messages: List[Dict[str,
|
| 104 |
"""
|
| 105 |
Counts the number of tokens in the messages using tiktoken.
|
| 106 |
Adjust the encoding based on the model.
|
|
@@ -111,7 +111,14 @@ def count_tokens(messages: List[Dict[str, str]], model: str) -> int:
|
|
| 111 |
encoding = tiktoken.get_encoding("cl100k_base") # Default encoding
|
| 112 |
tokens = 0
|
| 113 |
for message in messages:
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
return tokens
|
| 116 |
|
| 117 |
# Blackbox Class: Handles interaction with the external AI service
|
|
@@ -235,7 +242,7 @@ class Blackbox:
|
|
| 235 |
async def create_async_generator(
|
| 236 |
cls,
|
| 237 |
model: str,
|
| 238 |
-
messages: List[Dict[str,
|
| 239 |
proxy: Optional[str] = None,
|
| 240 |
image: Any = None,
|
| 241 |
image_name: Optional[str] = None,
|
|
@@ -269,22 +276,33 @@ class Blackbox:
|
|
| 269 |
|
| 270 |
if model in cls.model_prefixes:
|
| 271 |
prefix = cls.model_prefixes[model]
|
| 272 |
-
if
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
messages[0]['content'] = f"{prefix} {messages[0]['content']}"
|
| 275 |
-
|
| 276 |
random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
|
| 277 |
-
|
| 278 |
-
messages
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
'
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
data = {
|
| 289 |
"messages": messages,
|
| 290 |
"id": random_id,
|
|
@@ -337,36 +355,13 @@ class Blackbox:
|
|
| 337 |
logger.error("Image URL not found in the response.")
|
| 338 |
raise Exception("Image URL not found in the response")
|
| 339 |
else:
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
if
|
| 345 |
-
decoded_chunk
|
| 346 |
-
|
| 347 |
-
if decoded_chunk.strip():
|
| 348 |
-
if '$~~~$' in decoded_chunk:
|
| 349 |
-
search_results_json += decoded_chunk
|
| 350 |
-
else:
|
| 351 |
-
full_response += decoded_chunk
|
| 352 |
-
yield decoded_chunk
|
| 353 |
-
logger.info("Finished streaming response chunks.")
|
| 354 |
-
except Exception as e:
|
| 355 |
-
logger.exception("Error while iterating over response chunks.")
|
| 356 |
-
raise e
|
| 357 |
-
if data["webSearchMode"] and search_results_json:
|
| 358 |
-
match = re.search(r'\$~~~\$(.*?)\$~~~\$', search_results_json, re.DOTALL)
|
| 359 |
-
if match:
|
| 360 |
-
try:
|
| 361 |
-
search_results = json.loads(match.group(1))
|
| 362 |
-
formatted_results = "\n\n**Sources:**\n"
|
| 363 |
-
for i, result in enumerate(search_results[:5], 1):
|
| 364 |
-
formatted_results += f"{i}. [{result['title']}]({result['link']})\n"
|
| 365 |
-
logger.info("Formatted search results.")
|
| 366 |
-
yield formatted_results
|
| 367 |
-
except json.JSONDecodeError as je:
|
| 368 |
-
logger.error("Failed to parse search results JSON.")
|
| 369 |
-
raise je
|
| 370 |
except ClientError as ce:
|
| 371 |
logger.error(f"Client error occurred: {ce}. Retrying attempt {attempt + 1}/{retry_attempts}")
|
| 372 |
if attempt == retry_attempts - 1:
|
|
@@ -381,9 +376,28 @@ class Blackbox:
|
|
| 381 |
raise HTTPException(status_code=500, detail=str(e))
|
| 382 |
|
| 383 |
# Pydantic Models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
class Message(BaseModel):
|
| 385 |
role: str = Field(..., description="The role of the message author.")
|
| 386 |
-
content: str = Field(..., description="The content of the message.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
class ChatRequest(BaseModel):
|
| 389 |
model: str = Field(..., description="ID of the model to use.")
|
|
@@ -431,12 +445,26 @@ async def chat_completions(
|
|
| 431 |
):
|
| 432 |
logger.info(f"Received chat completions request: {chat_request}")
|
| 433 |
try:
|
| 434 |
-
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
async_generator = Blackbox.create_async_generator(
|
| 438 |
model=chat_request.model,
|
| 439 |
-
messages=
|
| 440 |
image=None, # Adjust if image handling is required
|
| 441 |
image_name=None,
|
| 442 |
webSearchMode=chat_request.webSearchMode
|
|
|
|
| 10 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header, status
|
| 11 |
from fastapi.responses import StreamingResponse, JSONResponse
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
+
from pydantic import BaseModel, Field, validator
|
| 14 |
+
from typing import List, Dict, Any, Optional, Union, AsyncGenerator
|
| 15 |
from datetime import datetime
|
| 16 |
from slowapi import Limiter, _rate_limit_exceeded_handler
|
| 17 |
from slowapi.util import get_remote_address
|
|
|
|
| 100 |
return "data:image/png;base64,..." # Replace with actual base64 data if needed
|
| 101 |
|
| 102 |
# Token Counting using tiktoken
|
| 103 |
+
def count_tokens(messages: List[Dict[str, Any]], model: str) -> int:
|
| 104 |
"""
|
| 105 |
Counts the number of tokens in the messages using tiktoken.
|
| 106 |
Adjust the encoding based on the model.
|
|
|
|
| 111 |
encoding = tiktoken.get_encoding("cl100k_base") # Default encoding
|
| 112 |
tokens = 0
|
| 113 |
for message in messages:
|
| 114 |
+
if isinstance(message['content'], list):
|
| 115 |
+
for content_part in message['content']:
|
| 116 |
+
if content_part.get('type') == 'text':
|
| 117 |
+
tokens += len(encoding.encode(content_part['text']))
|
| 118 |
+
elif content_part.get('type') == 'image_url':
|
| 119 |
+
tokens += len(encoding.encode(content_part['image_url']['url']))
|
| 120 |
+
else:
|
| 121 |
+
tokens += len(encoding.encode(message['content']))
|
| 122 |
return tokens
|
| 123 |
|
| 124 |
# Blackbox Class: Handles interaction with the external AI service
|
|
|
|
| 242 |
async def create_async_generator(
|
| 243 |
cls,
|
| 244 |
model: str,
|
| 245 |
+
messages: List[Dict[str, Any]],
|
| 246 |
proxy: Optional[str] = None,
|
| 247 |
image: Any = None,
|
| 248 |
image_name: Optional[str] = None,
|
|
|
|
| 276 |
|
| 277 |
if model in cls.model_prefixes:
|
| 278 |
prefix = cls.model_prefixes[model]
|
| 279 |
+
if messages and isinstance(messages[0]['content'], list):
|
| 280 |
+
# Prepend prefix to the first text message
|
| 281 |
+
for content_part in messages[0]['content']:
|
| 282 |
+
if content_part.get('type') == 'text' and not content_part['text'].startswith(prefix):
|
| 283 |
+
logger.debug(f"Adding prefix '{prefix}' to the first text message.")
|
| 284 |
+
content_part['text'] = f"{prefix} {content_part['text']}"
|
| 285 |
+
break
|
| 286 |
+
elif messages and isinstance(messages[0]['content'], str) and not messages[0]['content'].startswith(prefix):
|
| 287 |
messages[0]['content'] = f"{prefix} {messages[0]['content']}"
|
| 288 |
+
|
| 289 |
random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
|
| 290 |
+
# Assuming the last message is from the user
|
| 291 |
+
if messages:
|
| 292 |
+
last_message = messages[-1]
|
| 293 |
+
if isinstance(last_message['content'], list):
|
| 294 |
+
for content_part in last_message['content']:
|
| 295 |
+
if content_part.get('type') == 'text':
|
| 296 |
+
content_part['role'] = 'user'
|
| 297 |
+
else:
|
| 298 |
+
last_message['id'] = random_id
|
| 299 |
+
last_message['role'] = 'user'
|
| 300 |
|
| 301 |
+
if image is not None:
|
| 302 |
+
# Process image if required
|
| 303 |
+
# This implementation assumes that image URLs are handled by the external service
|
| 304 |
+
pass # Implement as needed
|
| 305 |
+
|
| 306 |
data = {
|
| 307 |
"messages": messages,
|
| 308 |
"id": random_id,
|
|
|
|
| 355 |
logger.error("Image URL not found in the response.")
|
| 356 |
raise Exception("Image URL not found in the response")
|
| 357 |
else:
|
| 358 |
+
async for chunk in response.content.iter_chunks():
|
| 359 |
+
if chunk:
|
| 360 |
+
decoded_chunk = chunk.decode(errors='ignore')
|
| 361 |
+
decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
|
| 362 |
+
if decoded_chunk.strip():
|
| 363 |
+
yield decoded_chunk
|
| 364 |
+
break # Exit the retry loop if successful
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
except ClientError as ce:
|
| 366 |
logger.error(f"Client error occurred: {ce}. Retrying attempt {attempt + 1}/{retry_attempts}")
|
| 367 |
if attempt == retry_attempts - 1:
|
|
|
|
| 376 |
raise HTTPException(status_code=500, detail=str(e))
|
| 377 |
|
| 378 |
# Pydantic Models
|
| 379 |
+
class TextContent(BaseModel):
|
| 380 |
+
type: str = Field(..., description="Type of content, e.g., 'text'.")
|
| 381 |
+
text: str = Field(..., description="The text content.")
|
| 382 |
+
|
| 383 |
+
class ImageURLContent(BaseModel):
|
| 384 |
+
type: str = Field(..., description="Type of content, e.g., 'image_url'.")
|
| 385 |
+
image_url: Dict[str, str] = Field(..., description="Dictionary containing the image URL.")
|
| 386 |
+
|
| 387 |
+
Content = Union[TextContent, ImageURLContent]
|
| 388 |
+
|
| 389 |
class Message(BaseModel):
|
| 390 |
role: str = Field(..., description="The role of the message author.")
|
| 391 |
+
content: Union[str, List[Content]] = Field(..., description="The content of the message. Can be a string or a list of content parts.")
|
| 392 |
+
|
| 393 |
+
@validator('content', pre=True)
|
| 394 |
+
def validate_content(cls, v):
|
| 395 |
+
if isinstance(v, list):
|
| 396 |
+
return [Content(**item) for item in v]
|
| 397 |
+
elif isinstance(v, str):
|
| 398 |
+
return v
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError("Content must be either a string or a list of content parts.")
|
| 401 |
|
| 402 |
class ChatRequest(BaseModel):
|
| 403 |
model: str = Field(..., description="ID of the model to use.")
|
|
|
|
| 445 |
):
|
| 446 |
logger.info(f"Received chat completions request: {chat_request}")
|
| 447 |
try:
|
| 448 |
+
# Process messages for token counting and sending to Blackbox
|
| 449 |
+
processed_messages = []
|
| 450 |
+
for msg in chat_request.messages:
|
| 451 |
+
if isinstance(msg.content, list):
|
| 452 |
+
# Convert list of content parts to a structured format
|
| 453 |
+
combined_content = []
|
| 454 |
+
for part in msg.content:
|
| 455 |
+
if isinstance(part, TextContent):
|
| 456 |
+
combined_content.append({"type": part.type, "text": part.text})
|
| 457 |
+
elif isinstance(part, ImageURLContent):
|
| 458 |
+
combined_content.append({"type": part.type, "image_url": part.image_url})
|
| 459 |
+
processed_messages.append({"role": msg.role, "content": combined_content})
|
| 460 |
+
else:
|
| 461 |
+
processed_messages.append({"role": msg.role, "content": msg.content})
|
| 462 |
+
|
| 463 |
+
prompt_tokens = count_tokens(processed_messages, chat_request.model)
|
| 464 |
|
| 465 |
async_generator = Blackbox.create_async_generator(
|
| 466 |
model=chat_request.model,
|
| 467 |
+
messages=processed_messages,
|
| 468 |
image=None, # Adjust if image handling is required
|
| 469 |
image_name=None,
|
| 470 |
webSearchMode=chat_request.webSearchMode
|