bibibi12345 commited on
Commit
afcace3
·
1 Parent(s): ebfe9fb

added safety score. added proxy

Browse files
app/api_helpers.py CHANGED
@@ -2,27 +2,20 @@ import json
2
  import time
3
  import math
4
  import asyncio
5
- import base64
6
- import random
7
  from typing import List, Dict, Any, Callable, Union, Optional
8
 
9
  from fastapi.responses import JSONResponse, StreamingResponse
10
  from google.auth.transport.requests import Request as AuthRequest
11
  from google.genai import types
12
- from google.genai.types import GenerateContentResponse
13
- from google import genai
14
  from openai import AsyncOpenAI
15
- from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageToolCall
16
- from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
17
 
18
  from models import OpenAIRequest, OpenAIMessage
19
  from message_processing import (
20
- deobfuscate_text,
21
  convert_to_openai_format,
22
  convert_chunk_to_openai,
23
- create_final_chunk,
24
- parse_gemini_response_for_reasoning_and_content,
25
- extract_reasoning_by_tags
26
  )
27
  import config as app_config
28
  from config import VERTEX_REASONING_TAG
@@ -123,12 +116,13 @@ def create_generation_config(request: OpenAIRequest) -> Dict[str, Any]:
123
  if request.seed is not None: config["seed"] = request.seed
124
  if request.n is not None: config["candidate_count"] = request.n
125
 
 
126
  config["safety_settings"] = [
127
- types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
128
- types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
129
- types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
130
- types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
131
- types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF")
132
  ]
133
  # config["thinking_config"] = {"include_thoughts": True}
134
 
@@ -355,18 +349,29 @@ async def openai_fake_stream_generator(
355
  raw_response_obj = await api_call_task
356
  openai_response_dict = raw_response_obj.model_dump(exclude_unset=True, exclude_none=True)
357
 
 
 
 
 
 
 
 
 
 
 
 
358
  if openai_response_dict.get("choices") and \
359
  isinstance(openai_response_dict["choices"], list) and \
360
  len(openai_response_dict["choices"]) > 0:
361
 
362
- first_choice_dict_item = openai_response_dict["choices"]
363
- if first_choice_dict_item and isinstance(first_choice_dict_item, dict) :
364
- choice_message_ref = first_choice_dict_item.get("message", {})
365
  original_content = choice_message_ref.get("content")
366
  if isinstance(original_content, str):
367
  reasoning_text, actual_content = extract_reasoning_by_tags(original_content, VERTEX_REASONING_TAG)
368
  choice_message_ref["content"] = actual_content
369
- if reasoning_text:
370
  choice_message_ref["reasoning_content"] = reasoning_text
371
 
372
  async for chunk_sse in _chunk_openai_response_dict_for_sse(
 
2
  import time
3
  import math
4
  import asyncio
 
 
5
  from typing import List, Dict, Any, Callable, Union, Optional
6
 
7
  from fastapi.responses import JSONResponse, StreamingResponse
8
  from google.auth.transport.requests import Request as AuthRequest
9
  from google.genai import types
 
 
10
  from openai import AsyncOpenAI
11
+
 
12
 
13
  from models import OpenAIRequest, OpenAIMessage
14
  from message_processing import (
 
15
  convert_to_openai_format,
16
  convert_chunk_to_openai,
17
+ extract_reasoning_by_tags,
18
+ _create_safety_ratings_html
 
19
  )
20
  import config as app_config
21
  from config import VERTEX_REASONING_TAG
 
116
  if request.seed is not None: config["seed"] = request.seed
117
  if request.n is not None: config["candidate_count"] = request.n
118
 
119
+ safety_threshold = "BLOCK_NONE" if app_config.SAFETY_SCORE else "BLOCK_ONLY_HIGH"
120
  config["safety_settings"] = [
121
+ types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold=safety_threshold),
122
+ types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold=safety_threshold),
123
+ types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold=safety_threshold),
124
+ types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold=safety_threshold),
125
+ types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold=safety_threshold)
126
  ]
127
  # config["thinking_config"] = {"include_thoughts": True}
128
 
 
349
  raw_response_obj = await api_call_task
350
  openai_response_dict = raw_response_obj.model_dump(exclude_unset=True, exclude_none=True)
351
 
352
+ if app_config.SAFETY_SCORE and hasattr(raw_response_obj, "choices") and raw_response_obj.choices:
353
+ for i, choice_obj in enumerate(raw_response_obj.choices):
354
+ if hasattr(choice_obj, "safety_ratings") and choice_obj.safety_ratings:
355
+ safety_html = _create_safety_ratings_html(choice_obj.safety_ratings)
356
+ if i < len(openai_response_dict.get("choices", [])):
357
+ choice_dict = openai_response_dict["choices"][i]
358
+ message_dict = choice_dict.get("message")
359
+ if message_dict:
360
+ current_content = message_dict.get("content") or ""
361
+ message_dict["content"] = current_content + safety_html
362
+
363
  if openai_response_dict.get("choices") and \
364
  isinstance(openai_response_dict["choices"], list) and \
365
  len(openai_response_dict["choices"]) > 0:
366
 
367
+ first_choice_dict_item = openai_response_dict["choices"]
368
+ if first_choice_dict_item and isinstance(first_choice_dict_item, dict) :
369
+ choice_message_ref = first_choice_dict_item.get("message", {})
370
  original_content = choice_message_ref.get("content")
371
  if isinstance(original_content, str):
372
  reasoning_text, actual_content = extract_reasoning_by_tags(original_content, VERTEX_REASONING_TAG)
373
  choice_message_ref["content"] = actual_content
374
+ if reasoning_text:
375
  choice_message_ref["reasoning_content"] = reasoning_text
376
 
377
  async for chunk_sse in _chunk_openai_response_dict_for_sse(
app/config.py CHANGED
@@ -36,4 +36,11 @@ VERTEX_REASONING_TAG = "vertex_think_tag"
36
  # Round-robin credential selection strategy
37
  ROUNDROBIN = os.environ.get("ROUNDROBIN", "false").lower() == "true"
38
 
39
- # Validation logic moved to app/auth.py
 
 
 
 
 
 
 
 
36
  # Round-robin credential selection strategy
37
  ROUNDROBIN = os.environ.get("ROUNDROBIN", "false").lower() == "true"
38
 
39
+ # Safety score display setting
40
+ SAFETY_SCORE = os.environ.get("SAFETY_SCORE", "false").lower() == "true"
41
+ # Validation logic moved to app/auth.py
42
+
43
+ # Proxy settings
44
+ HTTPS_PROXY = os.environ.get("HTTPS_PROXY")
45
+ SOCKS_PROXY = os.environ.get("SOCKS_PROXY")
46
+ SSL_CERT_FILE = os.environ.get("SSL_CERT_FILE")
app/main.py CHANGED
@@ -1,9 +1,5 @@
1
  from fastapi import FastAPI, Depends # Depends might be used by root endpoint
2
- # from fastapi.responses import JSONResponse # Not used
3
  from fastapi.middleware.cors import CORSMiddleware
4
- # import asyncio # Not used
5
- # import os # Not used
6
-
7
 
8
  # Local module imports
9
  from auth import get_api_key # Potentially for root endpoint
@@ -15,8 +11,6 @@ from vertex_ai_init import init_vertex_ai
15
  from routes import models_api
16
  from routes import chat_api
17
 
18
- # import config as app_config # Not directly used in main.py
19
-
20
  app = FastAPI(title="OpenAI to Gemini Adapter")
21
 
22
  app.add_middleware(
 
1
  from fastapi import FastAPI, Depends # Depends might be used by root endpoint
 
2
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
3
 
4
  # Local module imports
5
  from auth import get_api_key # Potentially for root endpoint
 
11
  from routes import models_api
12
  from routes import chat_api
13
 
 
 
14
  app = FastAPI(title="OpenAI to Gemini Adapter")
15
 
16
  app.add_middleware(
app/message_processing.py CHANGED
@@ -4,7 +4,8 @@ import json
4
  import time
5
  import random # For more unique tool_call_id
6
  import urllib.parse
7
- from typing import List, Dict, Any, Union, Literal, Tuple
 
8
 
9
  from google.genai import types
10
  from models import OpenAIMessage, ContentPartText, ContentPartImage
@@ -292,6 +293,56 @@ def create_encrypted_full_gemini_prompt(messages: List[OpenAIMessage]) -> List[t
292
  return create_encrypted_gemini_prompt(processed_messages)
293
 
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  def deobfuscate_text(text: str) -> str:
296
  if not text: return text
297
  placeholder = "___TRIPLE_BACKTICK_PLACEHOLDER___"
@@ -385,6 +436,13 @@ def process_gemini_response_to_openai_dict(gemini_response_obj: Any, request_mod
385
  reasoning_str = deobfuscate_text(reasoning_str)
386
  normal_content_str = deobfuscate_text(normal_content_str)
387
 
 
 
 
 
 
 
 
388
  message_payload["content"] = normal_content_str
389
  if reasoning_str:
390
  message_payload['reasoning_content'] = reasoning_str
@@ -482,6 +540,13 @@ def convert_chunk_to_openai(chunk: Any, model_name: str, response_id: str, candi
482
  reasoning_text = deobfuscate_text(reasoning_text)
483
  normal_text = deobfuscate_text(normal_text)
484
 
 
 
 
 
 
 
 
485
  if reasoning_text: delta_payload['reasoning_content'] = reasoning_text
486
  if normal_text: # Only add content if it's non-empty
487
  delta_payload['content'] = normal_text
 
4
  import time
5
  import random # For more unique tool_call_id
6
  import urllib.parse
7
+ from typing import List, Dict, Any, Tuple
8
+ from app import config as app_config
9
 
10
  from google.genai import types
11
  from models import OpenAIMessage, ContentPartText, ContentPartImage
 
293
  return create_encrypted_gemini_prompt(processed_messages)
294
 
295
 
296
+ def _create_safety_ratings_html(safety_ratings: list) -> str:
297
+ """Generates a styled HTML block for safety ratings."""
298
+ if not safety_ratings:
299
+ return ""
300
+
301
+ # Find the rating with the highest probability score
302
+ highest_rating = max(safety_ratings, key=lambda r: r.probability_score)
303
+ highest_score = highest_rating.probability_score
304
+
305
+ # Determine color based on the highest score
306
+ if highest_score <= 0.33:
307
+ color = "#0f8" # green
308
+ elif highest_score <= 0.66:
309
+ color = "yellow"
310
+ else:
311
+ color = "#bf555d"
312
+
313
+ # Format the summary line for the highest score
314
+ summary_category = highest_rating.category.name.replace('HARM_CATEGORY_', '').replace('_', ' ').title()
315
+ summary_probability = highest_rating.probability.name
316
+ # Using .7f for score and .8f for severity as per example's precision
317
+ summary_score_str = f"{highest_rating.probability_score:.7f}" if highest_rating.probability_score is not None else "None"
318
+ summary_severity_str = f"{highest_rating.severity_score:.8f}" if highest_rating.severity_score is not None else "None"
319
+ summary_line = f"{summary_category}: {summary_probability} (Score: {summary_score_str}, Severity: {summary_severity_str})"
320
+
321
+ # Format the list of all ratings for the <pre> block
322
+ ratings_list = []
323
+ for rating in safety_ratings:
324
+ category = rating.category.name.replace('HARM_CATEGORY_', '').replace('_', ' ').title()
325
+ probability = rating.probability.name
326
+ score_str = f"{rating.probability_score:.7f}" if rating.probability_score is not None else "None"
327
+ severity_str = f"{rating.severity_score:.8f}" if rating.severity_score is not None else "None"
328
+ ratings_list.append(f"{category}: {probability} (Score: {score_str}, Severity: {severity_str})")
329
+ all_ratings_str = '\n'.join(ratings_list)
330
+
331
+ # CSS Style as specified
332
+ css_style = "<style>.cb{border:1px solid #444;margin:10px;border-radius:4px;background:#111}.cb summary{padding:8px;cursor:pointer;background:#222}.cb pre{margin:0;padding:10px;border-top:1px solid #444;white-space:pre-wrap}</style>"
333
+
334
+ # Final HTML structure
335
+ html_output = (
336
+ f'{css_style}'
337
+ f'<details class="cb">'
338
+ f'<summary style="color:{color}">{summary_line} ▼</summary>'
339
+ f'<pre>\\n--- Safety Ratings ---\\n{all_ratings_str}\\n</pre>'
340
+ f'</details>'
341
+ )
342
+
343
+ return html_output
344
+
345
+
346
  def deobfuscate_text(text: str) -> str:
347
  if not text: return text
348
  placeholder = "___TRIPLE_BACKTICK_PLACEHOLDER___"
 
436
  reasoning_str = deobfuscate_text(reasoning_str)
437
  normal_content_str = deobfuscate_text(normal_content_str)
438
 
439
+ if app_config.SAFETY_SCORE and hasattr(candidate, 'safety_ratings') and candidate.safety_ratings:
440
+ safety_html = _create_safety_ratings_html(candidate.safety_ratings)
441
+ if reasoning_str:
442
+ reasoning_str += safety_html
443
+ else:
444
+ normal_content_str += safety_html
445
+
446
  message_payload["content"] = normal_content_str
447
  if reasoning_str:
448
  message_payload['reasoning_content'] = reasoning_str
 
540
  reasoning_text = deobfuscate_text(reasoning_text)
541
  normal_text = deobfuscate_text(normal_text)
542
 
543
+ if app_config.SAFETY_SCORE and hasattr(candidate, 'safety_ratings') and candidate.safety_ratings:
544
+ safety_html = _create_safety_ratings_html(candidate.safety_ratings)
545
+ if reasoning_text:
546
+ reasoning_text += safety_html
547
+ else:
548
+ normal_text += safety_html
549
+
550
  if reasoning_text: delta_payload['reasoning_content'] = reasoning_text
551
  if normal_text: # Only add content if it's non-empty
552
  delta_payload['content'] = normal_text
app/openai_handler.py CHANGED
@@ -4,13 +4,11 @@ This module encapsulates all OpenAI-specific logic that was previously in chat_a
4
  """
5
  import json
6
  import time
7
- import asyncio
8
  import httpx
9
- from typing import Dict, Any, AsyncGenerator, Optional
10
 
11
  from fastapi.responses import JSONResponse, StreamingResponse
12
  import openai
13
- from google.auth.transport.requests import Request as AuthRequest
14
 
15
  from models import OpenAIRequest
16
  from config import VERTEX_REASONING_TAG
@@ -82,7 +80,11 @@ class ExpressClientWrapper:
82
  if 'extra_body' in payload:
83
  payload.update(payload.pop('extra_body'))
84
 
85
- async with httpx.AsyncClient(timeout=300) as client:
 
 
 
 
86
  async with client.stream("POST", endpoint, headers=headers, params=params, json=payload, timeout=None) as response:
87
  response.raise_for_status()
88
  async for chunk in self._stream_generator(response):
@@ -108,7 +110,11 @@ class ExpressClientWrapper:
108
  if 'extra_body' in payload:
109
  payload.update(payload.pop('extra_body'))
110
 
111
- async with httpx.AsyncClient(timeout=300) as client:
 
 
 
 
112
  response = await client.post(endpoint, headers=headers, params=params, json=payload, timeout=None)
113
  response.raise_for_status()
114
  return FakeChatCompletion(response.json())
@@ -120,12 +126,13 @@ class OpenAIDirectHandler:
120
  def __init__(self, credential_manager=None, express_key_manager=None):
121
  self.credential_manager = credential_manager
122
  self.express_key_manager = express_key_manager
 
123
  self.safety_settings = [
124
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
125
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
126
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
127
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
128
- {"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'OFF'}
129
  ]
130
 
131
  def create_openai_client(self, project_id: str, gcp_token: str, location: str = "global") -> openai.AsyncOpenAI:
@@ -135,9 +142,18 @@ class OpenAIDirectHandler:
135
  f"projects/{project_id}/locations/{location}/endpoints/openapi"
136
  )
137
 
 
 
 
 
 
 
 
 
138
  return openai.AsyncOpenAI(
139
  base_url=endpoint_url,
140
  api_key=gcp_token, # OAuth token
 
141
  )
142
 
143
  def prepare_openai_params(self, request: OpenAIRequest, model_id: str, is_openai_search: bool = False) -> Dict[str, Any]:
 
4
  """
5
  import json
6
  import time
 
7
  import httpx
8
+ from typing import Dict, Any, AsyncGenerator
9
 
10
  from fastapi.responses import JSONResponse, StreamingResponse
11
  import openai
 
12
 
13
  from models import OpenAIRequest
14
  from config import VERTEX_REASONING_TAG
 
80
  if 'extra_body' in payload:
81
  payload.update(payload.pop('extra_body'))
82
 
83
+ proxy = app_config.SOCKS_PROXY or app_config.HTTPS_PROXY
84
+ client_args = {'timeout': 300, 'proxies': proxy}
85
+ if app_config.SSL_CERT_FILE:
86
+ client_args['verify'] = app_config.SSL_CERT_FILE
87
+ async with httpx.AsyncClient(**client_args) as client:
88
  async with client.stream("POST", endpoint, headers=headers, params=params, json=payload, timeout=None) as response:
89
  response.raise_for_status()
90
  async for chunk in self._stream_generator(response):
 
110
  if 'extra_body' in payload:
111
  payload.update(payload.pop('extra_body'))
112
 
113
+ proxy = app_config.SOCKS_PROXY or app_config.HTTPS_PROXY
114
+ client_args = {'timeout': 300, 'proxies': proxy}
115
+ if app_config.SSL_CERT_FILE:
116
+ client_args['verify'] = app_config.SSL_CERT_FILE
117
+ async with httpx.AsyncClient(**client_args) as client:
118
  response = await client.post(endpoint, headers=headers, params=params, json=payload, timeout=None)
119
  response.raise_for_status()
120
  return FakeChatCompletion(response.json())
 
126
  def __init__(self, credential_manager=None, express_key_manager=None):
127
  self.credential_manager = credential_manager
128
  self.express_key_manager = express_key_manager
129
+ safety_threshold = "BLOCK_NONE" if app_config.SAFETY_SCORE else "OFF"
130
  self.safety_settings = [
131
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": safety_threshold},
132
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": safety_threshold},
133
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": safety_threshold},
134
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": safety_threshold},
135
+ {"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": safety_threshold}
136
  ]
137
 
138
  def create_openai_client(self, project_id: str, gcp_token: str, location: str = "global") -> openai.AsyncOpenAI:
 
142
  f"projects/{project_id}/locations/{location}/endpoints/openapi"
143
  )
144
 
145
+ proxy = app_config.SOCKS_PROXY or app_config.HTTPS_PROXY
146
+ client_args = {}
147
+ if proxy:
148
+ client_args['proxies'] = proxy
149
+ if app_config.SSL_CERT_FILE:
150
+ client_args['verify'] = app_config.SSL_CERT_FILE
151
+
152
+ http_client = httpx.AsyncClient(**client_args) if client_args else None
153
  return openai.AsyncOpenAI(
154
  base_url=endpoint_url,
155
  api_key=gcp_token, # OAuth token
156
+ http_client=http_client,
157
  )
158
 
159
  def prepare_openai_params(self, request: OpenAIRequest, model_id: str, is_openai_search: bool = False) -> Dict[str, Any]:
app/project_id_discovery.py CHANGED
@@ -2,11 +2,21 @@ import aiohttp
2
  import json
3
  import re
4
  from typing import Dict, Optional
 
5
 
6
  # Global cache for project IDs: {api_key: project_id}
7
  PROJECT_ID_CACHE: Dict[str, str] = {}
8
 
9
 
 
 
 
 
 
 
 
 
 
10
  async def discover_project_id(api_key: str) -> str:
11
  """
12
  Discover project ID by triggering an intentional error with a non-existent model.
@@ -34,9 +44,10 @@ async def discover_project_id(api_key: str) -> str:
34
  "contents": [{"role": "user", "parts": [{"text": "test"}]}]
35
  }
36
 
 
37
  async with aiohttp.ClientSession() as session:
38
  try:
39
- async with session.post(error_url, json=payload) as response:
40
  response_text = await response.text()
41
 
42
  try:
 
2
  import json
3
  import re
4
  from typing import Dict, Optional
5
+ from app import config
6
 
7
  # Global cache for project IDs: {api_key: project_id}
8
  PROJECT_ID_CACHE: Dict[str, str] = {}
9
 
10
 
11
+ def _get_proxy_url() -> Optional[str]:
12
+ """Get proxy URL from config."""
13
+ if config.SOCKS_PROXY:
14
+ return config.SOCKS_PROXY
15
+ if config.HTTPS_PROXY:
16
+ return config.HTTPS_PROXY
17
+ return None
18
+
19
+
20
  async def discover_project_id(api_key: str) -> str:
21
  """
22
  Discover project ID by triggering an intentional error with a non-existent model.
 
44
  "contents": [{"role": "user", "parts": [{"text": "test"}]}]
45
  }
46
 
47
+ proxy = _get_proxy_url()
48
  async with aiohttp.ClientSession() as session:
49
  try:
50
+ async with session.post(error_url, json=payload, proxy=proxy, ssl=getattr(config, "SSL_CERT_FILE", None)) as response:
51
  response_text = await response.text()
52
 
53
  try:
app/requirements.txt CHANGED
@@ -4,7 +4,7 @@ google-auth==2.38.0
4
  google-cloud-aiplatform==1.86.0
5
  pydantic==2.6.1
6
  google-genai==1.17.0
7
- httpx>=0.25.0
8
  openai
9
  google-auth-oauthlib
10
  aiohttp
 
4
  google-cloud-aiplatform==1.86.0
5
  pydantic==2.6.1
6
  google-genai==1.17.0
7
+ httpx[socks]>=0.25.0
8
  openai
9
  google-auth-oauthlib
10
  aiohttp
app/routes/chat_api.py CHANGED
@@ -1,6 +1,5 @@
1
  import asyncio
2
  import json
3
- import random
4
  from fastapi import APIRouter, Depends, Request
5
  from fastapi.responses import JSONResponse, StreamingResponse
6
 
@@ -11,7 +10,6 @@ from google import genai
11
  # Local module imports
12
  from models import OpenAIRequest
13
  from auth import get_api_key
14
- import config as app_config
15
  from message_processing import (
16
  create_gemini_prompt,
17
  create_encrypted_gemini_prompt,
 
1
  import asyncio
2
  import json
 
3
  from fastapi import APIRouter, Depends, Request
4
  from fastapi.responses import JSONResponse, StreamingResponse
5
 
 
10
  # Local module imports
11
  from models import OpenAIRequest
12
  from auth import get_api_key
 
13
  from message_processing import (
14
  create_gemini_prompt,
15
  create_encrypted_gemini_prompt,
app/routes/models_api.py CHANGED
@@ -3,7 +3,6 @@ from fastapi import APIRouter, Depends, Request
3
  from typing import List, Dict, Any, Set
4
  from auth import get_api_key
5
  from model_loader import get_vertex_models, get_vertex_express_models, refresh_models_config_cache
6
- import config as app_config
7
  from credentials_manager import CredentialManager
8
 
9
  router = APIRouter()
 
3
  from typing import List, Dict, Any, Set
4
  from auth import get_api_key
5
  from model_loader import get_vertex_models, get_vertex_express_models, refresh_models_config_cache
 
6
  from credentials_manager import CredentialManager
7
 
8
  router = APIRouter()
app/vertex_ai_init.py CHANGED
@@ -1,8 +1,8 @@
1
  import json
2
- import asyncio # Added for await
3
  from google import genai
4
  from credentials_manager import CredentialManager, parse_multiple_json_credentials
5
  import config as app_config
 
6
  from model_loader import refresh_models_config_cache # Import new model loader function
7
 
8
  # VERTEX_EXPRESS_MODELS list is now dynamically loaded via model_loader
@@ -11,6 +11,16 @@ from model_loader import refresh_models_config_cache # Import new model loader f
11
 
12
  # Global 'client' and 'get_vertex_client()' are removed.
13
 
 
 
 
 
 
 
 
 
 
 
14
  async def init_vertex_ai(credential_manager_instance: CredentialManager) -> bool: # Made async
15
  """
16
  Initializes the credential manager with credentials from GOOGLE_CREDENTIALS_JSON (if provided)
@@ -85,7 +95,11 @@ async def init_vertex_ai(credential_manager_instance: CredentialManager) -> bool
85
  temp_creds_val, temp_project_id_val = credential_manager_instance.get_credentials()
86
  if temp_creds_val and temp_project_id_val:
87
  try:
88
- _ = genai.Client(vertexai=True, credentials=temp_creds_val, project=temp_project_id_val, location="global")
 
 
 
 
89
  print(f"INFO: Successfully validated a credential from Credential Manager (Project: {temp_project_id_val}). Initialization check passed.")
90
  return True
91
  except Exception as e_val:
 
1
  import json
 
2
  from google import genai
3
  from credentials_manager import CredentialManager, parse_multiple_json_credentials
4
  import config as app_config
5
+ from google.generativeai.client import types
6
  from model_loader import refresh_models_config_cache # Import new model loader function
7
 
8
  # VERTEX_EXPRESS_MODELS list is now dynamically loaded via model_loader
 
11
 
12
  # Global 'client' and 'get_vertex_client()' are removed.
13
 
14
+ def _get_http_options() -> Optional[types.HttpOptions]:
15
+ """Get http options from config."""
16
+ if app_config.SOCKS_PROXY:
17
+ return types.HttpOptions(
18
+ client_args={'proxy': app_config.SOCKS_PROXY},
19
+ async_client_args={'proxy': app_config.SOCKS_PROXY},
20
+ )
21
+ return None
22
+
23
+
24
  async def init_vertex_ai(credential_manager_instance: CredentialManager) -> bool: # Made async
25
  """
26
  Initializes the credential manager with credentials from GOOGLE_CREDENTIALS_JSON (if provided)
 
95
  temp_creds_val, temp_project_id_val = credential_manager_instance.get_credentials()
96
  if temp_creds_val and temp_project_id_val:
97
  try:
98
+ http_options = _get_http_options()
99
+ if http_options:
100
+ _ = genai.Client(vertexai=True, credentials=temp_creds_val, project=temp_project_id_val, location="global", http_options=http_options)
101
+ else:
102
+ _ = genai.Client(vertexai=True, credentials=temp_creds_val, project=temp_project_id_val, location="global")
103
  print(f"INFO: Successfully validated a credential from Credential Manager (Project: {temp_project_id_val}). Initialization check passed.")
104
  return True
105
  except Exception as e_val: