Niansuh commited on
Commit
bda3109
·
verified ·
1 Parent(s): 6331a12

Update api/utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +22 -32
api/utils.py CHANGED
@@ -1,31 +1,28 @@
1
  from datetime import datetime
2
  import json
3
  import uuid
4
- import random
5
  import asyncio
 
6
  from typing import Any, Dict, Optional
 
7
  import httpx
8
  from fastapi import HTTPException
9
  from api.config import (
10
  MODEL_MAPPING,
11
- headers,
 
 
12
  AGENT_MODE,
13
  TRENDING_AGENT_MODE,
14
- BASE_URL,
15
  MODEL_PREFIXES,
16
  MODEL_REFERERS
17
  )
18
  from api.models import ChatRequest
19
  from api.logger import setup_logger
20
 
21
- # Initialize logger
22
  logger = setup_logger(__name__)
23
 
24
- # Function to generate a unique chat ID
25
- def generate_chat_id() -> str:
26
- return f"chat-{uuid.uuid4()}"
27
-
28
- # Create chat completion data function
29
  def create_chat_completion_data(
30
  content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
31
  ) -> Dict[str, Any]:
@@ -44,7 +41,7 @@ def create_chat_completion_data(
44
  "usage": None,
45
  }
46
 
47
- # Message to dict converter function
48
  def message_to_dict(message, model_prefix: Optional[str] = None):
49
  if isinstance(message.content, str):
50
  content = message.content
@@ -67,7 +64,7 @@ def message_to_dict(message, model_prefix: Optional[str] = None):
67
  else:
68
  return {"role": message.role, "content": message.content}
69
 
70
- # Strip model prefix function
71
  def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
72
  """Remove the model prefix from the response content if present."""
73
  if model_prefix and content.startswith(model_prefix):
@@ -76,18 +73,16 @@ def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
76
  logger.debug("No prefix to strip from content.")
77
  return content
78
 
79
- # Process streaming response
80
  async def process_streaming_response(request: ChatRequest):
81
- chat_id = generate_chat_id() # Generate unique chat ID
82
  agent_mode = AGENT_MODE.get(request.model, {})
83
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
84
  model_prefix = MODEL_PREFIXES.get(request.model, "")
85
  referer_path = MODEL_REFERERS.get(request.model, f"/?model={request.model}")
86
- referer_url = f"{BASE_URL}/chat/{chat_id}?model={request.model}" # Updated URL format with chat_id
87
 
88
- # Update headers with dynamic Referer
89
- dynamic_headers = headers.copy()
90
- dynamic_headers['Referer'] = referer_url
91
 
92
  # Introduce delay for 'o1-preview' model
93
  if request.model == 'o1-preview':
@@ -96,7 +91,6 @@ async def process_streaming_response(request: ChatRequest):
96
  await asyncio.sleep(delay_seconds)
97
 
98
  json_data = {
99
- "id": chat_id,
100
  "messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages],
101
  "previewToken": None,
102
  "userId": None,
@@ -124,7 +118,7 @@ async def process_streaming_response(request: ChatRequest):
124
  async with client.stream(
125
  "POST",
126
  f"{BASE_URL}/api/chat",
127
- headers=dynamic_headers,
128
  json=json_data,
129
  timeout=100,
130
  ) as response:
@@ -135,6 +129,7 @@ async def process_streaming_response(request: ChatRequest):
135
  content = line
136
  if content.startswith("$@$v=undefined-rv1$@$"):
137
  content = content[21:]
 
138
  cleaned_content = strip_model_prefix(content, model_prefix)
139
  yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, request.model, timestamp))}\n\n"
140
 
@@ -147,27 +142,20 @@ async def process_streaming_response(request: ChatRequest):
147
  logger.error(f"Error occurred during request: {e}")
148
  raise HTTPException(status_code=500, detail=str(e))
149
 
150
- # Process non-streaming response
151
  async def process_non_streaming_response(request: ChatRequest):
152
- chat_id = generate_chat_id() # Generate unique chat ID
153
  agent_mode = AGENT_MODE.get(request.model, {})
154
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
155
  model_prefix = MODEL_PREFIXES.get(request.model, "")
156
  referer_path = MODEL_REFERERS.get(request.model, f"/?model={request.model}")
157
  referer_url = f"{BASE_URL}{referer_path}"
 
158
 
159
- # Update headers with dynamic Referer
160
- dynamic_headers = headers.copy()
161
- dynamic_headers['Referer'] = referer_url
162
-
163
- # Introduce delay for 'o1-preview' model
164
- if request.model == 'o1-preview':
165
- delay_seconds = random.randint(20, 60)
166
- logger.info(f"Introducing a delay of {delay_seconds} seconds for model 'o1-preview'")
167
- await asyncio.sleep(delay_seconds)
168
 
169
  json_data = {
170
- "id": chat_id,
171
  "messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages],
172
  "previewToken": None,
173
  "userId": None,
@@ -189,11 +177,12 @@ async def process_non_streaming_response(request: ChatRequest):
189
  "mobileClient": False,
190
  "userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
191
  }
 
192
  full_response = ""
193
  async with httpx.AsyncClient() as client:
194
  try:
195
  async with client.stream(
196
- method="POST", url=f"{BASE_URL}/api/chat", headers=dynamic_headers, json=json_data
197
  ) as response:
198
  response.raise_for_status()
199
  async for chunk in response.aiter_text():
@@ -207,6 +196,7 @@ async def process_non_streaming_response(request: ChatRequest):
207
  if full_response.startswith("$@$v=undefined-rv1$@$"):
208
  full_response = full_response[21:]
209
 
 
210
  cleaned_full_response = strip_model_prefix(full_response, model_prefix)
211
 
212
  return {
 
1
  from datetime import datetime
2
  import json
3
  import uuid
 
4
  import asyncio
5
+ import random
6
  from typing import Any, Dict, Optional
7
+
8
  import httpx
9
  from fastapi import HTTPException
10
  from api.config import (
11
  MODEL_MAPPING,
12
+ get_headers_api_chat,
13
+ get_headers_chat,
14
+ BASE_URL,
15
  AGENT_MODE,
16
  TRENDING_AGENT_MODE,
 
17
  MODEL_PREFIXES,
18
  MODEL_REFERERS
19
  )
20
  from api.models import ChatRequest
21
  from api.logger import setup_logger
22
 
 
23
  logger = setup_logger(__name__)
24
 
25
+ # Helper function to create chat completion data
 
 
 
 
26
  def create_chat_completion_data(
27
  content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
28
  ) -> Dict[str, Any]:
 
41
  "usage": None,
42
  }
43
 
44
+ # Function to convert message to dictionary format with optional model prefix
45
  def message_to_dict(message, model_prefix: Optional[str] = None):
46
  if isinstance(message.content, str):
47
  content = message.content
 
64
  else:
65
  return {"role": message.role, "content": message.content}
66
 
67
+ # Function to strip model prefix from content if present
68
  def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
69
  """Remove the model prefix from the response content if present."""
70
  if model_prefix and content.startswith(model_prefix):
 
73
  logger.debug("No prefix to strip from content.")
74
  return content
75
 
76
+ # Process streaming response with headers from config.py
77
  async def process_streaming_response(request: ChatRequest):
 
78
  agent_mode = AGENT_MODE.get(request.model, {})
79
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
80
  model_prefix = MODEL_PREFIXES.get(request.model, "")
81
  referer_path = MODEL_REFERERS.get(request.model, f"/?model={request.model}")
82
+ referer_url = f"{BASE_URL}{referer_path}"
83
 
84
+ # Generate headers for API chat request
85
+ headers_api_chat = get_headers_api_chat(referer_url)
 
86
 
87
  # Introduce delay for 'o1-preview' model
88
  if request.model == 'o1-preview':
 
91
  await asyncio.sleep(delay_seconds)
92
 
93
  json_data = {
 
94
  "messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages],
95
  "previewToken": None,
96
  "userId": None,
 
118
  async with client.stream(
119
  "POST",
120
  f"{BASE_URL}/api/chat",
121
+ headers=headers_api_chat,
122
  json=json_data,
123
  timeout=100,
124
  ) as response:
 
129
  content = line
130
  if content.startswith("$@$v=undefined-rv1$@$"):
131
  content = content[21:]
132
+ # Strip the model prefix from the response content
133
  cleaned_content = strip_model_prefix(content, model_prefix)
134
  yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, request.model, timestamp))}\n\n"
135
 
 
142
  logger.error(f"Error occurred during request: {e}")
143
  raise HTTPException(status_code=500, detail=str(e))
144
 
145
+ # Process non-streaming response with headers from config.py
146
  async def process_non_streaming_response(request: ChatRequest):
 
147
  agent_mode = AGENT_MODE.get(request.model, {})
148
  trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
149
  model_prefix = MODEL_PREFIXES.get(request.model, "")
150
  referer_path = MODEL_REFERERS.get(request.model, f"/?model={request.model}")
151
  referer_url = f"{BASE_URL}{referer_path}"
152
+ chat_url = f"{BASE_URL}/chat/{uuid.uuid4()}?model={request.model}"
153
 
154
+ # Generate headers for API chat request and chat request
155
+ headers_api_chat = get_headers_api_chat(referer_url)
156
+ headers_chat = get_headers_chat(chat_url, next_action=str(uuid.uuid4()), next_router_state_tree=json.dumps([""]))
 
 
 
 
 
 
157
 
158
  json_data = {
 
159
  "messages": [message_to_dict(msg, model_prefix=model_prefix) for msg in request.messages],
160
  "previewToken": None,
161
  "userId": None,
 
177
  "mobileClient": False,
178
  "userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
179
  }
180
+
181
  full_response = ""
182
  async with httpx.AsyncClient() as client:
183
  try:
184
  async with client.stream(
185
+ method="POST", url=f"{BASE_URL}/api/chat", headers=headers_api_chat, json=json_data
186
  ) as response:
187
  response.raise_for_status()
188
  async for chunk in response.aiter_text():
 
196
  if full_response.startswith("$@$v=undefined-rv1$@$"):
197
  full_response = full_response[21:]
198
 
199
+ # Strip the model prefix from the full response
200
  cleaned_full_response = strip_model_prefix(full_response, model_prefix)
201
 
202
  return {