bibibi12345 commited on
Commit
e0b9c59
·
1 Parent(s): fa22c58

image test

Browse files
Files changed (1) hide show
  1. app/main.py +139 -36
app/main.py CHANGED
@@ -1,16 +1,16 @@
1
  from fastapi import FastAPI, HTTPException, Depends, Header, Request
2
  from fastapi.responses import JSONResponse, StreamingResponse
3
  from fastapi.security import APIKeyHeader
4
- from pydantic import BaseModel, ConfigDict
5
- from typing import List, Dict, Any, Optional, Union
 
 
6
  import json
7
  import time
8
  import os
9
  import glob
10
  import random
11
  from google.oauth2 import service_account
12
- # from vertexai.preview.generative_models import GenerativeModel, HarmCategory, HarmBlockThreshold, SafetySetting
13
- import vertexai
14
  import config
15
 
16
  from google.genai import types
@@ -137,9 +137,20 @@ class CredentialManager:
137
  credential_manager = CredentialManager()
138
 
139
  # Define data models
 
 
 
 
 
 
 
 
 
 
 
140
  class OpenAIMessage(BaseModel):
141
  role: str
142
- content: Union[str, List[Dict[str, str]]] # Allow string or list of dicts for content
143
 
144
  class OpenAIRequest(BaseModel):
145
  model: str
@@ -195,33 +206,52 @@ async def startup_event():
195
  print("WARNING: Failed to initialize Vertex AI authentication")
196
 
197
  # Conversion functions
198
- def create_gemini_prompt(messages: List[OpenAIMessage]) -> str:
199
- prompt = ""
200
-
201
- # Extract system message if present
202
- system_message = None
 
 
203
  for message in messages:
204
- if message.role == "system":
205
- # Handle both string and list[dict] content types
206
- if isinstance(message.content, str):
207
- system_message = message.content
208
- elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
209
- system_message = message.content[0]['text']
210
- else:
211
- # Handle unexpected format or raise error? For now, assume it's usable or skip.
212
- system_message = str(message.content) # Fallback, might need refinement
213
  break
214
 
215
- # If system message exists, prepend it
216
- if system_message:
217
- prompt += f"System: {system_message}\n\n"
218
-
219
- # Add other messages
220
- for message in messages:
221
- if message.role == "system":
222
- continue # Already handled
223
 
224
- if message.role == "user":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  # Handle both string and list[dict] content types
226
  content_text = ""
227
  if isinstance(message.content, str):
@@ -229,19 +259,90 @@ def create_gemini_prompt(messages: List[OpenAIMessage]) -> str:
229
  elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
230
  content_text = message.content[0]['text']
231
  else:
232
- # Fallback for unexpected format
233
  content_text = str(message.content)
234
 
235
  if message.role == "user":
236
  prompt += f"Human: {content_text}\n"
237
  elif message.role == "assistant":
238
  prompt += f"AI: {content_text}\n"
 
 
 
 
 
 
 
 
 
239
 
240
- # Add final AI prompt if last message was from user
241
- if messages[-1].role == "user":
242
- prompt += "AI: "
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  def create_generation_config(request: OpenAIRequest) -> Dict[str, Any]:
247
  config = {}
@@ -589,9 +690,10 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
589
  # If multiple candidates are requested, we'll generate them sequentially
590
  for candidate_index in range(candidate_count):
591
  # Generate content with streaming
 
592
  responses = client.models.generate_content_stream(
593
  model=gemini_model,
594
- contents=prompt,
595
  config=generation_config,
596
  )
597
 
@@ -623,13 +725,14 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
623
  # Make sure generation_config has candidate_count set
624
  if "candidate_count" not in generation_config:
625
  generation_config["candidate_count"] = request.n
626
-
627
  response = client.models.generate_content(
628
  model=gemini_model,
629
- contents=prompt,
630
  config=generation_config,
631
  )
632
 
 
633
  openai_response = convert_to_openai_format(response, request.model)
634
  return JSONResponse(content=openai_response)
635
  except Exception as generate_error:
 
1
  from fastapi import FastAPI, HTTPException, Depends, Header, Request
2
  from fastapi.responses import JSONResponse, StreamingResponse
3
  from fastapi.security import APIKeyHeader
4
+ from pydantic import BaseModel, ConfigDict, Field
5
+ from typing import List, Dict, Any, Optional, Union, Literal
6
+ import base64
7
+ import re
8
  import json
9
  import time
10
  import os
11
  import glob
12
  import random
13
  from google.oauth2 import service_account
 
 
14
  import config
15
 
16
  from google.genai import types
 
137
  credential_manager = CredentialManager()
138
 
139
  # Define data models
140
+ class ImageUrl(BaseModel):
141
+ url: str
142
+
143
+ class ContentPartImage(BaseModel):
144
+ type: Literal["image_url"]
145
+ image_url: ImageUrl
146
+
147
+ class ContentPartText(BaseModel):
148
+ type: Literal["text"]
149
+ text: str
150
+
151
  class OpenAIMessage(BaseModel):
152
  role: str
153
+ content: Union[str, List[Union[ContentPartText, ContentPartImage, Dict[str, Any]]]]
154
 
155
  class OpenAIRequest(BaseModel):
156
  model: str
 
206
  print("WARNING: Failed to initialize Vertex AI authentication")
207
 
208
  # Conversion functions
209
+ def create_gemini_prompt(messages: List[OpenAIMessage]) -> Union[str, List[Any]]:
210
+ """
211
+ Convert OpenAI messages to Gemini format.
212
+ Returns either a string prompt or a list of content parts if images are present.
213
+ """
214
+ # Check if any message contains image content
215
+ has_images = False
216
  for message in messages:
217
+ if isinstance(message.content, list):
218
+ for part in message.content:
219
+ if isinstance(part, dict) and part.get('type') == 'image_url':
220
+ has_images = True
221
+ break
222
+ elif isinstance(part, ContentPartImage):
223
+ has_images = True
224
+ break
225
+ if has_images:
226
  break
227
 
228
+ # If no images, use the text-only format
229
+ if not has_images:
230
+ prompt = ""
 
 
 
 
 
231
 
232
+ # Extract system message if present
233
+ system_message = None
234
+ for message in messages:
235
+ if message.role == "system":
236
+ # Handle both string and list[dict] content types
237
+ if isinstance(message.content, str):
238
+ system_message = message.content
239
+ elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
240
+ system_message = message.content[0]['text']
241
+ else:
242
+ # Handle unexpected format or raise error? For now, assume it's usable or skip.
243
+ system_message = str(message.content) # Fallback, might need refinement
244
+ break
245
+
246
+ # If system message exists, prepend it
247
+ if system_message:
248
+ prompt += f"System: {system_message}\n\n"
249
+
250
+ # Add other messages
251
+ for message in messages:
252
+ if message.role == "system":
253
+ continue # Already handled
254
+
255
  # Handle both string and list[dict] content types
256
  content_text = ""
257
  if isinstance(message.content, str):
 
259
  elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
260
  content_text = message.content[0]['text']
261
  else:
262
+ # Fallback for unexpected format
263
  content_text = str(message.content)
264
 
265
  if message.role == "user":
266
  prompt += f"Human: {content_text}\n"
267
  elif message.role == "assistant":
268
  prompt += f"AI: {content_text}\n"
269
+
270
+ # Add final AI prompt if last message was from user
271
+ if messages[-1].role == "user":
272
+ prompt += "AI: "
273
+
274
+ return prompt
275
+
276
+ # If images are present, create a list of content parts
277
+ gemini_contents = []
278
 
279
+ # Extract system message if present and add it first
280
+ for message in messages:
281
+ if message.role == "system":
282
+ if isinstance(message.content, str):
283
+ gemini_contents.append(f"System: {message.content}")
284
+ elif isinstance(message.content, list):
285
+ # Extract text from system message
286
+ system_text = ""
287
+ for part in message.content:
288
+ if isinstance(part, dict) and part.get('type') == 'text':
289
+ system_text += part.get('text', '')
290
+ elif isinstance(part, ContentPartText):
291
+ system_text += part.text
292
+ if system_text:
293
+ gemini_contents.append(f"System: {system_text}")
294
+ break
295
 
296
+ # Process user and assistant messages
297
+ for message in messages:
298
+ if message.role == "system":
299
+ continue # Already handled
300
+
301
+ # For string content, add as text
302
+ if isinstance(message.content, str):
303
+ prefix = "Human: " if message.role == "user" else "AI: "
304
+ gemini_contents.append(f"{prefix}{message.content}")
305
+
306
+ # For list content, process each part
307
+ elif isinstance(message.content, list):
308
+ # First collect all text parts
309
+ text_content = ""
310
+
311
+ for part in message.content:
312
+ # Handle text parts
313
+ if isinstance(part, dict) and part.get('type') == 'text':
314
+ text_content += part.get('text', '')
315
+ elif isinstance(part, ContentPartText):
316
+ text_content += part.text
317
+
318
+ # Add the combined text content if any
319
+ if text_content:
320
+ prefix = "Human: " if message.role == "user" else "AI: "
321
+ gemini_contents.append(f"{prefix}{text_content}")
322
+
323
+ # Then process image parts
324
+ for part in message.content:
325
+ # Handle image parts
326
+ if isinstance(part, dict) and part.get('type') == 'image_url':
327
+ image_url = part.get('image_url', {}).get('url', '')
328
+ if image_url.startswith('data:'):
329
+ # Extract mime type and base64 data
330
+ mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
331
+ if mime_match:
332
+ mime_type, b64_data = mime_match.groups()
333
+ image_bytes = base64.b64decode(b64_data)
334
+ gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
335
+ elif isinstance(part, ContentPartImage):
336
+ image_url = part.image_url.url
337
+ if image_url.startswith('data:'):
338
+ # Extract mime type and base64 data
339
+ mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
340
+ if mime_match:
341
+ mime_type, b64_data = mime_match.groups()
342
+ image_bytes = base64.b64decode(b64_data)
343
+ gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
344
+
345
+ return gemini_contents
346
 
347
  def create_generation_config(request: OpenAIRequest) -> Dict[str, Any]:
348
  config = {}
 
690
  # If multiple candidates are requested, we'll generate them sequentially
691
  for candidate_index in range(candidate_count):
692
  # Generate content with streaming
693
+ # Handle both string and list content formats (for images)
694
  responses = client.models.generate_content_stream(
695
  model=gemini_model,
696
+ contents=prompt, # This can be either a string or a list of content parts
697
  config=generation_config,
698
  )
699
 
 
725
  # Make sure generation_config has candidate_count set
726
  if "candidate_count" not in generation_config:
727
  generation_config["candidate_count"] = request.n
728
+ # Handle both string and list content formats (for images)
729
  response = client.models.generate_content(
730
  model=gemini_model,
731
+ contents=prompt, # This can be either a string or a list of content parts
732
  config=generation_config,
733
  )
734
 
735
+
736
  openai_response = convert_to_openai_format(response, request.model)
737
  return JSONResponse(content=openai_response)
738
  except Exception as generate_error: