bibibi12345 commited on
Commit
c1c72bd
·
1 Parent(s): e942db3

added openai direct format

Browse files
Files changed (2) hide show
  1. app/main.py +256 -31
  2. app/requirements.txt +2 -1
app/main.py CHANGED
@@ -15,6 +15,8 @@ import random
15
  import urllib.parse
16
  from google.oauth2 import service_account
17
  import config
 
 
18
 
19
  from google.genai import types
20
 
@@ -1149,6 +1151,15 @@ async def list_models(api_key: str = Depends(get_api_key)):
1149
  "root": "gemini-2.5-pro-exp-03-25",
1150
  "parent": None,
1151
  },
 
 
 
 
 
 
 
 
 
1152
  {
1153
  "id": "gemini-2.5-pro-preview-03-25",
1154
  "object": "model",
@@ -1336,6 +1347,15 @@ def create_openai_error_response(status_code: int, message: str, error_type: str
1336
  }
1337
  }
1338
 
 
 
 
 
 
 
 
 
 
1339
  @app.post("/v1/chat/completions")
1340
  async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)): # Add request parameter
1341
  try:
@@ -1348,38 +1368,232 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1348
  )
1349
  return JSONResponse(status_code=400, content=error_response)
1350
 
1351
- # Check model type and extract base model name
1352
- is_auto_model = request.model.endswith("-auto")
1353
- is_grounded_search = request.model.endswith("-search")
1354
- is_encrypted_model = request.model.endswith("-encrypt")
1355
- is_encrypted_full_model = request.model.endswith("-encrypt-full")
1356
- is_nothinking_model = request.model.endswith("-nothinking")
1357
- is_max_thinking_model = request.model.endswith("-max")
1358
 
1359
- if is_auto_model:
1360
- base_model_name = request.model.replace("-auto", "")
1361
- elif is_grounded_search:
1362
- base_model_name = request.model.replace("-search", "")
1363
- elif is_encrypted_model:
1364
- base_model_name = request.model.replace("-encrypt", "")
1365
- elif is_encrypted_full_model:
1366
- base_model_name = request.model.replace("-encrypt-full", "")
1367
- elif is_nothinking_model:
1368
- base_model_name = request.model.replace("-nothinking","")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1369
  # Specific check for the flash model requiring budget
1370
- if base_model_name != "gemini-2.5-flash-preview-04-17":
1371
- error_response = create_openai_error_response(
1372
- 400, f"Model '{request.model}' does not support -nothinking variant", "invalid_request_error"
1373
- )
1374
- return JSONResponse(status_code=400, content=error_response)
1375
- elif is_max_thinking_model:
1376
- base_model_name = request.model.replace("-max","")
 
 
1377
  # Specific check for the flash model requiring budget
1378
- if base_model_name != "gemini-2.5-flash-preview-04-17":
1379
- error_response = create_openai_error_response(
1380
- 400, f"Model '{request.model}' does not support -max variant", "invalid_request_error"
1381
- )
1382
- return JSONResponse(status_code=400, content=error_response)
 
1383
  else:
1384
  base_model_name = request.model
1385
 
@@ -1418,7 +1632,8 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1418
  types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
1419
  types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
1420
  types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
1421
- types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF")
 
1422
  ]
1423
  generation_config["safety_settings"] = safety_settings
1424
 
@@ -1518,8 +1733,18 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1518
  # --- Main Logic ---
1519
  last_error = None
1520
 
1521
- if is_auto_model:
 
 
 
 
 
 
 
 
 
1522
  print(f"Processing auto model: {request.model}")
 
1523
  # Define encryption instructions for system_instruction
1524
  encryption_instructions = [
1525
  "// AI Assistant Configuration //",
 
15
  import urllib.parse
16
  from google.oauth2 import service_account
17
  import config
18
+ import openai # Added import
19
+ from google.auth.transport.requests import Request as AuthRequest # Added import
20
 
21
  from google.genai import types
22
 
 
1151
  "root": "gemini-2.5-pro-exp-03-25",
1152
  "parent": None,
1153
  },
1154
+ { # Added new model entry for OpenAI endpoint
1155
+ "id": "gemini-2.5-pro-exp-03-25-openai",
1156
+ "object": "model",
1157
+ "created": int(time.time()),
1158
+ "owned_by": "google",
1159
+ "permission": [],
1160
+ "root": "gemini-2.5-pro-exp-03-25", # Underlying model
1161
+ "parent": None,
1162
+ },
1163
  {
1164
  "id": "gemini-2.5-pro-preview-03-25",
1165
  "object": "model",
 
1347
  }
1348
  }
1349
 
1350
+ # Helper for token refresh
1351
+ def _refresh_auth(credentials):
1352
+ try:
1353
+ credentials.refresh(AuthRequest())
1354
+ return credentials.token
1355
+ except Exception as e:
1356
+ print(f"Error refreshing GCP token: {e}")
1357
+ return None
1358
+
1359
  @app.post("/v1/chat/completions")
1360
  async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)): # Add request parameter
1361
  try:
 
1368
  )
1369
  return JSONResponse(status_code=400, content=error_response)
1370
 
1371
+ # --- Handle specific OpenAI client model ---
1372
+ if request.model.endswith("-openai"): # Generalized check for suffix
1373
+ print(f"INFO: Using OpenAI library path for model: {request.model}")
1374
+ base_model_name = request.model.replace("-openai", "") # Extract base model name
1375
+ UNDERLYING_MODEL_ID = f"google/{base_model_name}" # Add google/ prefix
 
 
1376
 
1377
+ # --- Determine Credentials for OpenAI Client (Correct Priority) ---
1378
+ credentials_to_use = None
1379
+ project_id_to_use = None
1380
+ credential_source = "unknown"
1381
+
1382
+ # Priority 1: GOOGLE_CREDENTIALS_JSON (JSON String in Env Var)
1383
+ credentials_json_str = os.environ.get("GOOGLE_CREDENTIALS_JSON")
1384
+ if credentials_json_str:
1385
+ try:
1386
+ credentials_info = json.loads(credentials_json_str)
1387
+ if not isinstance(credentials_info, dict): raise ValueError("JSON is not a dict")
1388
+ required = ["type", "project_id", "private_key_id", "private_key", "client_email"]
1389
+ if any(f not in credentials_info for f in required): raise ValueError("Missing required fields")
1390
+
1391
+ credentials = service_account.Credentials.from_service_account_info(
1392
+ credentials_info, scopes=['https://www.googleapis.com/auth/cloud-platform']
1393
+ )
1394
+ project_id = credentials.project_id
1395
+ credentials_to_use = credentials
1396
+ project_id_to_use = project_id
1397
+ credential_source = "GOOGLE_CREDENTIALS_JSON env var"
1398
+ print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
1399
+ except Exception as e:
1400
+ print(f"WARNING: [OpenAI Path] Error processing GOOGLE_CREDENTIALS_JSON: {e}. Trying next method.")
1401
+ credentials_to_use = None # Ensure reset if failed
1402
+
1403
+ # Priority 2: Credential Manager (Rotated Files)
1404
+ if credentials_to_use is None:
1405
+ print(f"INFO: [OpenAI Path] Checking Credential Manager (directory: {credential_manager.credentials_dir})")
1406
+ rotated_credentials, rotated_project_id = credential_manager.get_next_credentials()
1407
+ if rotated_credentials and rotated_project_id:
1408
+ credentials_to_use = rotated_credentials
1409
+ project_id_to_use = rotated_project_id
1410
+ credential_source = f"Credential Manager file (Index: {credential_manager.current_index -1 if credential_manager.current_index > 0 else len(credential_manager.credentials_files) - 1})"
1411
+ print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
1412
+ else:
1413
+ print(f"INFO: [OpenAI Path] No credentials loaded via Credential Manager.")
1414
+
1415
+ # Priority 3: GOOGLE_APPLICATION_CREDENTIALS (File Path in Env Var)
1416
+ if credentials_to_use is None:
1417
+ file_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
1418
+ if file_path:
1419
+ print(f"INFO: [OpenAI Path] Checking GOOGLE_APPLICATION_CREDENTIALS file path: {file_path}")
1420
+ if os.path.exists(file_path):
1421
+ try:
1422
+ credentials = service_account.Credentials.from_service_account_file(
1423
+ file_path, scopes=['https://www.googleapis.com/auth/cloud-platform']
1424
+ )
1425
+ project_id = credentials.project_id
1426
+ credentials_to_use = credentials
1427
+ project_id_to_use = project_id
1428
+ credential_source = "GOOGLE_APPLICATION_CREDENTIALS file path"
1429
+ print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
1430
+ except Exception as e:
1431
+ print(f"ERROR: [OpenAI Path] Failed to load credentials from GOOGLE_APPLICATION_CREDENTIALS path ({file_path}): {e}")
1432
+ else:
1433
+ print(f"ERROR: [OpenAI Path] GOOGLE_APPLICATION_CREDENTIALS file does not exist at path: {file_path}")
1434
+
1435
+ # Error if no credentials found after all checks
1436
+ if credentials_to_use is None or project_id_to_use is None:
1437
+ error_msg = "No valid credentials found for OpenAI client path. Tried GOOGLE_CREDENTIALS_JSON, Credential Manager, and GOOGLE_APPLICATION_CREDENTIALS."
1438
+ print(f"ERROR: {error_msg}")
1439
+ error_response = create_openai_error_response(500, error_msg, "server_error")
1440
+ return JSONResponse(status_code=500, content=error_response)
1441
+ # --- Credentials Determined ---
1442
+
1443
+ # Get/Refresh GCP Token from the chosen credentials (credentials_to_use)
1444
+ gcp_token = None
1445
+ if credentials_to_use.expired or not credentials_to_use.token:
1446
+ print(f"INFO: [OpenAI Path] Refreshing GCP token (Source: {credential_source})...")
1447
+ gcp_token = _refresh_auth(credentials_to_use)
1448
+ else:
1449
+ gcp_token = credentials_to_use.token
1450
+
1451
+ if not gcp_token:
1452
+ error_msg = f"Failed to obtain valid GCP token for OpenAI client (Source: {credential_source})."
1453
+ print(f"ERROR: {error_msg}")
1454
+ error_response = create_openai_error_response(500, error_msg, "server_error")
1455
+ return JSONResponse(status_code=500, content=error_response)
1456
+
1457
+ # Configuration using determined Project ID
1458
+ PROJECT_ID = project_id_to_use
1459
+ LOCATION = "us-central1" # Assuming same location as genai client
1460
+ VERTEX_AI_OPENAI_ENDPOINT_URL = (
1461
+ f"https://{LOCATION}-aiplatform.googleapis.com/v1beta1/"
1462
+ f"projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi"
1463
+ )
1464
+ # UNDERLYING_MODEL_ID is now set above based on the request
1465
+
1466
+ # Initialize Async OpenAI Client
1467
+ openai_client = openai.AsyncOpenAI(
1468
+ base_url=VERTEX_AI_OPENAI_ENDPOINT_URL,
1469
+ api_key=gcp_token,
1470
+ )
1471
+
1472
+ # Define standard safety settings (as used elsewhere)
1473
+ openai_safety_settings = [
1474
+ {
1475
+ "category": "HARM_CATEGORY_HARASSMENT",
1476
+ "threshold": "OFF"
1477
+ },
1478
+ {
1479
+ "category": "HARM_CATEGORY_HATE_SPEECH",
1480
+ "threshold": "OFF"
1481
+ },
1482
+ {
1483
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
1484
+ "threshold": "OFF"
1485
+ },
1486
+ {
1487
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
1488
+ "threshold": "OFF"
1489
+ },
1490
+ {
1491
+ "category": 'HARM_CATEGORY_CIVIC_INTEGRITY',
1492
+ "threshold": 'OFF'
1493
+ }
1494
+ ]
1495
+
1496
+ # Prepare parameters for OpenAI client call
1497
+ openai_params = {
1498
+ "model": UNDERLYING_MODEL_ID,
1499
+ "messages": [msg.model_dump(exclude_unset=True) for msg in request.messages],
1500
+ "temperature": request.temperature,
1501
+ "max_tokens": request.max_tokens,
1502
+ "top_p": request.top_p,
1503
+ "stream": request.stream,
1504
+ "stop": request.stop,
1505
+ # "presence_penalty": request.presence_penalty,
1506
+ # "frequency_penalty": request.frequency_penalty,
1507
+ "seed": request.seed,
1508
+ "n": request.n,
1509
+ # Note: logprobs/response_logprobs mapping might need adjustment
1510
+ # Note: top_k is not directly supported by standard OpenAI API spec
1511
+ }
1512
+ # Add safety settings via extra_body
1513
+ openai_extra_body = {
1514
+ 'google': {
1515
+ 'safety_settings': openai_safety_settings
1516
+ }
1517
+ }
1518
+ openai_params = {k: v for k, v in openai_params.items() if v is not None}
1519
+
1520
+
1521
+ # Make the call using OpenAI client
1522
+ if request.stream:
1523
+ async def openai_stream_generator():
1524
+ try:
1525
+ stream = await openai_client.chat.completions.create(
1526
+ **openai_params,
1527
+ extra_body=openai_extra_body # Pass safety settings here
1528
+ )
1529
+ async for chunk in stream:
1530
+ yield f"data: {chunk.model_dump_json()}\n\n"
1531
+ yield "data: [DONE]\n\n"
1532
+ except Exception as stream_error:
1533
+ error_msg = f"Error during OpenAI client streaming for {request.model}: {str(stream_error)}"
1534
+ print(error_msg)
1535
+ error_response_content = create_openai_error_response(500, error_msg, "server_error")
1536
+ yield f"data: {json.dumps(error_response_content)}\n\n"
1537
+ yield "data: [DONE]\n\n"
1538
+
1539
+ return StreamingResponse(openai_stream_generator(), media_type="text/event-stream")
1540
+ else:
1541
+ try:
1542
+ response = await openai_client.chat.completions.create(
1543
+ **openai_params,
1544
+ extra_body=openai_extra_body # Pass safety settings here
1545
+ )
1546
+ return JSONResponse(content=response.model_dump(exclude_unset=True))
1547
+ except Exception as generate_error:
1548
+ error_msg = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
1549
+ print(error_msg)
1550
+ error_response = create_openai_error_response(500, error_msg, "server_error")
1551
+ return JSONResponse(status_code=500, content=error_response)
1552
+
1553
+ # --- End of specific OpenAI client model handling ---
1554
+
1555
+ # Initialize flags before checking suffixes
1556
+ is_auto_model = False
1557
+ is_grounded_search = False
1558
+ is_encrypted_model = False
1559
+ is_encrypted_full_model = False
1560
+ is_nothinking_model = False
1561
+ is_max_thinking_model = False
1562
+ base_model_name = request.model # Default to the full name
1563
+
1564
+ # Check model type and extract base model name
1565
+ if request.model.endswith("-auto"):
1566
+ is_auto_model = True
1567
+ base_model_name = request.model.replace("-auto", "")
1568
+ elif request.model.endswith("-search"):
1569
+ is_grounded_search = True
1570
+ base_model_name = request.model.replace("-search", "")
1571
+ elif request.model.endswith("-encrypt"):
1572
+ is_encrypted_model = True
1573
+ base_model_name = request.model.replace("-encrypt", "")
1574
+ elif request.model.endswith("-encrypt-full"):
1575
+ is_encrypted_full_model = True
1576
+ base_model_name = request.model.replace("-encrypt-full", "")
1577
+ elif request.model.endswith("-nothinking"):
1578
+ is_nothinking_model = True
1579
+ base_model_name = request.model.replace("-nothinking","")
1580
  # Specific check for the flash model requiring budget
1581
+ # Specific check for the flash model requiring budget
1582
+ if base_model_name != "gemini-2.5-flash-preview-04-17":
1583
+ error_response = create_openai_error_response(
1584
+ 400, f"Model '{request.model}' does not support -nothinking variant", "invalid_request_error"
1585
+ )
1586
+ return JSONResponse(status_code=400, content=error_response)
1587
+ elif request.model.endswith("-max"):
1588
+ is_max_thinking_model = True
1589
+ base_model_name = request.model.replace("-max","")
1590
  # Specific check for the flash model requiring budget
1591
+ # Specific check for the flash model requiring budget
1592
+ if base_model_name != "gemini-2.5-flash-preview-04-17":
1593
+ error_response = create_openai_error_response(
1594
+ 400, f"Model '{request.model}' does not support -max variant", "invalid_request_error"
1595
+ )
1596
+ return JSONResponse(status_code=400, content=error_response)
1597
  else:
1598
  base_model_name = request.model
1599
 
 
1632
  types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
1633
  types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
1634
  types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
1635
+ types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
1636
+ types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF")
1637
  ]
1638
  generation_config["safety_settings"] = safety_settings
1639
 
 
1733
  # --- Main Logic ---
1734
  last_error = None
1735
 
1736
+ # --- Main Logic --- (Ensure flags are correctly set if the first 'if' wasn't met)
1737
+ # Re-evaluate flags based on elif structure for clarity if needed, or rely on the fact that the first 'if' returned.
1738
+ is_auto_model = request.model.endswith("-auto") # This will be False if the first 'if' was True
1739
+ is_grounded_search = request.model.endswith("-search")
1740
+ is_encrypted_model = request.model.endswith("-encrypt")
1741
+ is_encrypted_full_model = request.model.endswith("-encrypt-full")
1742
+ is_nothinking_model = request.model.endswith("-nothinking")
1743
+ is_max_thinking_model = request.model.endswith("-max")
1744
+
1745
+ if is_auto_model: # This remains the primary check after the openai specific one
1746
  print(f"Processing auto model: {request.model}")
1747
+ base_model_name = request.model.replace("-auto", "") # Ensure base_model_name is set here too
1748
  # Define encryption instructions for system_instruction
1749
  encryption_instructions = [
1750
  "// AI Assistant Configuration //",
app/requirements.txt CHANGED
@@ -3,4 +3,5 @@ uvicorn==0.27.1
3
  google-auth==2.38.0
4
  google-cloud-aiplatform==1.86.0
5
  pydantic==2.6.1
6
- google-genai==1.8.0
 
 
3
  google-auth==2.38.0
4
  google-cloud-aiplatform==1.86.0
5
  pydantic==2.6.1
6
+ google-genai==1.8.0
7
+ openai