File size: 10,265 Bytes
250b6ae
 
 
 
b967045
250b6ae
 
 
 
 
 
 
 
 
b967045
250b6ae
b967045
 
 
250b6ae
 
 
 
 
 
b967045
250b6ae
b967045
250b6ae
 
b967045
250b6ae
 
 
 
 
 
b967045
250b6ae
b967045
250b6ae
 
b967045
 
 
 
 
 
 
 
 
250b6ae
b967045
 
 
 
 
 
 
 
 
 
 
250b6ae
 
b967045
250b6ae
 
 
b967045
 
250b6ae
 
 
 
b967045
250b6ae
b967045
250b6ae
b967045
250b6ae
 
b967045
 
 
250b6ae
 
b967045
250b6ae
 
b967045
 
 
 
250b6ae
b967045
250b6ae
b967045
250b6ae
b967045
250b6ae
b967045
250b6ae
b967045
250b6ae
 
b967045
250b6ae
 
 
 
b967045
250b6ae
b967045
250b6ae
 
 
 
b967045
 
 
 
250b6ae
 
 
 
b967045
250b6ae
b967045
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250b6ae
b967045
250b6ae
b967045
 
 
 
 
 
 
 
 
250b6ae
b967045
250b6ae
 
b967045
 
 
 
 
 
 
 
 
 
 
250b6ae
b967045
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# algoforge_prime/core/llm_clients.py
import os
import google.generativeai as genai
from huggingface_hub import InferenceClient
import time # For potential retries or delays

# --- Configuration ---
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")

GEMINI_API_CONFIGURED = False
HF_API_CONFIGURED = False

hf_inference_client = None
google_gemini_model_instances = {} # To cache initialized Gemini model instances

# --- Initialization Function (to be called from app.py) ---
def initialize_all_clients():
    global GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client
    
    # Google Gemini
    if GOOGLE_API_KEY:
        try:
            genai.configure(api_key=GOOGLE_API_KEY)
            GEMINI_API_CONFIGURED = True
            print("INFO: llm_clients.py - Google Gemini API configured successfully.")
        except Exception as e:
            GEMINI_API_CONFIGURED = False # Ensure it's False on error
            print(f"ERROR: llm_clients.py - Failed to configure Google Gemini API: {e}")
    else:
        print("WARNING: llm_clients.py - GOOGLE_API_KEY not found in environment variables.")

    # Hugging Face
    if HF_TOKEN:
        try:
            hf_inference_client = InferenceClient(token=HF_TOKEN)
            HF_API_CONFIGURED = True
            print("INFO: llm_clients.py - Hugging Face InferenceClient initialized successfully.")
        except Exception as e:
            HF_API_CONFIGURED = False # Ensure it's False on error
            print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient: {e}")
    else:
        print("WARNING: llm_clients.py - HF_TOKEN not found in environment variables.")

def _get_gemini_model_instance(model_id, system_instruction=None):
    """
    Manages Gemini model instances.
    Gemini's genai.GenerativeModel is fairly lightweight to create,
    but caching can avoid repeated setup if system_instruction is complex or model loading is slow.
    For now, creating a new one each time is fine unless performance becomes an issue.
    """
    if not GEMINI_API_CONFIGURED:
        raise ConnectionError("Google Gemini API not configured or configuration failed.")
    try:
        # For gemini-1.5 models, system_instruction is preferred.
        # For older gemini-1.0, system instructions might need to be part of the 'contents'.
        return genai.GenerativeModel(
            model_name=model_id,
            system_instruction=system_instruction
        )
    except Exception as e:
        print(f"ERROR: llm_clients.py - Failed to get Gemini model instance for {model_id}: {e}")
        raise

class LLMResponse:
    def __init__(self, text=None, error=None, success=True, raw_response=None, model_id_used="unknown"):
        self.text = text
        self.error = error
        self.success = success
        self.raw_response = raw_response
        self.model_id_used = model_id_used

    def __str__(self):
        if self.success:
            return self.text if self.text is not None else ""
        return f"ERROR (Model: {self.model_id_used}): {self.error}"

def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=512, system_prompt_text=None):
    if not HF_API_CONFIGURED or not hf_inference_client:
        return LLMResponse(error="Hugging Face API not configured (HF_TOKEN missing or client init failed).", success=False, model_id_used=model_id)

    full_prompt = prompt_text
    # Llama-style system prompt formatting; adjust if using other HF model families
    if system_prompt_text:
        full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt_text}\n<</SYS>>\n\n{prompt_text} [/INST]"

    try:
        use_sample = temperature > 0.001 # API might treat 0 as no sampling
        raw_response = hf_inference_client.text_generation(
            full_prompt, model=model_id, max_new_tokens=max_new_tokens,
            temperature=temperature if use_sample else None, # None or omit if not sampling
            do_sample=use_sample, 
            # top_p=0.9 if use_sample else None, # Optional
            stream=False
        )
        return LLMResponse(text=raw_response, raw_response=raw_response, model_id_used=model_id)
    except Exception as e:
        error_msg = f"HF API Error ({model_id}): {type(e).__name__} - {str(e)}"
        print(f"ERROR: llm_clients.py - {error_msg}")
        return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)

def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=768, system_prompt_text=None):
    if not GEMINI_API_CONFIGURED:
        return LLMResponse(error="Google Gemini API not configured (GOOGLE_API_KEY missing or config failed).", success=False, model_id_used=model_id)

    try:
        model_instance = _get_gemini_model_instance(model_id, system_instruction=system_prompt_text)
        
        generation_config = genai.types.GenerationConfig(
            temperature=temperature,
            max_output_tokens=max_new_tokens
            # top_p=0.9 # Optional
        )
        # For Gemini, the main prompt goes directly to generate_content if system_instruction is used.
        raw_response = model_instance.generate_content(
            prompt_text, # User prompt
            generation_config=generation_config,
            stream=False
            # safety_settings=[ # Optional: Adjust safety settings if needed, be very careful
            #    {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
            #    {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
            # ]
        )

        if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
            reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
            error_msg = f"Gemini API: Your prompt was blocked. Reason: {reason}. Try rephrasing."
            print(f"WARNING: llm_clients.py - {error_msg}")
            return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)

        if not raw_response.candidates: # No candidates usually means it was blocked or an issue.
            error_msg = "Gemini API: No candidates returned in response. Possibly blocked or internal error."
            # Check prompt_feedback again, as it might be populated even if candidates are empty.
            if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
                 reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
                 error_msg = f"Gemini API: Your prompt was blocked (no candidates). Reason: {reason}. Try rephrasing."
            print(f"WARNING: llm_clients.py - {error_msg}")
            return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)


        # Check the first candidate
        candidate = raw_response.candidates[0]
        if not candidate.content or not candidate.content.parts:
            finish_reason = str(candidate.finish_reason).upper()
            if finish_reason == "SAFETY":
                error_msg = f"Gemini API: Response generation stopped by safety filters. Finish Reason: {finish_reason}."
            elif finish_reason == "RECITATION":
                 error_msg = f"Gemini API: Response generation stopped due to recitation policy. Finish Reason: {finish_reason}."
            elif finish_reason == "MAX_TOKENS":
                 error_msg = f"Gemini API: Response generation stopped due to max tokens. Consider increasing max_new_tokens. Finish Reason: {finish_reason}."
                 # In this case, there might still be partial text.
                 # For simplicity, we'll treat it as an incomplete generation here.
                 # You could choose to return partial text if desired.
                 # return LLMResponse(text="[PARTIAL RESPONSE - MAX TOKENS REACHED]", ..., model_id_used=model_id)
            else:
                error_msg = f"Gemini API: Empty response or no content parts. Finish Reason: {finish_reason}."
            print(f"WARNING: llm_clients.py - {error_msg}")
            # Try to get text even if finish_reason is not 'STOP' but not ideal
            # This part might need refinement based on how you want to handle partial/stopped responses
            partial_text = ""
            if candidate.content and candidate.content.parts and candidate.content.parts[0].text:
                partial_text = candidate.content.parts[0].text
            if partial_text:
                 return LLMResponse(text=partial_text + f"\n[Note: Generation stopped due to {finish_reason}]", raw_response=raw_response, model_id_used=model_id)
            else:
                 return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
        
        return LLMResponse(text=candidate.content.parts[0].text, raw_response=raw_response, model_id_used=model_id)

    except Exception as e:
        error_msg = f"Gemini API Call Error ({model_id}): {type(e).__name__} - {str(e)}"
        # More specific error messages based on common Google API errors
        if "API key not valid" in str(e) or "PERMISSION_DENIED" in str(e):
            error_msg = f"Gemini API Error ({model_id}): API key invalid or permission denied. Check GOOGLE_API_KEY and ensure Gemini API is enabled. Original: {str(e)}"
        elif "Could not find model" in str(e) or "ील नहीं मिला" in str(e): # Hindi for "model not found"
            error_msg = f"Gemini API Error ({model_id}): Model ID '{model_id}' not found or inaccessible with your key. Original: {str(e)}"
        elif "User location is not supported" in str(e):
            error_msg = f"Gemini API Error ({model_id}): User location not supported for this model/API. Original: {str(e)}"
        elif "Quota exceeded" in str(e):
            error_msg = f"Gemini API Error ({model_id}): API quota exceeded. Please check your Google Cloud quotas. Original: {str(e)}"

        print(f"ERROR: llm_clients.py - {error_msg}")
        return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)