Spaces:
Sleeping
Sleeping
File size: 10,265 Bytes
250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 250b6ae b967045 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# algoforge_prime/core/llm_clients.py
import os
import google.generativeai as genai
from huggingface_hub import InferenceClient
import time # For potential retries or delays
# --- Configuration ---
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")
GEMINI_API_CONFIGURED = False
HF_API_CONFIGURED = False
hf_inference_client = None
google_gemini_model_instances = {} # To cache initialized Gemini model instances
# --- Initialization Function (to be called from app.py) ---
def initialize_all_clients():
global GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client
# Google Gemini
if GOOGLE_API_KEY:
try:
genai.configure(api_key=GOOGLE_API_KEY)
GEMINI_API_CONFIGURED = True
print("INFO: llm_clients.py - Google Gemini API configured successfully.")
except Exception as e:
GEMINI_API_CONFIGURED = False # Ensure it's False on error
print(f"ERROR: llm_clients.py - Failed to configure Google Gemini API: {e}")
else:
print("WARNING: llm_clients.py - GOOGLE_API_KEY not found in environment variables.")
# Hugging Face
if HF_TOKEN:
try:
hf_inference_client = InferenceClient(token=HF_TOKEN)
HF_API_CONFIGURED = True
print("INFO: llm_clients.py - Hugging Face InferenceClient initialized successfully.")
except Exception as e:
HF_API_CONFIGURED = False # Ensure it's False on error
print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient: {e}")
else:
print("WARNING: llm_clients.py - HF_TOKEN not found in environment variables.")
def _get_gemini_model_instance(model_id, system_instruction=None):
"""
Manages Gemini model instances.
Gemini's genai.GenerativeModel is fairly lightweight to create,
but caching can avoid repeated setup if system_instruction is complex or model loading is slow.
For now, creating a new one each time is fine unless performance becomes an issue.
"""
if not GEMINI_API_CONFIGURED:
raise ConnectionError("Google Gemini API not configured or configuration failed.")
try:
# For gemini-1.5 models, system_instruction is preferred.
# For older gemini-1.0, system instructions might need to be part of the 'contents'.
return genai.GenerativeModel(
model_name=model_id,
system_instruction=system_instruction
)
except Exception as e:
print(f"ERROR: llm_clients.py - Failed to get Gemini model instance for {model_id}: {e}")
raise
class LLMResponse:
def __init__(self, text=None, error=None, success=True, raw_response=None, model_id_used="unknown"):
self.text = text
self.error = error
self.success = success
self.raw_response = raw_response
self.model_id_used = model_id_used
def __str__(self):
if self.success:
return self.text if self.text is not None else ""
return f"ERROR (Model: {self.model_id_used}): {self.error}"
def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=512, system_prompt_text=None):
if not HF_API_CONFIGURED or not hf_inference_client:
return LLMResponse(error="Hugging Face API not configured (HF_TOKEN missing or client init failed).", success=False, model_id_used=model_id)
full_prompt = prompt_text
# Llama-style system prompt formatting; adjust if using other HF model families
if system_prompt_text:
full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt_text}\n<</SYS>>\n\n{prompt_text} [/INST]"
try:
use_sample = temperature > 0.001 # API might treat 0 as no sampling
raw_response = hf_inference_client.text_generation(
full_prompt, model=model_id, max_new_tokens=max_new_tokens,
temperature=temperature if use_sample else None, # None or omit if not sampling
do_sample=use_sample,
# top_p=0.9 if use_sample else None, # Optional
stream=False
)
return LLMResponse(text=raw_response, raw_response=raw_response, model_id_used=model_id)
except Exception as e:
error_msg = f"HF API Error ({model_id}): {type(e).__name__} - {str(e)}"
print(f"ERROR: llm_clients.py - {error_msg}")
return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)
def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=768, system_prompt_text=None):
if not GEMINI_API_CONFIGURED:
return LLMResponse(error="Google Gemini API not configured (GOOGLE_API_KEY missing or config failed).", success=False, model_id_used=model_id)
try:
model_instance = _get_gemini_model_instance(model_id, system_instruction=system_prompt_text)
generation_config = genai.types.GenerationConfig(
temperature=temperature,
max_output_tokens=max_new_tokens
# top_p=0.9 # Optional
)
# For Gemini, the main prompt goes directly to generate_content if system_instruction is used.
raw_response = model_instance.generate_content(
prompt_text, # User prompt
generation_config=generation_config,
stream=False
# safety_settings=[ # Optional: Adjust safety settings if needed, be very careful
# {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
# {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
# ]
)
if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
error_msg = f"Gemini API: Your prompt was blocked. Reason: {reason}. Try rephrasing."
print(f"WARNING: llm_clients.py - {error_msg}")
return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
if not raw_response.candidates: # No candidates usually means it was blocked or an issue.
error_msg = "Gemini API: No candidates returned in response. Possibly blocked or internal error."
# Check prompt_feedback again, as it might be populated even if candidates are empty.
if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
error_msg = f"Gemini API: Your prompt was blocked (no candidates). Reason: {reason}. Try rephrasing."
print(f"WARNING: llm_clients.py - {error_msg}")
return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
# Check the first candidate
candidate = raw_response.candidates[0]
if not candidate.content or not candidate.content.parts:
finish_reason = str(candidate.finish_reason).upper()
if finish_reason == "SAFETY":
error_msg = f"Gemini API: Response generation stopped by safety filters. Finish Reason: {finish_reason}."
elif finish_reason == "RECITATION":
error_msg = f"Gemini API: Response generation stopped due to recitation policy. Finish Reason: {finish_reason}."
elif finish_reason == "MAX_TOKENS":
error_msg = f"Gemini API: Response generation stopped due to max tokens. Consider increasing max_new_tokens. Finish Reason: {finish_reason}."
# In this case, there might still be partial text.
# For simplicity, we'll treat it as an incomplete generation here.
# You could choose to return partial text if desired.
# return LLMResponse(text="[PARTIAL RESPONSE - MAX TOKENS REACHED]", ..., model_id_used=model_id)
else:
error_msg = f"Gemini API: Empty response or no content parts. Finish Reason: {finish_reason}."
print(f"WARNING: llm_clients.py - {error_msg}")
# Try to get text even if finish_reason is not 'STOP' but not ideal
# This part might need refinement based on how you want to handle partial/stopped responses
partial_text = ""
if candidate.content and candidate.content.parts and candidate.content.parts[0].text:
partial_text = candidate.content.parts[0].text
if partial_text:
return LLMResponse(text=partial_text + f"\n[Note: Generation stopped due to {finish_reason}]", raw_response=raw_response, model_id_used=model_id)
else:
return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
return LLMResponse(text=candidate.content.parts[0].text, raw_response=raw_response, model_id_used=model_id)
except Exception as e:
error_msg = f"Gemini API Call Error ({model_id}): {type(e).__name__} - {str(e)}"
# More specific error messages based on common Google API errors
if "API key not valid" in str(e) or "PERMISSION_DENIED" in str(e):
error_msg = f"Gemini API Error ({model_id}): API key invalid or permission denied. Check GOOGLE_API_KEY and ensure Gemini API is enabled. Original: {str(e)}"
elif "Could not find model" in str(e) or "ील नहीं मिला" in str(e): # Hindi for "model not found"
error_msg = f"Gemini API Error ({model_id}): Model ID '{model_id}' not found or inaccessible with your key. Original: {str(e)}"
elif "User location is not supported" in str(e):
error_msg = f"Gemini API Error ({model_id}): User location not supported for this model/API. Original: {str(e)}"
elif "Quota exceeded" in str(e):
error_msg = f"Gemini API Error ({model_id}): API quota exceeded. Please check your Google Cloud quotas. Original: {str(e)}"
print(f"ERROR: llm_clients.py - {error_msg}")
return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id) |