Niansuh commited on
Commit
ac74d4c
·
verified ·
1 Parent(s): ccd9777

Update api/utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +101 -58
api/utils.py CHANGED
@@ -3,33 +3,103 @@ import json
3
  import uuid
4
  import asyncio
5
  import random
6
- from fastapi import HTTPException, Request
 
 
7
  import httpx
 
 
8
  from api.config import (
9
  MODEL_MAPPING,
10
  get_headers_api_chat,
 
11
  BASE_URL,
12
  AGENT_MODE,
13
  TRENDING_AGENT_MODE,
14
- MODEL_PREFIXES
 
15
  )
 
16
  from api.logger import setup_logger
17
- from api import validate
18
 
19
  logger = setup_logger(__name__)
20
 
21
- async def process_streaming_response(request: ChatRequest, request_obj: Request):
22
- referer_url = BASE_URL
23
- client_ip = request_obj.client.host # Get the client IP
24
- logger.info(f"Processing streaming response - Model: {request.model} - URL: {referer_url} - IP: {client_ip}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  agent_mode = AGENT_MODE.get(request.model, {})
27
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
28
  model_prefix = MODEL_PREFIXES.get(request.model, "")
29
 
30
  headers_api_chat = get_headers_api_chat(referer_url)
31
- validated_token = validate.getHid() # Get validated token
32
- logger.info(f"Retrieved validated token for IP {client_ip}: {validated_token}")
 
 
 
 
 
 
33
 
34
  json_data = {
35
  "agentMode": agent_mode,
@@ -38,6 +108,7 @@ async def process_streaming_response(request: ChatRequest, request_obj: Request)
38
  "clickedForceWebSearch": False,
39
  "codeModelMode": True,
40
  "githubToken": None,
 
41
  "isChromeExt": False,
42
  "isMicMode": False,
43
  "maxTokens": request.max_tokens,
@@ -70,27 +141,36 @@ async def process_streaming_response(request: ChatRequest, request_obj: Request)
70
  content = line
71
  if content.startswith("$@$v=undefined-rv1$@$"):
72
  content = content[21:]
73
- yield f"data: {json.dumps(create_chat_completion_data(content, request.model, timestamp))}\n\n"
 
 
74
  yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
75
  yield "data: [DONE]\n\n"
76
  except httpx.HTTPStatusError as e:
77
- logger.error(f"HTTP error occurred (IP: {client_ip}): {e}")
78
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
79
  except httpx.RequestError as e:
80
- logger.error(f"Error occurred during request (IP: {client_ip}): {e}")
81
  raise HTTPException(status_code=500, detail=str(e))
82
 
83
- async def process_non_streaming_response(request: ChatRequest, request_obj: Request):
84
- referer_url = BASE_URL
85
- client_ip = request_obj.client.host
86
- logger.info(f"Processing non-streaming response - Model: {request.model} - URL: {referer_url} - IP: {client_ip}")
 
87
 
88
  agent_mode = AGENT_MODE.get(request.model, {})
89
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
90
  model_prefix = MODEL_PREFIXES.get(request.model, "")
91
 
92
  headers_api_chat = get_headers_api_chat(referer_url)
93
- validated_token = validate.getHid() # Get validated token
 
 
 
 
 
 
94
 
95
  json_data = {
96
  "agentMode": agent_mode,
@@ -99,6 +179,7 @@ async def process_non_streaming_response(request: ChatRequest, request_obj: Requ
99
  "clickedForceWebSearch": False,
100
  "codeModelMode": True,
101
  "githubToken": None,
 
102
  "isChromeExt": False,
103
  "isMicMode": False,
104
  "maxTokens": request.max_tokens,
@@ -125,10 +206,10 @@ async def process_non_streaming_response(request: ChatRequest, request_obj: Requ
125
  async for chunk in response.aiter_text():
126
  full_response += chunk
127
  except httpx.HTTPStatusError as e:
128
- logger.error(f"HTTP error occurred (IP: {client_ip}): {e}")
129
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
130
  except httpx.RequestError as e:
131
- logger.error(f"Error occurred during request (IP: {client_ip}): {e}")
132
  raise HTTPException(status_code=500, detail=str(e))
133
  if full_response.startswith("$@$v=undefined-rv1$@$"):
134
  full_response = full_response[21:]
@@ -148,42 +229,4 @@ async def process_non_streaming_response(request: ChatRequest, request_obj: Requ
148
  }
149
  ],
150
  "usage": None,
151
- }
152
-
153
- def create_chat_completion_data(content: str, model: str, timestamp: int, finish_reason: Optional[str] = None):
154
- return {
155
- "id": f"chatcmpl-{uuid.uuid4()}",
156
- "object": "chat.completion.chunk",
157
- "created": timestamp,
158
- "model": model,
159
- "choices": [
160
- {
161
- "index": 0,
162
- "delta": {"content": content, "role": "assistant"},
163
- "finish_reason": finish_reason,
164
- }
165
- ],
166
- "usage": None,
167
- }
168
-
169
- def message_to_dict(message, model_prefix: Optional[str] = None):
170
- content = message.content if isinstance(message.content, str) else message.content[0]["text"]
171
- if model_prefix:
172
- content = f"{model_prefix} {content}"
173
- if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
174
- return {
175
- "role": message.role,
176
- "content": content,
177
- "data": {
178
- "imageBase64": message.content[1]["image_url"]["url"],
179
- "fileText": "",
180
- "title": "snapshot",
181
- },
182
- }
183
- return {"role": message.role, "content": content}
184
-
185
- def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
186
- if model_prefix and content.startswith(model_prefix):
187
- logger.debug(f"Stripping prefix '{model_prefix}' from content.")
188
- return content[len(model_prefix):].strip()
189
- return content
 
3
  import uuid
4
  import asyncio
5
  import random
6
+ import string
7
+ from typing import Any, Dict, Optional
8
+
9
  import httpx
10
+ from fastapi import HTTPException
11
+ from api import validate # Import validate to use getHid
12
  from api.config import (
13
  MODEL_MAPPING,
14
  get_headers_api_chat,
15
+ get_headers_chat,
16
  BASE_URL,
17
  AGENT_MODE,
18
  TRENDING_AGENT_MODE,
19
+ MODEL_PREFIXES,
20
+ MODEL_REFERERS
21
  )
22
+ from api.models import ChatRequest
23
  from api.logger import setup_logger
 
24
 
25
  logger = setup_logger(__name__)
26
 
27
+ # Helper function to create a random alphanumeric chat ID
28
+ def generate_chat_id(length: int = 7) -> str:
29
+ characters = string.ascii_letters + string.digits
30
+ return ''.join(random.choices(characters, k=length))
31
+
32
+ # Helper function to create chat completion data
33
+ def create_chat_completion_data(
34
+ content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
35
+ ) -> Dict[str, Any]:
36
+ return {
37
+ "id": f"chatcmpl-{uuid.uuid4()}",
38
+ "object": "chat.completion.chunk",
39
+ "created": timestamp,
40
+ "model": model,
41
+ "choices": [
42
+ {
43
+ "index": 0,
44
+ "delta": {"content": content, "role": "assistant"},
45
+ "finish_reason": finish_reason,
46
+ }
47
+ ],
48
+ "usage": None,
49
+ }
50
+
51
+ # Function to convert message to dictionary format, ensuring base64 data and optional model prefix
52
+ def message_to_dict(message, model_prefix: Optional[str] = None):
53
+ content = message.content if isinstance(message.content, str) else message.content[0]["text"]
54
+ if model_prefix:
55
+ content = f"{model_prefix} {content}"
56
+ if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
57
+ # Ensure base64 images are always included for all models
58
+ return {
59
+ "role": message.role,
60
+ "content": content,
61
+ "data": {
62
+ "imageBase64": message.content[1]["image_url"]["url"],
63
+ "fileText": "",
64
+ "title": "snapshot",
65
+ },
66
+ }
67
+ return {"role": message.role, "content": content}
68
+
69
+ # Function to strip model prefix from content if present
70
+ def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
71
+ """Remove the model prefix from the response content if present."""
72
+ if model_prefix and content.startswith(model_prefix):
73
+ logger.debug(f"Stripping prefix '{model_prefix}' from content.")
74
+ return content[len(model_prefix):].strip()
75
+ return content
76
+
77
+ # Function to get the correct referer URL for logging
78
+ def get_referer_url(chat_id: str, model: str) -> str:
79
+ """Generate the referer URL based on specific models listed in MODEL_REFERERS."""
80
+ if model in MODEL_REFERERS:
81
+ return f"{BASE_URL}/chat/{chat_id}?model={model}"
82
+ return BASE_URL
83
+
84
+ # Process streaming response with headers from config.py
85
+ async def process_streaming_response(request: ChatRequest):
86
+ chat_id = generate_chat_id()
87
+ referer_url = get_referer_url(chat_id, request.model)
88
+ logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
89
 
90
  agent_mode = AGENT_MODE.get(request.model, {})
91
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
92
  model_prefix = MODEL_PREFIXES.get(request.model, "")
93
 
94
  headers_api_chat = get_headers_api_chat(referer_url)
95
+ validated_token = validate.getHid() # Get the validated token from validate.py
96
+ logger.info(f"Retrieved validated token: {validated_token}")
97
+
98
+
99
+ if request.model == 'o1-preview':
100
+ delay_seconds = random.randint(1, 60)
101
+ logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview' (Chat ID: {chat_id})")
102
+ await asyncio.sleep(delay_seconds)
103
 
104
  json_data = {
105
  "agentMode": agent_mode,
 
108
  "clickedForceWebSearch": False,
109
  "codeModelMode": True,
110
  "githubToken": None,
111
+ "id": chat_id,
112
  "isChromeExt": False,
113
  "isMicMode": False,
114
  "maxTokens": request.max_tokens,
 
141
  content = line
142
  if content.startswith("$@$v=undefined-rv1$@$"):
143
  content = content[21:]
144
+ cleaned_content = strip_model_prefix(content, model_prefix)
145
+ yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, request.model, timestamp))}\n\n"
146
+
147
  yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
148
  yield "data: [DONE]\n\n"
149
  except httpx.HTTPStatusError as e:
150
+ logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
151
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
152
  except httpx.RequestError as e:
153
+ logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
154
  raise HTTPException(status_code=500, detail=str(e))
155
 
156
+ # Process non-streaming response with headers from config.py
157
+ async def process_non_streaming_response(request: ChatRequest):
158
+ chat_id = generate_chat_id()
159
+ referer_url = get_referer_url(chat_id, request.model)
160
+ logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
161
 
162
  agent_mode = AGENT_MODE.get(request.model, {})
163
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
164
  model_prefix = MODEL_PREFIXES.get(request.model, "")
165
 
166
  headers_api_chat = get_headers_api_chat(referer_url)
167
+ headers_chat = get_headers_chat(referer_url, next_action=str(uuid.uuid4()), next_router_state_tree=json.dumps([""]))
168
+ validated_token = validate.getHid() # Get the validated token from validate.py
169
+
170
+ if request.model == 'o1-preview':
171
+ delay_seconds = random.randint(20, 60)
172
+ logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview' (Chat ID: {chat_id})")
173
+ await asyncio.sleep(delay_seconds)
174
 
175
  json_data = {
176
  "agentMode": agent_mode,
 
179
  "clickedForceWebSearch": False,
180
  "codeModelMode": True,
181
  "githubToken": None,
182
+ "id": chat_id,
183
  "isChromeExt": False,
184
  "isMicMode": False,
185
  "maxTokens": request.max_tokens,
 
206
  async for chunk in response.aiter_text():
207
  full_response += chunk
208
  except httpx.HTTPStatusError as e:
209
+ logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
210
  raise HTTPException(status_code=e.response.status_code, detail=str(e))
211
  except httpx.RequestError as e:
212
+ logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
213
  raise HTTPException(status_code=500, detail=str(e))
214
  if full_response.startswith("$@$v=undefined-rv1$@$"):
215
  full_response = full_response[21:]
 
229
  }
230
  ],
231
  "usage": None,
232
+ }