Niansuh commited on
Commit
e25a40c
·
verified ·
1 Parent(s): 2814bc0

Update api/utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +36 -48
api/utils.py CHANGED
@@ -4,21 +4,18 @@ import uuid
4
  import asyncio
5
  import random
6
  import string
7
- from typing import Any, Dict, Optional, List
8
 
9
  import httpx
10
  from fastapi import HTTPException
11
  from api.config import (
12
- MODEL_MAPPING,
13
- AGENT_MODE,
14
- TRENDING_AGENT_MODE,
15
- get_headers,
 
16
  BASE_URL,
17
  )
18
- from api.models import ChatRequest
19
- from api.logger import setup_logger
20
-
21
- logger = setup_logger(__name__)
22
 
23
  # Helper function to create a random alphanumeric chat ID
24
  def generate_chat_id(length: int = 7) -> str:
@@ -44,10 +41,11 @@ def create_chat_completion_data(
44
  "usage": None,
45
  }
46
 
47
- # Function to convert message to dictionary format for API
48
- def message_to_dict(message) -> Dict[str, Any]:
49
  content = message.content if isinstance(message.content, str) else message.content[0]["text"]
50
  if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
 
51
  return {
52
  "role": message.role,
53
  "content": content,
@@ -59,37 +57,21 @@ def message_to_dict(message) -> Dict[str, Any]:
59
  }
60
  return {"role": message.role, "content": content}
61
 
62
- # Function to retrieve agent modes for a specific model
63
- def get_agent_modes(model: str) -> Dict[str, Any]:
64
- """Returns specific agent mode configurations for agent models."""
65
- agent_mode = AGENT_MODE.get(model, {})
66
- trending_agent_mode = TRENDING_AGENT_MODE.get(model, {})
67
-
68
- if agent_mode or trending_agent_mode:
69
- logger.info(f"Applying agent configurations for model '{model}'")
70
- else:
71
- logger.info(f"Model '{model}' is not an agent model; defaulting to standard mode")
72
-
73
- return {
74
- "agentMode": agent_mode,
75
- "trendingAgentMode": trending_agent_mode,
76
- }
77
-
78
  # Process streaming response with headers from config.py
79
- async def process_streaming_response(request: ChatRequest):
80
  chat_id = generate_chat_id()
81
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model}")
82
 
83
- # Check if the model is an agent model and apply corresponding agent modes
84
- agent_modes = get_agent_modes(request.model)
85
- headers = get_headers()
 
86
 
87
  json_data = {
88
- "agentMode": agent_modes["agentMode"],
89
- "trendingAgentMode": agent_modes["trendingAgentMode"],
90
  "clickedAnswer2": False,
91
  "clickedAnswer3": False,
92
- "clickedForceWebSearch": False,
93
  "codeModelMode": True,
94
  "githubToken": None,
95
  "id": chat_id,
@@ -101,20 +83,20 @@ async def process_streaming_response(request: ChatRequest):
101
  "playgroundTemperature": request.temperature,
102
  "playgroundTopP": request.top_p,
103
  "previewToken": None,
 
104
  "userId": None,
105
- "userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
106
  "userSystemPrompt": None,
107
  "validated": "69783381-2ce4-4dbd-ac78-35e9063feabc",
108
  "visitFromDelta": False,
109
  }
110
 
111
- logger.debug(f"Data Payload for {request.model}: {json_data}") # Inspect payload for accuracy
112
  async with httpx.AsyncClient() as client:
113
  try:
114
  async with client.stream(
115
  "POST",
116
  f"{BASE_URL}/api/chat",
117
- headers=headers,
118
  json=json_data,
119
  timeout=100,
120
  ) as response:
@@ -122,7 +104,11 @@ async def process_streaming_response(request: ChatRequest):
122
  async for line in response.aiter_lines():
123
  timestamp = int(datetime.now().timestamp())
124
  if line:
125
- yield f"data: {json.dumps(create_chat_completion_data(line, request.model, timestamp))}\n\n"
 
 
 
 
126
  yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
127
  yield "data: [DONE]\n\n"
128
  except httpx.HTTPStatusError as e:
@@ -133,17 +119,17 @@ async def process_streaming_response(request: ChatRequest):
133
  raise HTTPException(status_code=500, detail=str(e))
134
 
135
  # Process non-streaming response with headers from config.py
136
- async def process_non_streaming_response(request: ChatRequest):
137
  chat_id = generate_chat_id()
138
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model}")
139
 
140
- # Retrieve agent modes based on the model
141
- agent_modes = get_agent_modes(request.model)
142
- headers = get_headers()
 
143
 
144
  json_data = {
145
- "agentMode": agent_modes["agentMode"],
146
- "trendingAgentMode": agent_modes["trendingAgentMode"],
147
  "clickedAnswer2": False,
148
  "clickedAnswer3": False,
149
  "clickedForceWebSearch": False,
@@ -158,19 +144,19 @@ async def process_non_streaming_response(request: ChatRequest):
158
  "playgroundTemperature": request.temperature,
159
  "playgroundTopP": request.top_p,
160
  "previewToken": None,
 
161
  "userId": None,
162
- "userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
163
  "userSystemPrompt": None,
164
  "validated": "69783381-2ce4-4dbd-ac78-35e9063feabc",
165
  "visitFromDelta": False,
166
  }
167
 
168
- logger.debug(f"Data Payload for {request.model}: {json_data}") # Inspect payload for accuracy
169
  full_response = ""
170
  async with httpx.AsyncClient() as client:
171
  try:
172
  async with client.stream(
173
- method="POST", url=f"{BASE_URL}/api/chat", headers=headers, json=json_data
174
  ) as response:
175
  response.raise_for_status()
176
  async for chunk in response.aiter_text():
@@ -181,6 +167,8 @@ async def process_non_streaming_response(request: ChatRequest):
181
  except httpx.RequestError as e:
182
  logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
183
  raise HTTPException(status_code=500, detail=str(e))
 
 
184
 
185
  return {
186
  "id": f"chatcmpl-{uuid.uuid4()}",
@@ -195,4 +183,4 @@ async def process_non_streaming_response(request: ChatRequest):
195
  }
196
  ],
197
  "usage": None,
198
- }
 
4
  import asyncio
5
  import random
6
  import string
7
+ from typing import Any, Dict, Optional
8
 
9
  import httpx
10
  from fastapi import HTTPException
11
  from api.config import (
12
+ models,
13
+ model_aliases,
14
+ agentMode,
15
+ trendingAgentMode,
16
+ get_headers_api_chat,
17
  BASE_URL,
18
  )
 
 
 
 
19
 
20
  # Helper function to create a random alphanumeric chat ID
21
  def generate_chat_id(length: int = 7) -> str:
 
41
  "usage": None,
42
  }
43
 
44
+ # Function to convert message to dictionary format
45
+ def message_to_dict(message):
46
  content = message.content if isinstance(message.content, str) else message.content[0]["text"]
47
  if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
48
+ # Ensure base64 images are always included for all models
49
  return {
50
  "role": message.role,
51
  "content": content,
 
57
  }
58
  return {"role": message.role, "content": content}
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Process streaming response with headers from config.py
61
+ async def process_streaming_response(request):
62
  chat_id = generate_chat_id()
63
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model}")
64
 
65
+ agent_mode = agentMode.get(request.model, {})
66
+ trending_agent_mode = trendingAgentMode.get(request.model, {})
67
+
68
+ headers_api_chat = get_headers_api_chat(f"{BASE_URL}/?model={request.model}")
69
 
70
  json_data = {
71
+ "agentMode": agent_mode,
 
72
  "clickedAnswer2": False,
73
  "clickedAnswer3": False,
74
+ "clickedForceWebSearch ": False,
75
  "codeModelMode": True,
76
  "githubToken": None,
77
  "id": chat_id,
 
83
  "playgroundTemperature": request.temperature,
84
  "playgroundTopP": request.top_p,
85
  "previewToken": None,
86
+ "trendingAgentMode": trending_agent_mode,
87
  "userId": None,
88
+ "userSelectedModel": request.model,
89
  "userSystemPrompt": None,
90
  "validated": "69783381-2ce4-4dbd-ac78-35e9063feabc",
91
  "visitFromDelta": False,
92
  }
93
 
 
94
  async with httpx.AsyncClient() as client:
95
  try:
96
  async with client.stream(
97
  "POST",
98
  f"{BASE_URL}/api/chat",
99
+ headers=headers_api_chat,
100
  json=json_data,
101
  timeout=100,
102
  ) as response:
 
104
  async for line in response.aiter_lines():
105
  timestamp = int(datetime.now().timestamp())
106
  if line:
107
+ content = line
108
+ if content.startswith("$@$v=undefined-rv1$@$"):
109
+ content = content[21:]
110
+ yield f"data: {json.dumps(create_chat_completion_data(content, request.model, timestamp))}\n\n"
111
+
112
  yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
113
  yield "data: [DONE]\n\n"
114
  except httpx.HTTPStatusError as e:
 
119
  raise HTTPException(status_code=500, detail=str(e))
120
 
121
  # Process non-streaming response with headers from config.py
122
+ async def process_non_streaming_response(request):
123
  chat_id = generate_chat_id()
124
  logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model}")
125
 
126
+ agent_mode = agentMode.get(request.model, {})
127
+ trending_agent_mode = trendingAgentMode.get(request.model, {})
128
+
129
+ headers_api_chat = get_headers_api_chat(f"{BASE_URL}/?model={request.model}")
130
 
131
  json_data = {
132
+ "agentMode": agent_mode,
 
133
  "clickedAnswer2": False,
134
  "clickedAnswer3": False,
135
  "clickedForceWebSearch": False,
 
144
  "playgroundTemperature": request.temperature,
145
  "playgroundTopP": request.top_p,
146
  "previewToken": None,
147
+ "trendingAgentMode": trending_agent_mode,
148
  "userId": None,
149
+ "userSelectedModel": request.model,
150
  "userSystemPrompt": None,
151
  "validated": "69783381-2ce4-4dbd-ac78-35e9063feabc",
152
  "visitFromDelta": False,
153
  }
154
 
 
155
  full_response = ""
156
  async with httpx.AsyncClient() as client:
157
  try:
158
  async with client.stream(
159
+ method="POST", url=f"{BASE_URL}/api/chat", headers=headers_api_chat, json=json_data
160
  ) as response:
161
  response.raise_for_status()
162
  async for chunk in response.aiter_text():
 
167
  except httpx.RequestError as e:
168
  logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
169
  raise HTTPException(status_code=500, detail=str(e))
170
+ if full_response.startswith("$@$v=undefined-rv1$@$"):
171
+ full_response = full_response[21:]
172
 
173
  return {
174
  "id": f"chatcmpl-{uuid.uuid4()}",
 
183
  }
184
  ],
185
  "usage": None,
186
+ }