Niansuh commited on
Commit
b0275d3
·
verified ·
1 Parent(s): 3b29c2b

Update api/utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +196 -96
api/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from datetime import datetime
2
  import json
3
  import uuid
@@ -9,16 +11,17 @@ from typing import Any, Dict, Optional, AsyncGenerator
9
  import httpx
10
  from fastapi import HTTPException
11
  from api.config import (
12
- models,
13
- model_aliases,
14
- ALLOWED_MODELS,
15
- MODEL_MAPPING,
16
  get_headers_api_chat,
17
  BASE_URL,
 
 
18
  )
19
  from api.models import ChatRequest, Message
20
  from api.logger import setup_logger
21
- from api.providers.gizai import GizAI # Import the GizAI provider
22
 
23
  logger = setup_logger(__name__)
24
 
@@ -27,7 +30,7 @@ def generate_chat_id(length: int = 7) -> str:
27
  characters = string.ascii_letters + string.digits
28
  return ''.join(random.choices(characters, k=length))
29
 
30
- # Helper function to create chat completion data
31
  def create_chat_completion_data(
32
  content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
33
  ) -> Dict[str, Any]:
@@ -46,15 +49,11 @@ def create_chat_completion_data(
46
  "usage": None,
47
  }
48
 
49
- # Function to convert message to dictionary format, ensuring base64 data
50
- def message_to_dict(message: Message):
51
- if isinstance(message.content, str):
52
- content = message.content
53
- elif isinstance(message.content, list) and isinstance(message.content[0], dict) and "text" in message.content[0]:
54
- content = message.content[0]["text"]
55
- else:
56
- content = ""
57
-
58
  if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
59
  # Ensure base64 images are always included for all models
60
  return {
@@ -68,91 +67,192 @@ def message_to_dict(message: Message):
68
  }
69
  return {"role": message.role, "content": content}
70
 
71
- # Function to resolve model aliases
72
- def resolve_model(model: str) -> str:
73
- if model in MODEL_MAPPING:
74
- return model
75
- elif model in model_aliases:
76
- return model_aliases[model]
77
- else:
78
- logger.warning(f"Model '{model}' not recognized. Using default model '{GizAI.default_model}'.")
79
- return GizAI.default_model # default_model
80
-
81
- # Process streaming response with GizAI provider
 
 
 
 
 
 
 
 
 
 
82
  async def process_streaming_response(request: ChatRequest) -> AsyncGenerator[str, None]:
83
  chat_id = generate_chat_id()
84
- resolved_model = resolve_model(request.model)
85
- logger.info(f"Generated Chat ID: {chat_id} - Model: {resolved_model}")
86
-
87
- # Instantiate the GizAI provider
88
- gizai_provider = GizAI()
89
-
90
- # Create the async generator
91
- async for response in gizai_provider.create_async_generator(
92
- model=resolved_model,
93
- messages=request.messages,
94
- proxy=request.proxy # Assuming 'proxy' is part of ChatRequest; if not, adjust accordingly
95
- ):
96
- timestamp = int(datetime.now().timestamp())
97
- if isinstance(response, ImageResponse):
98
- # Handle image responses
99
- yield f"data: {json.dumps({'image_url': response.images, 'alt': response.alt})}\n\n"
100
- else:
101
- # Handle text responses
102
- yield f"data: {json.dumps(create_chat_completion_data(response, resolved_model, timestamp))}\n\n"
103
-
104
- # Indicate completion
105
- timestamp = int(datetime.now().timestamp())
106
- yield f"data: {json.dumps(create_chat_completion_data('', resolved_model, timestamp, 'stop'))}\n\n"
107
- yield "data: [DONE]\n\n"
108
-
109
- # Process non-streaming response with GizAI provider
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  async def process_non_streaming_response(request: ChatRequest) -> Dict[str, Any]:
111
  chat_id = generate_chat_id()
112
- resolved_model = resolve_model(request.model)
113
- logger.info(f"Generated Chat ID: {chat_id} - Model: {resolved_model}")
114
-
115
- # Instantiate the GizAI provider
116
- gizai_provider = GizAI()
117
-
118
- # Collect the responses
119
- responses = []
120
- async for response in gizai_provider.create_async_generator(
121
- model=resolved_model,
122
- messages=request.messages,
123
- proxy=request.proxy # Assuming 'proxy' is part of ChatRequest; if not, adjust accordingly
124
- ):
125
- if isinstance(response, ImageResponse):
126
- # For image responses, collect image URLs
127
- responses.append({"image_url": response.images, "alt": response.alt})
128
- else:
129
- # For text responses, append the text
130
- responses.append(response)
131
 
132
- return {
133
- "id": f"chatcmpl-{uuid.uuid4()}",
134
- "object": "chat.completion",
135
- "created": int(datetime.now().timestamp()),
136
- "model": resolved_model,
137
- "choices": [
138
- {
139
- "index": 0,
140
- "message": {"role": "assistant", "content": responses},
141
- "finish_reason": "stop",
 
 
 
 
 
 
 
 
142
  }
143
- ],
144
- "usage": None,
145
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- # Helper function to format prompt from messages
148
- def format_prompt(messages: list[Message]) -> str:
149
- # Implement the prompt formatting as per GizAI's requirements
150
- # Placeholder implementation
151
- formatted_messages = []
152
- for msg in messages:
153
- if isinstance(msg.content, str):
154
- formatted_messages.append(msg.content)
155
- elif isinstance(msg.content, list):
156
- text = msg.content[0].get("text", "")
157
- formatted_messages.append(text)
158
- return "\n".join(formatted_messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/utils.py
2
+
3
  from datetime import datetime
4
  import json
5
  import uuid
 
11
  import httpx
12
  from fastapi import HTTPException
13
  from api.config import (
14
+ MODELS,
15
+ MODEL_ALIASES,
16
+ DEFAULT_MODEL,
17
+ API_ENDPOINT,
18
  get_headers_api_chat,
19
  BASE_URL,
20
+ MODEL_PREFIXES,
21
+ MODEL_REFERERS
22
  )
23
  from api.models import ChatRequest, Message
24
  from api.logger import setup_logger
 
25
 
26
  logger = setup_logger(__name__)
27
 
 
30
  characters = string.ascii_letters + string.digits
31
  return ''.join(random.choices(characters, k=length))
32
 
33
+ # Helper function to create a chat completion data chunk
34
  def create_chat_completion_data(
35
  content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
36
  ) -> Dict[str, Any]:
 
49
  "usage": None,
50
  }
51
 
52
+ # Function to convert message to dictionary format, ensuring base64 data and optional model prefix
53
+ def message_to_dict(message: Message, model_prefix: Optional[str] = None):
54
+ content = message.content if isinstance(message.content, str) else message.content[0]["text"]
55
+ if model_prefix:
56
+ content = f"{model_prefix} {content}"
 
 
 
 
57
  if isinstance(message.content, list) and len(message.content) == 2 and "image_url" in message.content[1]:
58
  # Ensure base64 images are always included for all models
59
  return {
 
67
  }
68
  return {"role": message.role, "content": content}
69
 
70
+ # Function to strip model prefix from content if present
71
+ def strip_model_prefix(content: str, model_prefix: Optional[str] = None) -> str:
72
+ """Remove the model prefix from the response content if present."""
73
+ if model_prefix and content.startswith(model_prefix):
74
+ logger.debug(f"Stripping prefix '{model_prefix}' from content.")
75
+ return content[len(model_prefix):].strip()
76
+ return content
77
+
78
+ # Function to get the correct referer URL for logging
79
+ def get_referer_url(chat_id: str, model: str) -> str:
80
+ """Generate the referer URL based on specific models listed in MODEL_REFERERS."""
81
+ if model in MODEL_REFERERS:
82
+ return f"{BASE_URL}/chat/{chat_id}?model={model}"
83
+ return BASE_URL
84
+
85
+ # Helper function to format messages
86
+ def format_messages(messages: list[Message]) -> str:
87
+ # Assuming messages need to be concatenated in some way
88
+ return "\n".join([msg.content if isinstance(msg.content, str) else msg.content[0]["text"] for msg in messages])
89
+
90
+ # Process streaming response
91
  async def process_streaming_response(request: ChatRequest) -> AsyncGenerator[str, None]:
92
  chat_id = generate_chat_id()
93
+ referer_url = get_referer_url(chat_id, request.model)
94
+ logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
95
+
96
+ model = request.model if request.model in MODELS else MODEL_ALIASES.get(request.model, DEFAULT_MODEL)
97
+ model_prefix = MODEL_PREFIXES.get(model, "")
98
+
99
+ headers_api_chat = get_headers_api_chat(referer_url)
100
+
101
+ # Prepare data based on model type
102
+ if model in ['flux1', 'sdxl', 'sd', 'sd35']: # Image models
103
+ prompt = request.messages[-1].content if isinstance(request.messages[-1].content, str) else request.messages[-1].content[0]["text"]
104
+ data = {
105
+ "model": model,
106
+ "input": {
107
+ "width": "1024",
108
+ "height": "1024",
109
+ "steps": 4,
110
+ "output_format": "webp",
111
+ "batch_size": 1,
112
+ "mode": "plan",
113
+ "prompt": prompt
114
+ }
115
+ }
116
+ else: # Chat models
117
+ data = {
118
+ "model": model,
119
+ "input": {
120
+ "messages": [
121
+ {
122
+ "type": "human",
123
+ "content": f"{model_prefix} {format_messages(request.messages)}" if model_prefix else format_messages(request.messages)
124
+ }
125
+ ],
126
+ "mode": "plan"
127
+ },
128
+ "noStream": False # Assuming streaming
129
+ }
130
+
131
+ async with httpx.AsyncClient() as client:
132
+ try:
133
+ async with client.post(
134
+ API_ENDPOINT,
135
+ headers=headers_api_chat,
136
+ json=data,
137
+ timeout=100
138
+ ) as response:
139
+ response.raise_for_status()
140
+ # Assuming the API returns a streaming response
141
+ async for line in response.aiter_lines():
142
+ timestamp = int(datetime.now().timestamp())
143
+ if line:
144
+ content = line
145
+ # Depending on GizAI's response format, adjust parsing
146
+ # Placeholder for content processing
147
+ # Assuming content contains the message
148
+ cleaned_content = strip_model_prefix(content, model_prefix)
149
+ yield f"data: {json.dumps(create_chat_completion_data(cleaned_content, model, timestamp))}\n\n"
150
+
151
+ # Indicate end of stream
152
+ yield f"data: {json.dumps(create_chat_completion_data('', model, int(datetime.now().timestamp()), 'stop'))}\n\n"
153
+ yield "data: [DONE]\n\n"
154
+ except httpx.HTTPStatusError as e:
155
+ logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
156
+ raise HTTPException(status_code=e.response.status_code, detail=str(e))
157
+ except httpx.RequestError as e:
158
+ logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
159
+ raise HTTPException(status_code=500, detail=str(e))
160
+
161
+ # Process non-streaming response
162
  async def process_non_streaming_response(request: ChatRequest) -> Dict[str, Any]:
163
  chat_id = generate_chat_id()
164
+ referer_url = get_referer_url(chat_id, request.model)
165
+ logger.info(f"Generated Chat ID: {chat_id} - Model: {request.model} - URL: {referer_url}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ model = request.model if request.model in MODELS else MODEL_ALIASES.get(request.model, DEFAULT_MODEL)
168
+ model_prefix = MODEL_PREFIXES.get(model, "")
169
+
170
+ headers_api_chat = get_headers_api_chat(referer_url)
171
+
172
+ # Prepare data based on model type
173
+ if model in ['flux1', 'sdxl', 'sd', 'sd35']: # Image models
174
+ prompt = request.messages[-1].content if isinstance(request.messages[-1].content, str) else request.messages[-1].content[0]["text"]
175
+ data = {
176
+ "model": model,
177
+ "input": {
178
+ "width": "1024",
179
+ "height": "1024",
180
+ "steps": 4,
181
+ "output_format": "webp",
182
+ "batch_size": 1,
183
+ "mode": "plan",
184
+ "prompt": prompt
185
  }
186
+ }
187
+ else: # Chat models
188
+ data = {
189
+ "model": model,
190
+ "input": {
191
+ "messages": [
192
+ {
193
+ "type": "human",
194
+ "content": f"{model_prefix} {format_messages(request.messages)}" if model_prefix else format_messages(request.messages)
195
+ }
196
+ ],
197
+ "mode": "plan"
198
+ },
199
+ "noStream": True # Non-streaming
200
+ }
201
+
202
+ async with httpx.AsyncClient() as client:
203
+ try:
204
+ response = await client.post(
205
+ API_ENDPOINT,
206
+ headers=headers_api_chat,
207
+ json=data,
208
+ timeout=100
209
+ )
210
+ response.raise_for_status()
211
+ response_data = response.json()
212
 
213
+ # Process response based on GizAI's API response structure
214
+ # Placeholder: assuming 'output' contains the generated content
215
+ if model in ['flux1', 'sdxl', 'sd', 'sd35']: # Image models
216
+ if response_data.get('status') == 'completed' and response_data.get('output'):
217
+ images = response_data['output']
218
+ # Assuming images is a list of URLs
219
+ # For non-streaming, return all images at once
220
+ # Adjust according to actual response
221
+ return {
222
+ "id": f"chatcmpl-{uuid.uuid4()}",
223
+ "object": "chat.completion",
224
+ "created": int(datetime.now().timestamp()),
225
+ "model": model,
226
+ "choices": [
227
+ {
228
+ "index": 0,
229
+ "message": {"role": "assistant", "content": "", "images": images},
230
+ "finish_reason": "stop",
231
+ }
232
+ ],
233
+ "usage": None,
234
+ }
235
+ else: # Chat models
236
+ # Assuming response_data contains the full response
237
+ content = response_data.get('output', '')
238
+ cleaned_content = strip_model_prefix(content, model_prefix)
239
+ return {
240
+ "id": f"chatcmpl-{uuid.uuid4()}",
241
+ "object": "chat.completion",
242
+ "created": int(datetime.now().timestamp()),
243
+ "model": model,
244
+ "choices": [
245
+ {
246
+ "index": 0,
247
+ "message": {"role": "assistant", "content": cleaned_content},
248
+ "finish_reason": "stop",
249
+ }
250
+ ],
251
+ "usage": None,
252
+ }
253
+ except httpx.HTTPStatusError as e:
254
+ logger.error(f"HTTP error occurred for Chat ID {chat_id}: {e}")
255
+ raise HTTPException(status_code=e.response.status_code, detail=str(e))
256
+ except httpx.RequestError as e:
257
+ logger.error(f"Error occurred during request for Chat ID {chat_id}: {e}")
258
+ raise HTTPException(status_code=500, detail=str(e))