Niansuh commited on
Commit
7e61071
·
verified ·
1 Parent(s): 65e5192

Update api/utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +41 -52
api/utils.py CHANGED
@@ -5,8 +5,6 @@ import asyncio
5
  import random
6
  import string
7
  from typing import Any, Dict, Optional
8
- import base64
9
- import os
10
 
11
  import httpx
12
  from fastapi import HTTPException
@@ -48,39 +46,29 @@ def create_chat_completion_data(
48
  "usage": None,
49
  }
50
 
51
- # Function to convert message to dictionary format, ensuring base64 data and optional model prefix
52
  def message_to_dict(message, model_prefix: Optional[str] = None):
53
  content = message.content if isinstance(message.content, str) else message.content[0]["text"]
54
  if model_prefix:
55
  content = f"{model_prefix} {content}"
56
-
57
- # Handle image data in the message
58
- if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
59
- # Ensure base64 images are always included for all models
60
- return {
61
- "role": message.role,
62
- "content": content,
63
- "data": {
64
- "imagesData": [
65
- {
66
- "filePath": message.content[1].get("filePath", ""),
67
- "contents": message.content[1].get("contents", "")
68
- }
69
- ],
70
- "fileText": "",
71
- "title": "snapshot",
72
- },
73
- }
74
- return {"role": message.role, "content": content}
75
 
76
- # Function to convert image file to base64
77
- def image_to_base64(image_path: str) -> str:
78
- try:
79
- with open(image_path, "rb") as image_file:
80
- return base64.b64encode(image_file.read()).decode('utf-8')
81
- except Exception as e:
82
- logger.error(f"Error reading image {image_path}: {e}")
83
- return ""
 
 
 
 
 
 
 
 
 
84
 
85
  # Process streaming response with headers from config.py
86
  async def process_streaming_response(request: ChatRequest):
@@ -106,24 +94,23 @@ async def process_streaming_response(request: ChatRequest):
106
  logger.error("Failed to retrieve h-value for validation.")
107
  raise HTTPException(status_code=500, detail="Validation failed due to missing h-value.")
108
 
109
- # Prepare the image payload (if any)
110
- messages = []
111
- for msg in request.messages:
112
- message_dict = message_to_dict(msg, model_prefix=model_prefix)
113
- messages.append(message_dict)
114
 
115
  json_data = {
116
  "agentMode": agent_mode,
117
  "clickedAnswer2": False,
118
  "clickedAnswer3": False,
119
  "clickedForceWebSearch": False,
120
- "codeModelMode": False,
121
  "githubToken": None,
122
- "id": None, # Using request_id instead of chat_id
123
  "isChromeExt": False,
124
  "isMicMode": False,
125
  "maxTokens": request.max_tokens,
126
- "messages": messages,
127
  "mobileClient": False,
128
  "playgroundTemperature": request.temperature,
129
  "playgroundTopP": request.top_p,
@@ -135,6 +122,7 @@ async def process_streaming_response(request: ChatRequest):
135
  "validated": h_value, # Dynamically set the validated field
136
  "visitFromDelta": False,
137
  "webSearchModePrompt": False,
 
138
  }
139
 
140
  async with httpx.AsyncClient() as client:
@@ -196,24 +184,23 @@ async def process_non_streaming_response(request: ChatRequest):
196
  logger.error("Failed to retrieve h-value for validation.")
197
  raise HTTPException(status_code=500, detail="Validation failed due to missing h-value.")
198
 
199
- # Prepare the image payload (if any)
200
- messages = []
201
- for msg in request.messages:
202
- message_dict = message_to_dict(msg, model_prefix=model_prefix)
203
- messages.append(message_dict)
204
 
205
  json_data = {
206
  "agentMode": agent_mode,
207
  "clickedAnswer2": False,
208
  "clickedAnswer3": False,
209
  "clickedForceWebSearch": False,
210
- "codeModelMode": False,
211
  "githubToken": None,
212
- "id": None, # Using request_id instead of chat_id
213
  "isChromeExt": False,
214
  "isMicMode": False,
215
  "maxTokens": request.max_tokens,
216
- "messages": messages,
217
  "mobileClient": False,
218
  "playgroundTemperature": request.temperature,
219
  "playgroundTopP": request.top_p,
@@ -225,27 +212,29 @@ async def process_non_streaming_response(request: ChatRequest):
225
  "validated": h_value, # Dynamically set the validated field
226
  "visitFromDelta": False,
227
  "webSearchModePrompt": False,
 
228
  }
229
 
 
230
  async with httpx.AsyncClient() as client:
231
  try:
232
- async with client.post(
233
- f"{BASE_URL}/api/chat", headers=headers_api_chat, json=json_data
234
  ) as response:
235
  response.raise_for_status()
236
- full_response = await response.text()
237
-
238
  except httpx.HTTPStatusError as e:
239
  logger.error(f"HTTP error occurred for Request ID {request_id}: {e}")
240
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
241
  except httpx.RequestError as e:
242
  logger.error(f"Error occurred during request for Request ID {request_id}: {e}")
243
  raise HTTPException(status_code=500, detail=str(e))
244
-
245
- # Clean up the response and return it
246
  if full_response.startswith("$@$v=undefined-rv1$@$"):
247
  full_response = full_response[21:]
248
 
 
249
  if BLOCKED_MESSAGE in full_response:
250
  logger.info(f"Blocked message detected in response for Request ID {request_id}.")
251
  full_response = full_response.replace(BLOCKED_MESSAGE, '').strip()
 
5
  import random
6
  import string
7
  from typing import Any, Dict, Optional
 
 
8
 
9
  import httpx
10
  from fastapi import HTTPException
 
46
  "usage": None,
47
  }
48
 
49
+ # Function to convert message to dictionary format, including data and id if present
50
  def message_to_dict(message, model_prefix: Optional[str] = None):
51
  content = message.content if isinstance(message.content, str) else message.content[0]["text"]
52
  if model_prefix:
53
  content = f"{model_prefix} {content}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ message_dict = {"role": message.role, "content": content}
56
+
57
+ if hasattr(message, 'id') and message.id:
58
+ message_dict['id'] = message.id
59
+
60
+ if hasattr(message, 'data') and message.data:
61
+ message_dict['data'] = message.data
62
+
63
+ return message_dict
64
+
65
+ # Function to strip model prefix from content if present
66
+ def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
67
+ """Remove the model prefix from the response content if present."""
68
+ if model_prefix and content.startswith(model_prefix):
69
+ logger.debug(f"Stripping prefix '{model_prefix}' from content.")
70
+ return content[len(model_prefix):].strip()
71
+ return content
72
 
73
  # Process streaming response with headers from config.py
74
  async def process_streaming_response(request: ChatRequest):
 
94
  logger.error("Failed to retrieve h-value for validation.")
95
  raise HTTPException(status_code=500, detail="Validation failed due to missing h-value.")
96
 
97
+ # Determine if images are included in messages
98
+ code_model_mode = any(
99
+ hasattr(msg, 'data') and msg.data and 'imagesData' in msg.data for msg in request.messages
100
+ )
 
101
 
102
  json_data = {
103
  "agentMode": agent_mode,
104
  "clickedAnswer2": False,
105
  "clickedAnswer3": False,
106
  "clickedForceWebSearch": False,
107
+ "codeModelMode": code_model_mode,
108
  "githubToken": None,
109
+ "id": request_id,
110
  "isChromeExt": False,
111
  "isMicMode": False,
112
  "maxTokens": request.max_tokens,
113
+ "messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages],
114
  "mobileClient": False,
115
  "playgroundTemperature": request.temperature,
116
  "playgroundTopP": request.top_p,
 
122
  "validated": h_value, # Dynamically set the validated field
123
  "visitFromDelta": False,
124
  "webSearchModePrompt": False,
125
+ "imageGenerationMode": False,
126
  }
127
 
128
  async with httpx.AsyncClient() as client:
 
184
  logger.error("Failed to retrieve h-value for validation.")
185
  raise HTTPException(status_code=500, detail="Validation failed due to missing h-value.")
186
 
187
+ # Determine if images are included in messages
188
+ code_model_mode = any(
189
+ hasattr(msg, 'data') and msg.data and 'imagesData' in msg.data for msg in request.messages
190
+ )
 
191
 
192
  json_data = {
193
  "agentMode": agent_mode,
194
  "clickedAnswer2": False,
195
  "clickedAnswer3": False,
196
  "clickedForceWebSearch": False,
197
+ "codeModelMode": code_model_mode,
198
  "githubToken": None,
199
+ "id": request_id,
200
  "isChromeExt": False,
201
  "isMicMode": False,
202
  "maxTokens": request.max_tokens,
203
+ "messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages],
204
  "mobileClient": False,
205
  "playgroundTemperature": request.temperature,
206
  "playgroundTopP": request.top_p,
 
212
  "validated": h_value, # Dynamically set the validated field
213
  "visitFromDelta": False,
214
  "webSearchModePrompt": False,
215
+ "imageGenerationMode": False,
216
  }
217
 
218
+ full_response = ""
219
  async with httpx.AsyncClient() as client:
220
  try:
221
+ async with client.stream(
222
+ method="POST", url=f"{BASE_URL}/api/chat", headers=headers_api_chat, json=json_data
223
  ) as response:
224
  response.raise_for_status()
225
+ async for chunk in response.aiter_text():
226
+ full_response += chunk
227
  except httpx.HTTPStatusError as e:
228
  logger.error(f"HTTP error occurred for Request ID {request_id}: {e}")
229
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
230
  except httpx.RequestError as e:
231
  logger.error(f"Error occurred during request for Request ID {request_id}: {e}")
232
  raise HTTPException(status_code=500, detail=str(e))
233
+
 
234
  if full_response.startswith("$@$v=undefined-rv1$@$"):
235
  full_response = full_response[21:]
236
 
237
+ # Remove the blocked message if present
238
  if BLOCKED_MESSAGE in full_response:
239
  logger.info(f"Blocked message detected in response for Request ID {request_id}.")
240
  full_response = full_response.replace(BLOCKED_MESSAGE, '').strip()