Spaces:
Sleeping
Sleeping
# algoforge_prime/core/llm_clients.py | |
import os | |
import google.generativeai as genai | |
from huggingface_hub import InferenceClient | |
import time # For potential retries or delays | |
# --- Configuration --- | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
GEMINI_API_CONFIGURED = False | |
HF_API_CONFIGURED = False | |
hf_inference_client = None | |
google_gemini_model_instances = {} # To cache initialized Gemini model instances | |
# --- Initialization Function (to be called from app.py) --- | |
def initialize_all_clients(): | |
global GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client | |
# Google Gemini | |
if GOOGLE_API_KEY: | |
try: | |
genai.configure(api_key=GOOGLE_API_KEY) | |
GEMINI_API_CONFIGURED = True | |
print("INFO: llm_clients.py - Google Gemini API configured successfully.") | |
except Exception as e: | |
GEMINI_API_CONFIGURED = False # Ensure it's False on error | |
print(f"ERROR: llm_clients.py - Failed to configure Google Gemini API: {e}") | |
else: | |
print("WARNING: llm_clients.py - GOOGLE_API_KEY not found in environment variables.") | |
# Hugging Face | |
if HF_TOKEN: | |
try: | |
hf_inference_client = InferenceClient(token=HF_TOKEN) | |
HF_API_CONFIGURED = True | |
print("INFO: llm_clients.py - Hugging Face InferenceClient initialized successfully.") | |
except Exception as e: | |
HF_API_CONFIGURED = False # Ensure it's False on error | |
print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient: {e}") | |
else: | |
print("WARNING: llm_clients.py - HF_TOKEN not found in environment variables.") | |
def _get_gemini_model_instance(model_id, system_instruction=None): | |
""" | |
Manages Gemini model instances. | |
Gemini's genai.GenerativeModel is fairly lightweight to create, | |
but caching can avoid repeated setup if system_instruction is complex or model loading is slow. | |
For now, creating a new one each time is fine unless performance becomes an issue. | |
""" | |
if not GEMINI_API_CONFIGURED: | |
raise ConnectionError("Google Gemini API not configured or configuration failed.") | |
try: | |
# For gemini-1.5 models, system_instruction is preferred. | |
# For older gemini-1.0, system instructions might need to be part of the 'contents'. | |
return genai.GenerativeModel( | |
model_name=model_id, | |
system_instruction=system_instruction | |
) | |
except Exception as e: | |
print(f"ERROR: llm_clients.py - Failed to get Gemini model instance for {model_id}: {e}") | |
raise | |
class LLMResponse: | |
def __init__(self, text=None, error=None, success=True, raw_response=None, model_id_used="unknown"): | |
self.text = text | |
self.error = error | |
self.success = success | |
self.raw_response = raw_response | |
self.model_id_used = model_id_used | |
def __str__(self): | |
if self.success: | |
return self.text if self.text is not None else "" | |
return f"ERROR (Model: {self.model_id_used}): {self.error}" | |
def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=512, system_prompt_text=None): | |
if not HF_API_CONFIGURED or not hf_inference_client: | |
return LLMResponse(error="Hugging Face API not configured (HF_TOKEN missing or client init failed).", success=False, model_id_used=model_id) | |
full_prompt = prompt_text | |
# Llama-style system prompt formatting; adjust if using other HF model families | |
if system_prompt_text: | |
full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt_text}\n<</SYS>>\n\n{prompt_text} [/INST]" | |
try: | |
use_sample = temperature > 0.001 # API might treat 0 as no sampling | |
raw_response = hf_inference_client.text_generation( | |
full_prompt, model=model_id, max_new_tokens=max_new_tokens, | |
temperature=temperature if use_sample else None, # None or omit if not sampling | |
do_sample=use_sample, | |
# top_p=0.9 if use_sample else None, # Optional | |
stream=False | |
) | |
return LLMResponse(text=raw_response, raw_response=raw_response, model_id_used=model_id) | |
except Exception as e: | |
error_msg = f"HF API Error ({model_id}): {type(e).__name__} - {str(e)}" | |
print(f"ERROR: llm_clients.py - {error_msg}") | |
return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id) | |
def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=768, system_prompt_text=None): | |
if not GEMINI_API_CONFIGURED: | |
return LLMResponse(error="Google Gemini API not configured (GOOGLE_API_KEY missing or config failed).", success=False, model_id_used=model_id) | |
try: | |
model_instance = _get_gemini_model_instance(model_id, system_instruction=system_prompt_text) | |
generation_config = genai.types.GenerationConfig( | |
temperature=temperature, | |
max_output_tokens=max_new_tokens | |
# top_p=0.9 # Optional | |
) | |
# For Gemini, the main prompt goes directly to generate_content if system_instruction is used. | |
raw_response = model_instance.generate_content( | |
prompt_text, # User prompt | |
generation_config=generation_config, | |
stream=False | |
# safety_settings=[ # Optional: Adjust safety settings if needed, be very careful | |
# {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
# {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
# ] | |
) | |
if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason: | |
reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason | |
error_msg = f"Gemini API: Your prompt was blocked. Reason: {reason}. Try rephrasing." | |
print(f"WARNING: llm_clients.py - {error_msg}") | |
return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id) | |
if not raw_response.candidates: # No candidates usually means it was blocked or an issue. | |
error_msg = "Gemini API: No candidates returned in response. Possibly blocked or internal error." | |
# Check prompt_feedback again, as it might be populated even if candidates are empty. | |
if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason: | |
reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason | |
error_msg = f"Gemini API: Your prompt was blocked (no candidates). Reason: {reason}. Try rephrasing." | |
print(f"WARNING: llm_clients.py - {error_msg}") | |
return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id) | |
# Check the first candidate | |
candidate = raw_response.candidates[0] | |
if not candidate.content or not candidate.content.parts: | |
finish_reason = str(candidate.finish_reason).upper() | |
if finish_reason == "SAFETY": | |
error_msg = f"Gemini API: Response generation stopped by safety filters. Finish Reason: {finish_reason}." | |
elif finish_reason == "RECITATION": | |
error_msg = f"Gemini API: Response generation stopped due to recitation policy. Finish Reason: {finish_reason}." | |
elif finish_reason == "MAX_TOKENS": | |
error_msg = f"Gemini API: Response generation stopped due to max tokens. Consider increasing max_new_tokens. Finish Reason: {finish_reason}." | |
# In this case, there might still be partial text. | |
# For simplicity, we'll treat it as an incomplete generation here. | |
# You could choose to return partial text if desired. | |
# return LLMResponse(text="[PARTIAL RESPONSE - MAX TOKENS REACHED]", ..., model_id_used=model_id) | |
else: | |
error_msg = f"Gemini API: Empty response or no content parts. Finish Reason: {finish_reason}." | |
print(f"WARNING: llm_clients.py - {error_msg}") | |
# Try to get text even if finish_reason is not 'STOP' but not ideal | |
# This part might need refinement based on how you want to handle partial/stopped responses | |
partial_text = "" | |
if candidate.content and candidate.content.parts and candidate.content.parts[0].text: | |
partial_text = candidate.content.parts[0].text | |
if partial_text: | |
return LLMResponse(text=partial_text + f"\n[Note: Generation stopped due to {finish_reason}]", raw_response=raw_response, model_id_used=model_id) | |
else: | |
return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id) | |
return LLMResponse(text=candidate.content.parts[0].text, raw_response=raw_response, model_id_used=model_id) | |
except Exception as e: | |
error_msg = f"Gemini API Call Error ({model_id}): {type(e).__name__} - {str(e)}" | |
# More specific error messages based on common Google API errors | |
if "API key not valid" in str(e) or "PERMISSION_DENIED" in str(e): | |
error_msg = f"Gemini API Error ({model_id}): API key invalid or permission denied. Check GOOGLE_API_KEY and ensure Gemini API is enabled. Original: {str(e)}" | |
elif "Could not find model" in str(e) or "ील नहीं मिला" in str(e): # Hindi for "model not found" | |
error_msg = f"Gemini API Error ({model_id}): Model ID '{model_id}' not found or inaccessible with your key. Original: {str(e)}" | |
elif "User location is not supported" in str(e): | |
error_msg = f"Gemini API Error ({model_id}): User location not supported for this model/API. Original: {str(e)}" | |
elif "Quota exceeded" in str(e): | |
error_msg = f"Gemini API Error ({model_id}): API quota exceeded. Please check your Google Cloud quotas. Original: {str(e)}" | |
print(f"ERROR: llm_clients.py - {error_msg}") | |
return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id) |