Niansuh commited on
Commit
ca47a97
·
verified ·
1 Parent(s): 19df56c

Update api/utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +49 -43
api/utils.py CHANGED
@@ -1,31 +1,35 @@
 
 
1
  from datetime import datetime
2
  import json
3
  import uuid
4
  import asyncio
5
  import random
6
  import string
7
- from typing import Any, Dict, Optional
8
 
9
  import httpx
10
  from fastapi import HTTPException
11
  from api.config import (
12
- MODEL_ALIASES,
13
  get_headers_api_chat,
14
  get_headers_chat,
15
  BASE_URL,
16
  AGENT_MODE,
17
  TRENDING_AGENT_MODE,
18
  MODEL_PREFIXES,
19
- API_ENDPOINT,
20
- generate_id,
21
- MODELS
22
  )
23
- from api.models import ChatRequest
24
  from api.logger import setup_logger
25
- from api.image import ImageResponse, to_data_uri # Assuming image utilities are here
26
 
27
  logger = setup_logger(__name__)
28
 
 
 
 
 
 
29
  # Helper function to create chat completion data
30
  def create_chat_completion_data(
31
  content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
@@ -46,10 +50,18 @@ def create_chat_completion_data(
46
  }
47
 
48
  # Function to convert message to dictionary format, ensuring base64 data and optional model prefix
49
- def message_to_dict(message, model_prefix: Optional[str] = None):
50
- content = message.content if isinstance(message.content, str) else message.content[0]["text"]
 
 
 
 
 
 
51
  if model_prefix:
52
  content = f"{model_prefix} {content}"
 
 
53
  if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
54
  # Ensure base64 images are always included for all models
55
  return {
@@ -73,25 +85,24 @@ def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
73
 
74
  # Function to get the correct referer URL for logging
75
  def get_referer_url(chat_id: str, model: str) -> str:
76
- """Generate the referer URL based on specific models listed in MODELS."""
77
- return f"{BASE_URL}/chat/{chat_id}?model={model}"
 
 
78
 
79
  # Process streaming response with headers from config.py
80
- async def process_streaming_response(request: ChatRequest, web_search: bool = False):
81
- chat_id = generate_id()
82
  referer_url = get_referer_url(chat_id, request.model)
83
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
84
 
85
- # Resolve model aliases
86
- resolved_model = MODEL_ALIASES.get(request.model, request.model)
 
87
 
88
- agent_mode = AGENT_MODE.get(resolved_model, {})
89
- trending_agent_mode = TRENDING_AGENT_MODE.get(resolved_model, {})
90
- model_prefix = MODEL_PREFIXES.get(resolved_model, "")
91
 
92
- headers_api_chat = get_headers_api_chat()
93
-
94
- if resolved_model == 'o1-preview':
95
  delay_seconds = random.randint(1, 60)
96
  logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview' (Chat ID: {chat_id})")
97
  await asyncio.sleep(delay_seconds)
@@ -114,18 +125,17 @@ async def process_streaming_response(request: ChatRequest, web_search: bool = Fa
114
  "previewToken": None,
115
  "trendingAgentMode": trending_agent_mode,
116
  "userId": None,
117
- "userSelectedModel": resolved_model if resolved_model in MODELS else None,
118
  "userSystemPrompt": None,
119
  "validated": "69783381-2ce4-4dbd-ac78-35e9063feabc",
120
  "visitFromDelta": False,
121
- "webSearchMode": web_search, # Include web search mode
122
  }
123
 
124
  async with httpx.AsyncClient() as client:
125
  try:
126
  async with client.stream(
127
  "POST",
128
- API_ENDPOINT,
129
  headers=headers_api_chat,
130
  json=json_data,
131
  timeout=100,
@@ -138,9 +148,9 @@ async def process_streaming_response(request: ChatRequest, web_search: bool = Fa
138
  if content.startswith("$@$v=undefined-rv1$@$"):
139
  content = content[21:]
140
  cleaned_content = strip_model_prefix(content, model_prefix)
141
- yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, resolved_model, timestamp))}\n\n"
142
 
143
- yield f"data: {json.dumps(create_chat_completion_data('', resolved_model, timestamp, 'stop'))}\n\n"
144
  yield "data: [DONE]\n\n"
145
  except httpx.HTTPStatusError as e:
146
  logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
@@ -150,21 +160,19 @@ async def process_streaming_response(request: ChatRequest, web_search: bool = Fa
150
  raise HTTPException(status_code=500, detail=str(e))
151
 
152
  # Process non-streaming response with headers from config.py
153
- async def process_non_streaming_response(request: ChatRequest, web_search: bool = False):
154
- chat_id = generate_id()
155
  referer_url = get_referer_url(chat_id, request.model)
156
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
157
 
158
- # Resolve model aliases
159
- resolved_model = MODEL_ALIASES.get(request.model, request.model)
160
-
161
- agent_mode = AGENT_MODE.get(resolved_model, {})
162
- trending_agent_mode = TRENDING_AGENT_MODE.get(resolved_model, {})
163
- model_prefix = MODEL_PREFIXES.get(resolved_model, "")
164
 
165
- headers_api_chat = get_headers_api_chat()
 
166
 
167
- if resolved_model == 'o1-preview':
168
  delay_seconds = random.randint(20, 60)
169
  logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview' (Chat ID: {chat_id})")
170
  await asyncio.sleep(delay_seconds)
@@ -187,11 +195,10 @@ async def process_non_streaming_response(request: ChatRequest, web_search: bool
187
  "previewToken": None,
188
  "trendingAgentMode": trending_agent_mode,
189
  "userId": None,
190
- "userSelectedModel": resolved_model if resolved_model in MODELS else None,
191
  "userSystemPrompt": None,
192
  "validated": "69783381-2ce4-4dbd-ac78-35e9063feabc",
193
  "visitFromDelta": False,
194
- "webSearchMode": web_search, # Include web search mode
195
  }
196
 
197
  full_response = ""
@@ -199,10 +206,9 @@ async def process_non_streaming_response(request: ChatRequest, web_search: bool
199
  try:
200
  async with client.stream(
201
  method="POST",
202
- url=API_ENDPOINT,
203
  headers=headers_api_chat,
204
- json=json_data,
205
- timeout=100,
206
  ) as response:
207
  response.raise_for_status()
208
  async for chunk in response.aiter_text():
@@ -213,7 +219,7 @@ async def process_non_streaming_response(request: ChatRequest, web_search: bool
213
  except httpx.RequestError as e:
214
  logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
215
  raise HTTPException(status_code=500, detail=str(e))
216
-
217
  if full_response.startswith("$@$v=undefined-rv1$@$"):
218
  full_response = full_response[21:]
219
 
@@ -223,7 +229,7 @@ async def process_non_streaming_response(request: ChatRequest, web_search: bool
223
  "id": f"chatcmpl-{uuid.uuid4()}",
224
  "object": "chat.completion",
225
  "created": int(datetime.now().timestamp()),
226
- "model": resolved_model,
227
  "choices": [
228
  {
229
  "index": 0,
 
1
+ # api/utils.py
2
+
3
  from datetime import datetime
4
  import json
5
  import uuid
6
  import asyncio
7
  import random
8
  import string
9
+ from typing import Any, Dict, Optional, List, Union
10
 
11
  import httpx
12
  from fastapi import HTTPException
13
  from api.config import (
14
+ MODEL_MAPPING,
15
  get_headers_api_chat,
16
  get_headers_chat,
17
  BASE_URL,
18
  AGENT_MODE,
19
  TRENDING_AGENT_MODE,
20
  MODEL_PREFIXES,
21
+ MODEL_REFERERS
 
 
22
  )
23
+ from api.models import ChatRequest, Message
24
  from api.logger import setup_logger
 
25
 
26
  logger = setup_logger(__name__)
27
 
28
+ # Helper function to create a random alphanumeric chat ID
29
+ def generate_chat_id(length: int = 7) -> str:
30
+ characters = string.ascii_letters + string.digits
31
+ return ''.join(random.choices(characters, k=length))
32
+
33
  # Helper function to create chat completion data
34
  def create_chat_completion_data(
35
  content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
 
50
  }
51
 
52
  # Function to convert message to dictionary format, ensuring base64 data and optional model prefix
53
+ def message_to_dict(message: Message, model_prefix: Optional[str] = None) -> Dict[str, Any]:
54
+ if isinstance(message.content, str):
55
+ content = message.content
56
+ elif isinstance(message.content, list) and len(message.content) > 0:
57
+ content = message.content[0].get("text", "")
58
+ else:
59
+ content = ""
60
+
61
  if model_prefix:
62
  content = f"{model_prefix} {content}"
63
+
64
+ # Handle image content
65
  if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
66
  # Ensure base64 images are always included for all models
67
  return {
 
85
 
86
  # Function to get the correct referer URL for logging
87
  def get_referer_url(chat_id: str, model: str) -> str:
88
+ """Generate the referer URL based on specific models listed in MODEL_REFERERS."""
89
+ if model in MODEL_REFERERS:
90
+ return f"{BASE_URL}/chat/{chat_id}?model={model}"
91
+ return BASE_URL
92
 
93
  # Process streaming response with headers from config.py
94
+ async def process_streaming_response(request: ChatRequest):
95
+ chat_id = generate_chat_id()
96
  referer_url = get_referer_url(chat_id, request.model)
97
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
98
 
99
+ agent_mode = AGENT_MODE.get(request.model, {})
100
+ trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
101
+ model_prefix = MODEL_PREFIXES.get(request.model, "")
102
 
103
+ headers_api_chat = get_headers_api_chat(referer_url)
 
 
104
 
105
+ if request.model == 'o1-preview':
 
 
106
  delay_seconds = random.randint(1, 60)
107
  logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview' (Chat ID: {chat_id})")
108
  await asyncio.sleep(delay_seconds)
 
125
  "previewToken": None,
126
  "trendingAgentMode": trending_agent_mode,
127
  "userId": None,
128
+ "userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
129
  "userSystemPrompt": None,
130
  "validated": "69783381-2ce4-4dbd-ac78-35e9063feabc",
131
  "visitFromDelta": False,
 
132
  }
133
 
134
  async with httpx.AsyncClient() as client:
135
  try:
136
  async with client.stream(
137
  "POST",
138
+ f"{BASE_URL}/api/chat",
139
  headers=headers_api_chat,
140
  json=json_data,
141
  timeout=100,
 
148
  if content.startswith("$@$v=undefined-rv1$@$"):
149
  content = content[21:]
150
  cleaned_content = strip_model_prefix(content, model_prefix)
151
+ yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, request.model, timestamp))}\n\n"
152
 
153
+ yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
154
  yield "data: [DONE]\n\n"
155
  except httpx.HTTPStatusError as e:
156
  logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
 
160
  raise HTTPException(status_code=500, detail=str(e))
161
 
162
  # Process non-streaming response with headers from config.py
163
+ async def process_non_streaming_response(request: ChatRequest):
164
+ chat_id = generate_chat_id()
165
  referer_url = get_referer_url(chat_id, request.model)
166
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
167
 
168
+ agent_mode = AGENT_MODE.get(request.model, {})
169
+ trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
170
+ model_prefix = MODEL_PREFIXES.get(request.model, "")
 
 
 
171
 
172
+ headers_api_chat = get_headers_api_chat(referer_url)
173
+ headers_chat = get_headers_chat(referer_url, next_action=str(uuid.uuid4()), next_router_state_tree=json.dumps([""]))
174
 
175
+ if request.model == 'o1-preview':
176
  delay_seconds = random.randint(20, 60)
177
  logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview' (Chat ID: {chat_id})")
178
  await asyncio.sleep(delay_seconds)
 
195
  "previewToken": None,
196
  "trendingAgentMode": trending_agent_mode,
197
  "userId": None,
198
+ "userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
199
  "userSystemPrompt": None,
200
  "validated": "69783381-2ce4-4dbd-ac78-35e9063feabc",
201
  "visitFromDelta": False,
 
202
  }
203
 
204
  full_response = ""
 
206
  try:
207
  async with client.stream(
208
  method="POST",
209
+ url=f"{BASE_URL}/api/chat",
210
  headers=headers_api_chat,
211
+ json=json_data
 
212
  ) as response:
213
  response.raise_for_status()
214
  async for chunk in response.aiter_text():
 
219
  except httpx.RequestError as e:
220
  logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
221
  raise HTTPException(status_code=500, detail=str(e))
222
+
223
  if full_response.startswith("$@$v=undefined-rv1$@$"):
224
  full_response = full_response[21:]
225
 
 
229
  "id": f"chatcmpl-{uuid.uuid4()}",
230
  "object": "chat.completion",
231
  "created": int(datetime.now().timestamp()),
232
+ "model": request.model,
233
  "choices": [
234
  {
235
  "index": 0,