Niansuh commited on
Commit
db66062
·
verified ·
1 Parent(s): 539e646

Update api/utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +38 -17
api/utils.py CHANGED
@@ -4,16 +4,16 @@ import uuid
4
  import asyncio
5
  import random
6
  import string
7
- from typing import Any, Dict, Optional
8
 
9
  import httpx
10
  from fastapi import HTTPException
11
  from api.config import (
12
  MODEL_MAPPING,
 
 
13
  get_headers,
14
  BASE_URL,
15
- AGENT_MODE,
16
- TRENDING_AGENT_MODE
17
  )
18
  from api.models import ChatRequest
19
  from api.logger import setup_logger
@@ -44,9 +44,11 @@ def create_chat_completion_data(
44
  "usage": None,
45
  }
46
 
47
- # Function to convert message to dictionary format
48
- def message_to_dict(message):
49
  content = message.content if isinstance(message.content, str) else message.content[0]["text"]
 
 
50
  if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
51
  return {
52
  "role": message.role,
@@ -59,16 +61,25 @@ def message_to_dict(message):
59
  }
60
  return {"role": message.role, "content": content}
61
 
62
- # Process streaming response
 
 
 
 
 
 
 
63
  async def process_streaming_response(request: ChatRequest):
64
  chat_id = generate_chat_id()
65
- agent_mode = AGENT_MODE.get(request.model, {})
66
- trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
67
 
 
 
68
  headers = get_headers()
69
 
70
  json_data = {
71
- "agentMode": agent_mode,
 
72
  "clickedAnswer2": False,
73
  "clickedAnswer3": False,
74
  "clickedForceWebSearch": False,
@@ -83,7 +94,6 @@ async def process_streaming_response(request: ChatRequest):
83
  "playgroundTemperature": request.temperature,
84
  "playgroundTopP": request.top_p,
85
  "previewToken": None,
86
- "trendingAgentMode": trending_agent_mode,
87
  "userId": None,
88
  "userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
89
  "userSystemPrompt": None,
@@ -91,9 +101,16 @@ async def process_streaming_response(request: ChatRequest):
91
  "visitFromDelta": False,
92
  }
93
 
 
94
  async with httpx.AsyncClient() as client:
95
  try:
96
- async with client.stream("POST", f"{BASE_URL}/api/chat", headers=headers, json=json_data, timeout=100) as response:
 
 
 
 
 
 
97
  response.raise_for_status()
98
  async for line in response.aiter_lines():
99
  timestamp = int(datetime.now().timestamp())
@@ -108,16 +125,18 @@ async def process_streaming_response(request: ChatRequest):
108
  logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
109
  raise HTTPException(status_code=500, detail=str(e))
110
 
111
- # Process non-streaming response
112
  async def process_non_streaming_response(request: ChatRequest):
113
  chat_id = generate_chat_id()
114
- agent_mode = AGENT_MODE.get(request.model, {})
115
- trending_agent_mode = TRENDING_AGENT_MODE.get(request.model, {})
116
 
 
 
117
  headers = get_headers()
118
 
119
  json_data = {
120
- "agentMode": agent_mode,
 
121
  "clickedAnswer2": False,
122
  "clickedAnswer3": False,
123
  "clickedForceWebSearch": False,
@@ -132,7 +151,6 @@ async def process_non_streaming_response(request: ChatRequest):
132
  "playgroundTemperature": request.temperature,
133
  "playgroundTopP": request.top_p,
134
  "previewToken": None,
135
- "trendingAgentMode": trending_agent_mode,
136
  "userId": None,
137
  "userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
138
  "userSystemPrompt": None,
@@ -140,10 +158,13 @@ async def process_non_streaming_response(request: ChatRequest):
140
  "visitFromDelta": False,
141
  }
142
 
 
143
  full_response = ""
144
  async with httpx.AsyncClient() as client:
145
  try:
146
- async with client.stream("POST", f"{BASE_URL}/api/chat", headers=headers, json=json_data) as response:
 
 
147
  response.raise_for_status()
148
  async for chunk in response.aiter_text():
149
  full_response += chunk
 
4
  import asyncio
5
  import random
6
  import string
7
+ from typing import Any, Dict, Optional, List
8
 
9
  import httpx
10
  from fastapi import HTTPException
11
  from api.config import (
12
  MODEL_MAPPING,
13
+ AGENT_MODE,
14
+ TRENDING_AGENT_MODE,
15
  get_headers,
16
  BASE_URL,
 
 
17
  )
18
  from api.models import ChatRequest
19
  from api.logger import setup_logger
 
44
  "usage": None,
45
  }
46
 
47
+ # Function to convert message to dictionary format for API
48
+ def message_to_dict(message, model_prefix: Optional[str] = None) -> Dict[str, Any]:
49
  content = message.content if isinstance(message.content, str) else message.content[0]["text"]
50
+ if model_prefix:
51
+ content = f"{model_prefix} {content}"
52
  if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
53
  return {
54
  "role": message.role,
 
61
  }
62
  return {"role": message.role, "content": content}
63
 
64
+ # Function to get agent modes for specific models
65
+ def get_agent_modes(model: str) -> Dict[str, Any]:
66
+ return {
67
+ "agentMode": AGENT_MODE.get(model, {}),
68
+ "trendingAgentMode": TRENDING_AGENT_MODE.get(model, {}),
69
+ }
70
+
71
+ # Process streaming response with headers from config.py
72
  async def process_streaming_response(request: ChatRequest):
73
  chat_id = generate_chat_id()
74
+ logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model}")
 
75
 
76
+ # Retrieve agent modes based on the model
77
+ agent_modes = get_agent_modes(request.model)
78
  headers = get_headers()
79
 
80
  json_data = {
81
+ "agentMode": agent_modes["agentMode"],
82
+ "trendingAgentMode": agent_modes["trendingAgentMode"],
83
  "clickedAnswer2": False,
84
  "clickedAnswer3": False,
85
  "clickedForceWebSearch": False,
 
94
  "playgroundTemperature": request.temperature,
95
  "playgroundTopP": request.top_p,
96
  "previewToken": None,
 
97
  "userId": None,
98
  "userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
99
  "userSystemPrompt": None,
 
101
  "visitFromDelta": False,
102
  }
103
 
104
+ logger.debug(f"Data Payload: {json_data}") # Debugging line to inspect payload
105
  async with httpx.AsyncClient() as client:
106
  try:
107
+ async with client.stream(
108
+ "POST",
109
+ f"{BASE_URL}/api/chat",
110
+ headers=headers,
111
+ json=json_data,
112
+ timeout=100,
113
+ ) as response:
114
  response.raise_for_status()
115
  async for line in response.aiter_lines():
116
  timestamp = int(datetime.now().timestamp())
 
125
  logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
126
  raise HTTPException(status_code=500, detail=str(e))
127
 
128
+ # Process non-streaming response with headers from config.py
129
  async def process_non_streaming_response(request: ChatRequest):
130
  chat_id = generate_chat_id()
131
+ logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model}")
 
132
 
133
+ # Retrieve agent modes based on the model
134
+ agent_modes = get_agent_modes(request.model)
135
  headers = get_headers()
136
 
137
  json_data = {
138
+ "agentMode": agent_modes["agentMode"],
139
+ "trendingAgentMode": agent_modes["trendingAgentMode"],
140
  "clickedAnswer2": False,
141
  "clickedAnswer3": False,
142
  "clickedForceWebSearch": False,
 
151
  "playgroundTemperature": request.temperature,
152
  "playgroundTopP": request.top_p,
153
  "previewToken": None,
 
154
  "userId": None,
155
  "userSelectedModel": MODEL_MAPPING.get(request.model, request.model),
156
  "userSystemPrompt": None,
 
158
  "visitFromDelta": False,
159
  }
160
 
161
+ logger.debug(f"Data Payload: {json_data}") # Debugging line to inspect payload
162
  full_response = ""
163
  async with httpx.AsyncClient() as client:
164
  try:
165
+ async with client.stream(
166
+ method="POST", url=f"{BASE_URL}/api/chat", headers=headers, json=json_data
167
+ ) as response:
168
  response.raise_for_status()
169
  async for chunk in response.aiter_text():
170
  full_response += chunk