bibibi12345 commited on
Commit
52d9215
·
1 Parent(s): dec9a2c

added openai mode for express mode models

Browse files
app/openai_handler.py CHANGED
@@ -5,7 +5,8 @@ This module encapsulates all OpenAI-specific logic that was previously in chat_a
5
  import json
6
  import time
7
  import asyncio
8
- from typing import Dict, Any, AsyncGenerator
 
9
 
10
  from fastapi.responses import JSONResponse, StreamingResponse
11
  import openai
@@ -21,13 +22,104 @@ from api_helpers import (
21
  )
22
  from message_processing import extract_reasoning_by_tags
23
  from credentials_manager import _refresh_auth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  class OpenAIDirectHandler:
27
  """Handles OpenAI Direct mode operations including client creation and response processing."""
28
 
29
- def __init__(self, credential_manager):
30
  self.credential_manager = credential_manager
 
31
  self.safety_settings = [
32
  {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
33
  {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
@@ -35,7 +127,7 @@ class OpenAIDirectHandler:
35
  {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
36
  {"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'OFF'}
37
  ]
38
-
39
  def create_openai_client(self, project_id: str, gcp_token: str, location: str = "global") -> openai.AsyncOpenAI:
40
  """Create an OpenAI client configured for Vertex AI endpoint."""
41
  endpoint_url = (
@@ -80,7 +172,7 @@ class OpenAIDirectHandler:
80
 
81
  async def handle_streaming_response(
82
  self,
83
- openai_client: openai.AsyncOpenAI,
84
  openai_params: Dict[str, Any],
85
  openai_extra_body: Dict[str, Any],
86
  request: OpenAIRequest
@@ -107,7 +199,7 @@ class OpenAIDirectHandler:
107
 
108
  async def _true_stream_generator(
109
  self,
110
- openai_client: openai.AsyncOpenAI,
111
  openai_params: Dict[str, Any],
112
  openai_extra_body: Dict[str, Any],
113
  request: OpenAIRequest
@@ -136,6 +228,7 @@ class OpenAIDirectHandler:
136
  delta = choices[0].get('delta')
137
  if delta and isinstance(delta, dict):
138
  # Always remove extra_content if present
 
139
  if 'extra_content' in delta:
140
  del delta['extra_content']
141
 
@@ -242,7 +335,7 @@ class OpenAIDirectHandler:
242
 
243
  async def handle_non_streaming_response(
244
  self,
245
- openai_client: openai.AsyncOpenAI,
246
  openai_params: Dict[str, Any],
247
  openai_extra_body: Dict[str, Any],
248
  request: OpenAIRequest
@@ -296,44 +389,55 @@ class OpenAIDirectHandler:
296
  content=create_openai_error_response(500, error_msg, "server_error")
297
  )
298
 
299
- async def process_request(self, request: OpenAIRequest, base_model_name: str):
300
  """Main entry point for processing OpenAI Direct mode requests."""
301
- print(f"INFO: Using OpenAI Direct Path for model: {request.model}")
302
-
303
- # Get credentials
304
- rotated_credentials, rotated_project_id = self.credential_manager.get_credentials()
305
-
306
- if not rotated_credentials or not rotated_project_id:
307
- error_msg = "OpenAI Direct Mode requires GCP credentials, but none were available or loaded successfully."
308
- print(f"ERROR: {error_msg}")
309
- return JSONResponse(
310
- status_code=500,
311
- content=create_openai_error_response(500, error_msg, "server_error")
312
- )
313
-
314
- print(f"INFO: [OpenAI Direct Path] Using credentials for project: {rotated_project_id}")
315
- gcp_token = _refresh_auth(rotated_credentials)
316
 
317
- if not gcp_token:
318
- error_msg = f"Failed to obtain valid GCP token for OpenAI client (Project: {rotated_project_id})."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  print(f"ERROR: {error_msg}")
320
- return JSONResponse(
321
- status_code=500,
322
- content=create_openai_error_response(500, error_msg, "server_error")
323
- )
324
-
325
- # Create client and prepare parameters
326
- openai_client = self.create_openai_client(rotated_project_id, gcp_token)
327
- model_id = f"google/{base_model_name}"
328
- openai_params = self.prepare_openai_params(request, model_id)
329
- openai_extra_body = self.prepare_extra_body()
330
-
331
- # Handle streaming vs non-streaming
332
- if request.stream:
333
- return await self.handle_streaming_response(
334
- openai_client, openai_params, openai_extra_body, request
335
- )
336
- else:
337
- return await self.handle_non_streaming_response(
338
- openai_client, openai_params, openai_extra_body, request
339
- )
 
5
  import json
6
  import time
7
  import asyncio
8
+ import httpx
9
+ from typing import Dict, Any, AsyncGenerator, Optional
10
 
11
  from fastapi.responses import JSONResponse, StreamingResponse
12
  import openai
 
22
  )
23
  from message_processing import extract_reasoning_by_tags
24
  from credentials_manager import _refresh_auth
25
+ from project_id_discovery import discover_project_id
26
+
27
+
28
+ # Wrapper classes to mimic OpenAI SDK responses for direct httpx calls
29
+ class FakeChatCompletionChunk:
30
+ """A fake ChatCompletionChunk to wrap the dictionary from a direct API stream."""
31
+ def __init__(self, data: Dict[str, Any]):
32
+ self._data = data
33
+
34
+ def model_dump(self, exclude_unset=True, exclude_none=True) -> Dict[str, Any]:
35
+ return self._data
36
+
37
+ class FakeChatCompletion:
38
+ """A fake ChatCompletion to wrap the dictionary from a direct non-streaming API call."""
39
+ def __init__(self, data: Dict[str, Any]):
40
+ self._data = data
41
+
42
+ def model_dump(self, exclude_unset=True, exclude_none=True) -> Dict[str, Any]:
43
+ return self._data
44
+
45
+ class ExpressClientWrapper:
46
+ """
47
+ A wrapper that mimics the openai.AsyncOpenAI client interface but uses direct
48
+ httpx calls for Vertex AI Express Mode. This allows it to be used with the
49
+ existing response handling logic.
50
+ """
51
+ def __init__(self, project_id: str, api_key: str, location: str = "global"):
52
+ self.project_id = project_id
53
+ self.api_key = api_key
54
+ self.location = location
55
+ self.base_url = f"https://aiplatform.googleapis.com/v1beta1/projects/{self.project_id}/locations/{self.location}/endpoints/openapi"
56
+
57
+ # The 'chat.completions' structure mimics the real OpenAI client
58
+ self.chat = self
59
+ self.completions = self
60
+
61
+ async def _stream_generator(self, response: httpx.Response) -> AsyncGenerator[FakeChatCompletionChunk, None]:
62
+ """Processes the SSE stream from httpx and yields fake chunk objects."""
63
+ async for line in response.aiter_lines():
64
+ if line.startswith("data:"):
65
+ json_str = line[len("data: "):].strip()
66
+ if json_str == "[DONE]":
67
+ break
68
+ try:
69
+ data = json.loads(json_str)
70
+ yield FakeChatCompletionChunk(data)
71
+ except json.JSONDecodeError:
72
+ print(f"Warning: Could not decode JSON from stream line: {json_str}")
73
+ continue
74
+
75
+ async def _streaming_create(self, **kwargs) -> AsyncGenerator[FakeChatCompletionChunk, None]:
76
+ """Handles the creation of a streaming request using httpx."""
77
+ endpoint = f"{self.base_url}/chat/completions"
78
+ headers = {"Content-Type": "application/json"}
79
+ params = {"key": self.api_key}
80
+
81
+ payload = kwargs.copy()
82
+ if 'extra_body' in payload:
83
+ payload.update(payload.pop('extra_body'))
84
+
85
+ async with httpx.AsyncClient(timeout=300) as client:
86
+ async with client.stream("POST", endpoint, headers=headers, params=params, json=payload, timeout=None) as response:
87
+ response.raise_for_status()
88
+ async for chunk in self._stream_generator(response):
89
+ yield chunk
90
+
91
+ async def create(self, **kwargs) -> Any:
92
+ """
93
+ Mimics the 'create' method of the OpenAI client.
94
+ It builds and sends a direct HTTP request using httpx, delegating
95
+ to the appropriate streaming or non-streaming handler.
96
+ """
97
+ is_streaming = kwargs.get("stream", False)
98
+
99
+ if is_streaming:
100
+ return self._streaming_create(**kwargs)
101
+
102
+ # Non-streaming logic
103
+ endpoint = f"{self.base_url}/chat/completions"
104
+ headers = {"Content-Type": "application/json"}
105
+ params = {"key": self.api_key}
106
+
107
+ payload = kwargs.copy()
108
+ if 'extra_body' in payload:
109
+ payload.update(payload.pop('extra_body'))
110
+
111
+ async with httpx.AsyncClient(timeout=300) as client:
112
+ response = await client.post(endpoint, headers=headers, params=params, json=payload, timeout=None)
113
+ response.raise_for_status()
114
+ return FakeChatCompletion(response.json())
115
 
116
 
117
  class OpenAIDirectHandler:
118
  """Handles OpenAI Direct mode operations including client creation and response processing."""
119
 
120
+ def __init__(self, credential_manager=None, express_key_manager=None):
121
  self.credential_manager = credential_manager
122
+ self.express_key_manager = express_key_manager
123
  self.safety_settings = [
124
  {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
125
  {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
 
127
  {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
128
  {"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'OFF'}
129
  ]
130
+
131
  def create_openai_client(self, project_id: str, gcp_token: str, location: str = "global") -> openai.AsyncOpenAI:
132
  """Create an OpenAI client configured for Vertex AI endpoint."""
133
  endpoint_url = (
 
172
 
173
  async def handle_streaming_response(
174
  self,
175
+ openai_client: Any, # Can be openai.AsyncOpenAI or our wrapper
176
  openai_params: Dict[str, Any],
177
  openai_extra_body: Dict[str, Any],
178
  request: OpenAIRequest
 
199
 
200
  async def _true_stream_generator(
201
  self,
202
+ openai_client: Any, # Can be openai.AsyncOpenAI or our wrapper
203
  openai_params: Dict[str, Any],
204
  openai_extra_body: Dict[str, Any],
205
  request: OpenAIRequest
 
228
  delta = choices[0].get('delta')
229
  if delta and isinstance(delta, dict):
230
  # Always remove extra_content if present
231
+
232
  if 'extra_content' in delta:
233
  del delta['extra_content']
234
 
 
335
 
336
  async def handle_non_streaming_response(
337
  self,
338
+ openai_client: Any, # Can be openai.AsyncOpenAI or our wrapper
339
  openai_params: Dict[str, Any],
340
  openai_extra_body: Dict[str, Any],
341
  request: OpenAIRequest
 
389
  content=create_openai_error_response(500, error_msg, "server_error")
390
  )
391
 
392
+ async def process_request(self, request: OpenAIRequest, base_model_name: str, is_express: bool = False):
393
  """Main entry point for processing OpenAI Direct mode requests."""
394
+ print(f"INFO: Using OpenAI Direct Path for model: {request.model} (Express: {is_express})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
+ client: Any = None # Can be openai.AsyncOpenAI or our wrapper
397
+
398
+ try:
399
+ if is_express:
400
+ if not self.express_key_manager:
401
+ raise Exception("Express mode requires an ExpressKeyManager, but it was not provided.")
402
+
403
+ key_tuple = self.express_key_manager.get_express_api_key()
404
+ if not key_tuple:
405
+ raise Exception("OpenAI Express Mode requires an API key, but none were available.")
406
+
407
+ _, express_api_key = key_tuple
408
+ project_id = await discover_project_id(express_api_key)
409
+
410
+ client = ExpressClientWrapper(project_id=project_id, api_key=express_api_key)
411
+ print(f"INFO: [OpenAI Express Path] Using ExpressClientWrapper for project: {project_id}")
412
+
413
+ else: # Standard SA-based OpenAI SDK Path
414
+ if not self.credential_manager:
415
+ raise Exception("Standard OpenAI Direct mode requires a CredentialManager.")
416
+
417
+ rotated_credentials, rotated_project_id = self.credential_manager.get_credentials()
418
+ if not rotated_credentials or not rotated_project_id:
419
+ raise Exception("OpenAI Direct Mode requires GCP credentials, but none were available.")
420
+
421
+ print(f"INFO: [OpenAI Direct Path] Using credentials for project: {rotated_project_id}")
422
+ gcp_token = _refresh_auth(rotated_credentials)
423
+ if not gcp_token:
424
+ raise Exception(f"Failed to obtain valid GCP token for OpenAI client (Project: {rotated_project_id}).")
425
+
426
+ client = self.create_openai_client(rotated_project_id, gcp_token)
427
+
428
+ model_id = f"google/{base_model_name}"
429
+ openai_params = self.prepare_openai_params(request, model_id)
430
+ openai_extra_body = self.prepare_extra_body()
431
+
432
+ if request.stream:
433
+ return await self.handle_streaming_response(
434
+ client, openai_params, openai_extra_body, request
435
+ )
436
+ else:
437
+ return await self.handle_non_streaming_response(
438
+ client, openai_params, openai_extra_body, request
439
+ )
440
+ except Exception as e:
441
+ error_msg = f"Error in process_request for {request.model}: {e}"
442
  print(f"ERROR: {error_msg}")
443
+ return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/routes/chat_api.py CHANGED
@@ -46,9 +46,10 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
46
  is_openai_direct_model = False
47
  if request.model.endswith(OPENAI_DIRECT_SUFFIX):
48
  temp_name_for_marker_check = request.model[:-len(OPENAI_DIRECT_SUFFIX)]
49
- if temp_name_for_marker_check.startswith(PAY_PREFIX):
50
- is_openai_direct_model = True
51
- elif EXPERIMENTAL_MARKER in temp_name_for_marker_check:
 
52
  is_openai_direct_model = True
53
  is_auto_model = request.model.endswith("-auto")
54
  is_grounded_search = request.model.endswith("-search")
@@ -175,8 +176,12 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
175
 
176
  if is_openai_direct_model:
177
  # Use the new OpenAI handler
178
- openai_handler = OpenAIDirectHandler(credential_manager_instance)
179
- return await openai_handler.process_request(request, base_model_name)
 
 
 
 
180
  elif is_auto_model:
181
  print(f"Processing auto model: {request.model}")
182
  attempts = [
 
46
  is_openai_direct_model = False
47
  if request.model.endswith(OPENAI_DIRECT_SUFFIX):
48
  temp_name_for_marker_check = request.model[:-len(OPENAI_DIRECT_SUFFIX)]
49
+ # An OpenAI model can be prefixed with PAY, EXPRESS, or contain EXP
50
+ if temp_name_for_marker_check.startswith(PAY_PREFIX) or \
51
+ temp_name_for_marker_check.startswith(EXPRESS_PREFIX) or \
52
+ EXPERIMENTAL_MARKER in temp_name_for_marker_check:
53
  is_openai_direct_model = True
54
  is_auto_model = request.model.endswith("-auto")
55
  is_grounded_search = request.model.endswith("-search")
 
176
 
177
  if is_openai_direct_model:
178
  # Use the new OpenAI handler
179
+ if is_express_model_request:
180
+ openai_handler = OpenAIDirectHandler(express_key_manager=express_key_manager_instance)
181
+ return await openai_handler.process_request(request, base_model_name, is_express=True)
182
+ else:
183
+ openai_handler = OpenAIDirectHandler(credential_manager=credential_manager_instance)
184
+ return await openai_handler.process_request(request, base_model_name)
185
  elif is_auto_model:
186
  print(f"Processing auto model: {request.model}")
187
  attempts = [
app/routes/models_api.py CHANGED
@@ -106,26 +106,33 @@ async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_k
106
 
107
  # Ensure uniqueness again after adding suffixes
108
  # Add OpenAI direct variations if SA creds are available
109
- if has_sa_creds: # OpenAI direct mode only works with SA credentials
110
- # `all_model_ids` contains the comprehensive list of base models that are eligible based on current credentials
111
- # We iterate through this to determine which ones get an -openai variation.
112
- # `raw_vertex_models` is used here to ensure we only add -openai suffix to models that are
113
- # fundamentally Vertex models, not just any model that might appear in `all_model_ids` (e.g. from Express list exclusively)
114
- # if express only key is provided.
115
- # We iterate through the base models from the main Vertex list.
116
- for base_model_id_for_openai in raw_vertex_models: # Iterate through original list of GAIA/Vertex base models
117
  display_model_id = ""
118
  if EXPERIMENTAL_MARKER in base_model_id_for_openai:
119
  display_model_id = f"{base_model_id_for_openai}{OPENAI_DIRECT_SUFFIX}"
120
  else:
121
  display_model_id = f"{PAY_PREFIX}{base_model_id_for_openai}{OPENAI_DIRECT_SUFFIX}"
122
 
123
- # Check if already added (e.g. if remote config somehow already listed it or added as a base model)
124
  if display_model_id and not any(m['id'] == display_model_id for m in dynamic_models_data):
125
  dynamic_models_data.append({
126
  "id": display_model_id, "object": "model", "created": current_time, "owned_by": "google",
127
  "permission": [], "root": base_model_id_for_openai, "parent": None
128
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  # final_models_data_map = {m["id"]: m for m in dynamic_models_data}
130
  # model_list = list(final_models_data_map.values())
131
  # model_list.sort()
 
106
 
107
  # Ensure uniqueness again after adding suffixes
108
  # Add OpenAI direct variations if SA creds are available
109
+ # Add OpenAI direct variations for SA credentials
110
+ if has_sa_creds:
111
+ for base_model_id_for_openai in raw_vertex_models:
 
 
 
 
 
112
  display_model_id = ""
113
  if EXPERIMENTAL_MARKER in base_model_id_for_openai:
114
  display_model_id = f"{base_model_id_for_openai}{OPENAI_DIRECT_SUFFIX}"
115
  else:
116
  display_model_id = f"{PAY_PREFIX}{base_model_id_for_openai}{OPENAI_DIRECT_SUFFIX}"
117
 
 
118
  if display_model_id and not any(m['id'] == display_model_id for m in dynamic_models_data):
119
  dynamic_models_data.append({
120
  "id": display_model_id, "object": "model", "created": current_time, "owned_by": "google",
121
  "permission": [], "root": base_model_id_for_openai, "parent": None
122
  })
123
+
124
+ # Add OpenAI direct variations for Express keys
125
+ if has_express_key:
126
+ EXPRESS_PREFIX = "[EXPRESS] "
127
+ for base_model_id_for_express_openai in raw_express_models:
128
+ # Express models are prefixed with [EXPRESS]
129
+ display_model_id = f"{EXPRESS_PREFIX}{base_model_id_for_express_openai}{OPENAI_DIRECT_SUFFIX}"
130
+
131
+ if display_model_id and not any(m['id'] == display_model_id for m in dynamic_models_data):
132
+ dynamic_models_data.append({
133
+ "id": display_model_id, "object": "model", "created": current_time, "owned_by": "google",
134
+ "permission": [], "root": base_model_id_for_express_openai, "parent": None
135
+ })
136
  # final_models_data_map = {m["id"]: m for m in dynamic_models_data}
137
  # model_list = list(final_models_data_map.values())
138
  # model_list.sort()