Spaces:
Running
Running
Commit
·
e0b9c59
1
Parent(s):
fa22c58
image test
Browse files- app/main.py +139 -36
app/main.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
from fastapi import FastAPI, HTTPException, Depends, Header, Request
|
2 |
from fastapi.responses import JSONResponse, StreamingResponse
|
3 |
from fastapi.security import APIKeyHeader
|
4 |
-
from pydantic import BaseModel, ConfigDict
|
5 |
-
from typing import List, Dict, Any, Optional, Union
|
|
|
|
|
6 |
import json
|
7 |
import time
|
8 |
import os
|
9 |
import glob
|
10 |
import random
|
11 |
from google.oauth2 import service_account
|
12 |
-
# from vertexai.preview.generative_models import GenerativeModel, HarmCategory, HarmBlockThreshold, SafetySetting
|
13 |
-
import vertexai
|
14 |
import config
|
15 |
|
16 |
from google.genai import types
|
@@ -137,9 +137,20 @@ class CredentialManager:
|
|
137 |
credential_manager = CredentialManager()
|
138 |
|
139 |
# Define data models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
class OpenAIMessage(BaseModel):
|
141 |
role: str
|
142 |
-
content: Union[str, List[Dict[str,
|
143 |
|
144 |
class OpenAIRequest(BaseModel):
|
145 |
model: str
|
@@ -195,33 +206,52 @@ async def startup_event():
|
|
195 |
print("WARNING: Failed to initialize Vertex AI authentication")
|
196 |
|
197 |
# Conversion functions
|
198 |
-
def create_gemini_prompt(messages: List[OpenAIMessage]) -> str:
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
203 |
for message in messages:
|
204 |
-
if message.
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
break
|
214 |
|
215 |
-
# If
|
216 |
-
if
|
217 |
-
prompt
|
218 |
-
|
219 |
-
# Add other messages
|
220 |
-
for message in messages:
|
221 |
-
if message.role == "system":
|
222 |
-
continue # Already handled
|
223 |
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
# Handle both string and list[dict] content types
|
226 |
content_text = ""
|
227 |
if isinstance(message.content, str):
|
@@ -229,19 +259,90 @@ def create_gemini_prompt(messages: List[OpenAIMessage]) -> str:
|
|
229 |
elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
|
230 |
content_text = message.content[0]['text']
|
231 |
else:
|
232 |
-
|
233 |
content_text = str(message.content)
|
234 |
|
235 |
if message.role == "user":
|
236 |
prompt += f"Human: {content_text}\n"
|
237 |
elif message.role == "assistant":
|
238 |
prompt += f"AI: {content_text}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
-
#
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
def create_generation_config(request: OpenAIRequest) -> Dict[str, Any]:
|
247 |
config = {}
|
@@ -589,9 +690,10 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
589 |
# If multiple candidates are requested, we'll generate them sequentially
|
590 |
for candidate_index in range(candidate_count):
|
591 |
# Generate content with streaming
|
|
|
592 |
responses = client.models.generate_content_stream(
|
593 |
model=gemini_model,
|
594 |
-
contents=prompt,
|
595 |
config=generation_config,
|
596 |
)
|
597 |
|
@@ -623,13 +725,14 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
623 |
# Make sure generation_config has candidate_count set
|
624 |
if "candidate_count" not in generation_config:
|
625 |
generation_config["candidate_count"] = request.n
|
626 |
-
|
627 |
response = client.models.generate_content(
|
628 |
model=gemini_model,
|
629 |
-
contents=prompt,
|
630 |
config=generation_config,
|
631 |
)
|
632 |
|
|
|
633 |
openai_response = convert_to_openai_format(response, request.model)
|
634 |
return JSONResponse(content=openai_response)
|
635 |
except Exception as generate_error:
|
|
|
1 |
from fastapi import FastAPI, HTTPException, Depends, Header, Request
|
2 |
from fastapi.responses import JSONResponse, StreamingResponse
|
3 |
from fastapi.security import APIKeyHeader
|
4 |
+
from pydantic import BaseModel, ConfigDict, Field
|
5 |
+
from typing import List, Dict, Any, Optional, Union, Literal
|
6 |
+
import base64
|
7 |
+
import re
|
8 |
import json
|
9 |
import time
|
10 |
import os
|
11 |
import glob
|
12 |
import random
|
13 |
from google.oauth2 import service_account
|
|
|
|
|
14 |
import config
|
15 |
|
16 |
from google.genai import types
|
|
|
137 |
credential_manager = CredentialManager()
|
138 |
|
139 |
# Define data models
|
140 |
+
class ImageUrl(BaseModel):
|
141 |
+
url: str
|
142 |
+
|
143 |
+
class ContentPartImage(BaseModel):
|
144 |
+
type: Literal["image_url"]
|
145 |
+
image_url: ImageUrl
|
146 |
+
|
147 |
+
class ContentPartText(BaseModel):
|
148 |
+
type: Literal["text"]
|
149 |
+
text: str
|
150 |
+
|
151 |
class OpenAIMessage(BaseModel):
|
152 |
role: str
|
153 |
+
content: Union[str, List[Union[ContentPartText, ContentPartImage, Dict[str, Any]]]]
|
154 |
|
155 |
class OpenAIRequest(BaseModel):
|
156 |
model: str
|
|
|
206 |
print("WARNING: Failed to initialize Vertex AI authentication")
|
207 |
|
208 |
# Conversion functions
|
209 |
+
def create_gemini_prompt(messages: List[OpenAIMessage]) -> Union[str, List[Any]]:
|
210 |
+
"""
|
211 |
+
Convert OpenAI messages to Gemini format.
|
212 |
+
Returns either a string prompt or a list of content parts if images are present.
|
213 |
+
"""
|
214 |
+
# Check if any message contains image content
|
215 |
+
has_images = False
|
216 |
for message in messages:
|
217 |
+
if isinstance(message.content, list):
|
218 |
+
for part in message.content:
|
219 |
+
if isinstance(part, dict) and part.get('type') == 'image_url':
|
220 |
+
has_images = True
|
221 |
+
break
|
222 |
+
elif isinstance(part, ContentPartImage):
|
223 |
+
has_images = True
|
224 |
+
break
|
225 |
+
if has_images:
|
226 |
break
|
227 |
|
228 |
+
# If no images, use the text-only format
|
229 |
+
if not has_images:
|
230 |
+
prompt = ""
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
+
# Extract system message if present
|
233 |
+
system_message = None
|
234 |
+
for message in messages:
|
235 |
+
if message.role == "system":
|
236 |
+
# Handle both string and list[dict] content types
|
237 |
+
if isinstance(message.content, str):
|
238 |
+
system_message = message.content
|
239 |
+
elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
|
240 |
+
system_message = message.content[0]['text']
|
241 |
+
else:
|
242 |
+
# Handle unexpected format or raise error? For now, assume it's usable or skip.
|
243 |
+
system_message = str(message.content) # Fallback, might need refinement
|
244 |
+
break
|
245 |
+
|
246 |
+
# If system message exists, prepend it
|
247 |
+
if system_message:
|
248 |
+
prompt += f"System: {system_message}\n\n"
|
249 |
+
|
250 |
+
# Add other messages
|
251 |
+
for message in messages:
|
252 |
+
if message.role == "system":
|
253 |
+
continue # Already handled
|
254 |
+
|
255 |
# Handle both string and list[dict] content types
|
256 |
content_text = ""
|
257 |
if isinstance(message.content, str):
|
|
|
259 |
elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
|
260 |
content_text = message.content[0]['text']
|
261 |
else:
|
262 |
+
# Fallback for unexpected format
|
263 |
content_text = str(message.content)
|
264 |
|
265 |
if message.role == "user":
|
266 |
prompt += f"Human: {content_text}\n"
|
267 |
elif message.role == "assistant":
|
268 |
prompt += f"AI: {content_text}\n"
|
269 |
+
|
270 |
+
# Add final AI prompt if last message was from user
|
271 |
+
if messages[-1].role == "user":
|
272 |
+
prompt += "AI: "
|
273 |
+
|
274 |
+
return prompt
|
275 |
+
|
276 |
+
# If images are present, create a list of content parts
|
277 |
+
gemini_contents = []
|
278 |
|
279 |
+
# Extract system message if present and add it first
|
280 |
+
for message in messages:
|
281 |
+
if message.role == "system":
|
282 |
+
if isinstance(message.content, str):
|
283 |
+
gemini_contents.append(f"System: {message.content}")
|
284 |
+
elif isinstance(message.content, list):
|
285 |
+
# Extract text from system message
|
286 |
+
system_text = ""
|
287 |
+
for part in message.content:
|
288 |
+
if isinstance(part, dict) and part.get('type') == 'text':
|
289 |
+
system_text += part.get('text', '')
|
290 |
+
elif isinstance(part, ContentPartText):
|
291 |
+
system_text += part.text
|
292 |
+
if system_text:
|
293 |
+
gemini_contents.append(f"System: {system_text}")
|
294 |
+
break
|
295 |
|
296 |
+
# Process user and assistant messages
|
297 |
+
for message in messages:
|
298 |
+
if message.role == "system":
|
299 |
+
continue # Already handled
|
300 |
+
|
301 |
+
# For string content, add as text
|
302 |
+
if isinstance(message.content, str):
|
303 |
+
prefix = "Human: " if message.role == "user" else "AI: "
|
304 |
+
gemini_contents.append(f"{prefix}{message.content}")
|
305 |
+
|
306 |
+
# For list content, process each part
|
307 |
+
elif isinstance(message.content, list):
|
308 |
+
# First collect all text parts
|
309 |
+
text_content = ""
|
310 |
+
|
311 |
+
for part in message.content:
|
312 |
+
# Handle text parts
|
313 |
+
if isinstance(part, dict) and part.get('type') == 'text':
|
314 |
+
text_content += part.get('text', '')
|
315 |
+
elif isinstance(part, ContentPartText):
|
316 |
+
text_content += part.text
|
317 |
+
|
318 |
+
# Add the combined text content if any
|
319 |
+
if text_content:
|
320 |
+
prefix = "Human: " if message.role == "user" else "AI: "
|
321 |
+
gemini_contents.append(f"{prefix}{text_content}")
|
322 |
+
|
323 |
+
# Then process image parts
|
324 |
+
for part in message.content:
|
325 |
+
# Handle image parts
|
326 |
+
if isinstance(part, dict) and part.get('type') == 'image_url':
|
327 |
+
image_url = part.get('image_url', {}).get('url', '')
|
328 |
+
if image_url.startswith('data:'):
|
329 |
+
# Extract mime type and base64 data
|
330 |
+
mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
|
331 |
+
if mime_match:
|
332 |
+
mime_type, b64_data = mime_match.groups()
|
333 |
+
image_bytes = base64.b64decode(b64_data)
|
334 |
+
gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
|
335 |
+
elif isinstance(part, ContentPartImage):
|
336 |
+
image_url = part.image_url.url
|
337 |
+
if image_url.startswith('data:'):
|
338 |
+
# Extract mime type and base64 data
|
339 |
+
mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
|
340 |
+
if mime_match:
|
341 |
+
mime_type, b64_data = mime_match.groups()
|
342 |
+
image_bytes = base64.b64decode(b64_data)
|
343 |
+
gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
|
344 |
+
|
345 |
+
return gemini_contents
|
346 |
|
347 |
def create_generation_config(request: OpenAIRequest) -> Dict[str, Any]:
|
348 |
config = {}
|
|
|
690 |
# If multiple candidates are requested, we'll generate them sequentially
|
691 |
for candidate_index in range(candidate_count):
|
692 |
# Generate content with streaming
|
693 |
+
# Handle both string and list content formats (for images)
|
694 |
responses = client.models.generate_content_stream(
|
695 |
model=gemini_model,
|
696 |
+
contents=prompt, # This can be either a string or a list of content parts
|
697 |
config=generation_config,
|
698 |
)
|
699 |
|
|
|
725 |
# Make sure generation_config has candidate_count set
|
726 |
if "candidate_count" not in generation_config:
|
727 |
generation_config["candidate_count"] = request.n
|
728 |
+
# Handle both string and list content formats (for images)
|
729 |
response = client.models.generate_content(
|
730 |
model=gemini_model,
|
731 |
+
contents=prompt, # This can be either a string or a list of content parts
|
732 |
config=generation_config,
|
733 |
)
|
734 |
|
735 |
+
|
736 |
openai_response = convert_to_openai_format(response, request.model)
|
737 |
return JSONResponse(content=openai_response)
|
738 |
except Exception as generate_error:
|