Update api/utils.py
Browse files- api/utils.py +51 -17
api/utils.py
CHANGED
@@ -147,6 +147,7 @@ def create_chat_completion_data(
|
|
147 |
prompt_tokens: int = 0,
|
148 |
completion_tokens: int = 0,
|
149 |
finish_reason: Optional[str] = None,
|
|
|
150 |
) -> Dict[str, Any]:
|
151 |
usage = None
|
152 |
if finish_reason == "stop":
|
@@ -161,15 +162,28 @@ def create_chat_completion_data(
|
|
161 |
"created": timestamp,
|
162 |
"model": model,
|
163 |
"system_fingerprint": system_fingerprint,
|
164 |
-
"choices": [{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
"usage": usage,
|
166 |
}
|
167 |
|
168 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
"""
|
170 |
Convert a ChatRequest message to a dict for the request payload.
|
171 |
-
Supports
|
172 |
-
Prepends model_prefix to text content if specified.
|
173 |
"""
|
174 |
content = ""
|
175 |
images_data = []
|
@@ -242,13 +256,12 @@ async def process_streaming_response(request: ChatRequest):
|
|
242 |
logger.error("No h-value for validation.")
|
243 |
raise HTTPException(status_code=500, detail="Missing h-value.")
|
244 |
|
245 |
-
messages = [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages]
|
246 |
|
247 |
json_data = {
|
248 |
"agentMode": agent_mode,
|
249 |
"clickedAnswer2": False,
|
250 |
"clickedAnswer3": False,
|
251 |
-
"reasoningMode": False,
|
252 |
"clickedForceWebSearch": False,
|
253 |
"codeInterpreterMode": False,
|
254 |
"codeModelMode": True,
|
@@ -275,15 +288,16 @@ async def process_streaming_response(request: ChatRequest):
|
|
275 |
"visitFromDelta": False,
|
276 |
"webSearchModePrompt": False,
|
277 |
"vscodeClient": False,
|
278 |
-
"designerMode": False,
|
279 |
-
"workspaceId": "",
|
280 |
-
"beastMode": False,
|
281 |
"customProfile": {"name": "", "occupation": "", "traits": [], "additionalInfo": "", "enableNewChats": False},
|
282 |
"webSearchModeOption": {"autoMode": False, "webMode": False, "offlineMode": True},
|
283 |
"session": {
|
284 |
-
"user": {"name": random_name, "email":
|
285 |
"expires": datetime.now(timezone.utc).isoformat(timespec='milliseconds').replace('+00:00', 'Z'),
|
286 |
"subscriptionCache": {"customerId": random_customer_id, "status": "PREMIUM", "isTrialSubscription": "False", "expiryTimestamp": 1744652408, "lastChecked": int(time.time() * 1000)},
|
|
|
|
|
|
|
|
|
287 |
},
|
288 |
}
|
289 |
|
@@ -320,7 +334,27 @@ async def process_streaming_response(request: ChatRequest):
|
|
320 |
final_snapzion_links.extend(snapzion_urls)
|
321 |
cleaned_content = strip_model_prefix(chunk, model_prefix)
|
322 |
completion_tokens += calculate_tokens(cleaned_content, request.model)
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
yield "data: " + json.dumps(create_chat_completion_data("", request.model, timestamp, request_id, system_fingerprint, prompt_tokens, completion_tokens, "stop")) + "\n\n"
|
325 |
yield "data: [DONE]\n\n"
|
326 |
except httpx.HTTPStatusError as e:
|
@@ -374,13 +408,12 @@ async def process_non_streaming_response(request: ChatRequest):
|
|
374 |
logger.error("Failed to retrieve h-value.")
|
375 |
raise HTTPException(status_code=500, detail="Missing h-value.")
|
376 |
|
377 |
-
messages = [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages]
|
378 |
|
379 |
json_data = {
|
380 |
"agentMode": agent_mode,
|
381 |
"clickedAnswer2": False,
|
382 |
"clickedAnswer3": False,
|
383 |
-
"reasoningMode": False,
|
384 |
"clickedForceWebSearch": False,
|
385 |
"codeInterpreterMode": False,
|
386 |
"codeModelMode": True,
|
@@ -407,15 +440,16 @@ async def process_non_streaming_response(request: ChatRequest):
|
|
407 |
"visitFromDelta": False,
|
408 |
"webSearchModePrompt": False,
|
409 |
"vscodeClient": False,
|
410 |
-
"designerMode": False,
|
411 |
-
"workspaceId": "",
|
412 |
-
"beastMode": False,
|
413 |
"customProfile": {"name": "", "occupation": "", "traits": [], "additionalInfo": "", "enableNewChats": False},
|
414 |
"webSearchModeOption": {"autoMode": False, "webMode": False, "offlineMode": True},
|
415 |
"session": {
|
416 |
-
"user": {"name": random_name, "email":
|
417 |
"expires": datetime.now(timezone.utc).isoformat(timespec='milliseconds').replace('+00:00', 'Z'),
|
418 |
"subscriptionCache": {"customerId": random_customer_id, "status": "PREMIUM", "isTrialSubscription": "False", "expiryTimestamp": 1744652408, "lastChecked": int(time.time() * 1000)},
|
|
|
|
|
|
|
|
|
419 |
},
|
420 |
}
|
421 |
|
|
|
147 |
prompt_tokens: int = 0,
|
148 |
completion_tokens: int = 0,
|
149 |
finish_reason: Optional[str] = None,
|
150 |
+
function_call: Optional[Dict] = None,
|
151 |
) -> Dict[str, Any]:
|
152 |
usage = None
|
153 |
if finish_reason == "stop":
|
|
|
162 |
"created": timestamp,
|
163 |
"model": model,
|
164 |
"system_fingerprint": system_fingerprint,
|
165 |
+
"choices": [{
|
166 |
+
"index": 0,
|
167 |
+
"delta": {
|
168 |
+
"content": content if not function_call else None,
|
169 |
+
"role": "assistant",
|
170 |
+
"function_call": function_call
|
171 |
+
},
|
172 |
+
"finish_reason": finish_reason
|
173 |
+
}],
|
174 |
"usage": usage,
|
175 |
}
|
176 |
|
177 |
+
async def handle_function_call(request: ChatRequest, function_call_details: Dict) -> Dict[str, Any]:
|
178 |
+
# Placeholder for function calling logic
|
179 |
+
logger.info(f"Handling function call for model: {request.model}")
|
180 |
+
logger.info(f"Function call details: {function_call_details}")
|
181 |
+
return {"message": "Function call handled successfully", "details": function_call_details}
|
182 |
+
|
183 |
+
def message_to_dict(message, model_prefix: Optional[str] = None, tools: Optional[List[Dict]] = None) -> Dict[str, Any]:
|
184 |
"""
|
185 |
Convert a ChatRequest message to a dict for the request payload.
|
186 |
+
Supports function calling, images, and model prefixes.
|
|
|
187 |
"""
|
188 |
content = ""
|
189 |
images_data = []
|
|
|
256 |
logger.error("No h-value for validation.")
|
257 |
raise HTTPException(status_code=500, detail="Missing h-value.")
|
258 |
|
259 |
+
messages = [message_to_dict(msg, model_prefix=model_prefix, tools=request.tools) for msg in request.messages]
|
260 |
|
261 |
json_data = {
|
262 |
"agentMode": agent_mode,
|
263 |
"clickedAnswer2": False,
|
264 |
"clickedAnswer3": False,
|
|
|
265 |
"clickedForceWebSearch": False,
|
266 |
"codeInterpreterMode": False,
|
267 |
"codeModelMode": True,
|
|
|
288 |
"visitFromDelta": False,
|
289 |
"webSearchModePrompt": False,
|
290 |
"vscodeClient": False,
|
|
|
|
|
|
|
291 |
"customProfile": {"name": "", "occupation": "", "traits": [], "additionalInfo": "", "enableNewChats": False},
|
292 |
"webSearchModeOption": {"autoMode": False, "webMode": False, "offlineMode": True},
|
293 |
"session": {
|
294 |
+
"user": {"name": random_name, "email": random_email, "image": "https://lh3.googleusercontent.com/a/...=s96-c", "subscriptionStatus": "PREMIUM"},
|
295 |
"expires": datetime.now(timezone.utc).isoformat(timespec='milliseconds').replace('+00:00', 'Z'),
|
296 |
"subscriptionCache": {"customerId": random_customer_id, "status": "PREMIUM", "isTrialSubscription": "False", "expiryTimestamp": 1744652408, "lastChecked": int(time.time() * 1000)},
|
297 |
+
"beastMode": False,
|
298 |
+
"reasoningMode": False,
|
299 |
+
"designerMode": False,
|
300 |
+
"workspaceId": "",
|
301 |
},
|
302 |
}
|
303 |
|
|
|
334 |
final_snapzion_links.extend(snapzion_urls)
|
335 |
cleaned_content = strip_model_prefix(chunk, model_prefix)
|
336 |
completion_tokens += calculate_tokens(cleaned_content, request.model)
|
337 |
+
|
338 |
+
# Handle function call responses
|
339 |
+
function_call = None
|
340 |
+
if cleaned_content and cleaned_content.startswith("{"):
|
341 |
+
try:
|
342 |
+
function_call = json.loads(cleaned_content)
|
343 |
+
cleaned_content = None # Content must be null for function calls
|
344 |
+
except json.JSONDecodeError:
|
345 |
+
pass
|
346 |
+
|
347 |
+
yield "data: " + json.dumps(create_chat_completion_data(
|
348 |
+
cleaned_content,
|
349 |
+
request.model,
|
350 |
+
timestamp,
|
351 |
+
request_id,
|
352 |
+
system_fingerprint,
|
353 |
+
prompt_tokens,
|
354 |
+
completion_tokens,
|
355 |
+
finish_reason=None,
|
356 |
+
function_call=function_call
|
357 |
+
)) + "\n\n"
|
358 |
yield "data: " + json.dumps(create_chat_completion_data("", request.model, timestamp, request_id, system_fingerprint, prompt_tokens, completion_tokens, "stop")) + "\n\n"
|
359 |
yield "data: [DONE]\n\n"
|
360 |
except httpx.HTTPStatusError as e:
|
|
|
408 |
logger.error("Failed to retrieve h-value.")
|
409 |
raise HTTPException(status_code=500, detail="Missing h-value.")
|
410 |
|
411 |
+
messages = [message_to_dict(msg, model_prefix=model_prefix, tools=request.tools) for msg in request.messages]
|
412 |
|
413 |
json_data = {
|
414 |
"agentMode": agent_mode,
|
415 |
"clickedAnswer2": False,
|
416 |
"clickedAnswer3": False,
|
|
|
417 |
"clickedForceWebSearch": False,
|
418 |
"codeInterpreterMode": False,
|
419 |
"codeModelMode": True,
|
|
|
440 |
"visitFromDelta": False,
|
441 |
"webSearchModePrompt": False,
|
442 |
"vscodeClient": False,
|
|
|
|
|
|
|
443 |
"customProfile": {"name": "", "occupation": "", "traits": [], "additionalInfo": "", "enableNewChats": False},
|
444 |
"webSearchModeOption": {"autoMode": False, "webMode": False, "offlineMode": True},
|
445 |
"session": {
|
446 |
+
"user": {"name": random_name, "email": random_email, "image": "https://lh3.googleusercontent.com/a/...=s96-c", "subscriptionStatus": "PREMIUM"},
|
447 |
"expires": datetime.now(timezone.utc).isoformat(timespec='milliseconds').replace('+00:00', 'Z'),
|
448 |
"subscriptionCache": {"customerId": random_customer_id, "status": "PREMIUM", "isTrialSubscription": "False", "expiryTimestamp": 1744652408, "lastChecked": int(time.time() * 1000)},
|
449 |
+
"beastMode": False,
|
450 |
+
"reasoningMode": False,
|
451 |
+
"designerMode": False,
|
452 |
+
"workspaceId": "",
|
453 |
},
|
454 |
}
|
455 |
|