mgbam commited on
Commit
10fdd4a
·
verified ·
1 Parent(s): 9b1a7e0

Update core/llm_clients.py

Browse files
Files changed (1) hide show
  1. core/llm_clients.py +97 -86
core/llm_clients.py CHANGED
@@ -5,63 +5,71 @@ from huggingface_hub import InferenceClient
5
  import time # For potential retries or delays
6
 
7
  # --- Configuration ---
8
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
9
- HF_TOKEN = os.getenv("HF_TOKEN")
 
10
 
 
11
  GEMINI_API_CONFIGURED = False
12
  HF_API_CONFIGURED = False
13
 
 
14
  hf_inference_client = None
15
- google_gemini_model_instances = {} # To cache initialized Gemini model instances
 
16
 
17
- # --- Initialization Function (to be called from app.py) ---
18
  def initialize_all_clients():
19
- global GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client
20
 
 
 
21
  # Google Gemini
22
- if GOOGLE_API_KEY:
 
 
23
  try:
 
 
 
24
  genai.configure(api_key=GOOGLE_API_KEY)
 
 
 
 
25
  GEMINI_API_CONFIGURED = True
26
- print("INFO: llm_clients.py - Google Gemini API configured successfully.")
27
  except Exception as e:
28
- GEMINI_API_CONFIGURED = False # Ensure it's False on error
29
- print(f"ERROR: llm_clients.py - Failed to configure Google Gemini API: {e}")
 
30
  else:
31
- print("WARNING: llm_clients.py - GOOGLE_API_KEY not found in environment variables.")
 
32
 
33
  # Hugging Face
34
- if HF_TOKEN:
 
 
35
  try:
36
  hf_inference_client = InferenceClient(token=HF_TOKEN)
 
 
37
  HF_API_CONFIGURED = True
38
- print("INFO: llm_clients.py - Hugging Face InferenceClient initialized successfully.")
39
  except Exception as e:
40
- HF_API_CONFIGURED = False # Ensure it's False on error
41
- print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient: {e}")
 
 
42
  else:
43
- print("WARNING: llm_clients.py - HF_TOKEN not found in environment variables.")
44
-
45
- def _get_gemini_model_instance(model_id, system_instruction=None):
46
- """
47
- Manages Gemini model instances.
48
- Gemini's genai.GenerativeModel is fairly lightweight to create,
49
- but caching can avoid repeated setup if system_instruction is complex or model loading is slow.
50
- For now, creating a new one each time is fine unless performance becomes an issue.
51
- """
52
- if not GEMINI_API_CONFIGURED:
53
- raise ConnectionError("Google Gemini API not configured or configuration failed.")
54
- try:
55
- # For gemini-1.5 models, system_instruction is preferred.
56
- # For older gemini-1.0, system instructions might need to be part of the 'contents'.
57
- return genai.GenerativeModel(
58
- model_name=model_id,
59
- system_instruction=system_instruction
60
- )
61
- except Exception as e:
62
- print(f"ERROR: llm_clients.py - Failed to get Gemini model instance for {model_id}: {e}")
63
- raise
64
 
 
65
  class LLMResponse:
66
  def __init__(self, text=None, error=None, success=True, raw_response=None, model_id_used="unknown"):
67
  self.text = text
@@ -72,55 +80,68 @@ class LLMResponse:
72
 
73
  def __str__(self):
74
  if self.success:
75
- return self.text if self.text is not None else ""
76
  return f"ERROR (Model: {self.model_id_used}): {self.error}"
77
 
 
78
  def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=512, system_prompt_text=None):
 
79
  if not HF_API_CONFIGURED or not hf_inference_client:
80
- return LLMResponse(error="Hugging Face API not configured (HF_TOKEN missing or client init failed).", success=False, model_id_used=model_id)
 
 
81
 
82
  full_prompt = prompt_text
83
- # Llama-style system prompt formatting; adjust if using other HF model families
84
  if system_prompt_text:
85
- full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt_text}\n<</SYS>>\n\n{prompt_text} [/INST]"
86
 
87
  try:
88
- use_sample = temperature > 0.001 # API might treat 0 as no sampling
 
89
  raw_response = hf_inference_client.text_generation(
90
  full_prompt, model=model_id, max_new_tokens=max_new_tokens,
91
- temperature=temperature if use_sample else None, # None or omit if not sampling
92
  do_sample=use_sample,
93
- # top_p=0.9 if use_sample else None, # Optional
94
  stream=False
95
  )
 
96
  return LLMResponse(text=raw_response, raw_response=raw_response, model_id_used=model_id)
97
  except Exception as e:
98
- error_msg = f"HF API Error ({model_id}): {type(e).__name__} - {str(e)}"
99
  print(f"ERROR: llm_clients.py - {error_msg}")
100
  return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)
101
 
102
  def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=768, system_prompt_text=None):
 
103
  if not GEMINI_API_CONFIGURED:
104
- return LLMResponse(error="Google Gemini API not configured (GOOGLE_API_KEY missing or config failed).", success=False, model_id_used=model_id)
 
 
105
 
106
  try:
107
- model_instance = _get_gemini_model_instance(model_id, system_instruction=system_prompt_text)
 
 
 
 
 
 
108
 
109
  generation_config = genai.types.GenerationConfig(
110
  temperature=temperature,
111
  max_output_tokens=max_new_tokens
112
- # top_p=0.9 # Optional
113
  )
114
- # For Gemini, the main prompt goes directly to generate_content if system_instruction is used.
 
 
 
115
  raw_response = model_instance.generate_content(
116
- prompt_text, # User prompt
117
  generation_config=generation_config,
118
  stream=False
119
- # safety_settings=[ # Optional: Adjust safety settings if needed, be very careful
120
- # {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
121
- # {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
122
- # ]
123
  )
 
 
124
 
125
  if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
126
  reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
@@ -128,56 +149,46 @@ def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=768,
128
  print(f"WARNING: llm_clients.py - {error_msg}")
129
  return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
130
 
131
- if not raw_response.candidates: # No candidates usually means it was blocked or an issue.
132
- error_msg = "Gemini API: No candidates returned in response. Possibly blocked or internal error."
133
- # Check prompt_feedback again, as it might be populated even if candidates are empty.
134
- if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
135
- reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
136
- error_msg = f"Gemini API: Your prompt was blocked (no candidates). Reason: {reason}. Try rephrasing."
137
  print(f"WARNING: llm_clients.py - {error_msg}")
138
  return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
139
 
140
-
141
- # Check the first candidate
142
  candidate = raw_response.candidates[0]
143
  if not candidate.content or not candidate.content.parts:
144
- finish_reason = str(candidate.finish_reason).upper()
145
- if finish_reason == "SAFETY":
146
- error_msg = f"Gemini API: Response generation stopped by safety filters. Finish Reason: {finish_reason}."
147
- elif finish_reason == "RECITATION":
148
- error_msg = f"Gemini API: Response generation stopped due to recitation policy. Finish Reason: {finish_reason}."
149
- elif finish_reason == "MAX_TOKENS":
150
- error_msg = f"Gemini API: Response generation stopped due to max tokens. Consider increasing max_new_tokens. Finish Reason: {finish_reason}."
151
- # In this case, there might still be partial text.
152
- # For simplicity, we'll treat it as an incomplete generation here.
153
- # You could choose to return partial text if desired.
154
- # return LLMResponse(text="[PARTIAL RESPONSE - MAX TOKENS REACHED]", ..., model_id_used=model_id)
155
- else:
156
- error_msg = f"Gemini API: Empty response or no content parts. Finish Reason: {finish_reason}."
157
  print(f"WARNING: llm_clients.py - {error_msg}")
158
- # Try to get text even if finish_reason is not 'STOP' but not ideal
159
- # This part might need refinement based on how you want to handle partial/stopped responses
160
  partial_text = ""
161
- if candidate.content and candidate.content.parts and candidate.content.parts[0].text:
162
  partial_text = candidate.content.parts[0].text
163
- if partial_text:
164
- return LLMResponse(text=partial_text + f"\n[Note: Generation stopped due to {finish_reason}]", raw_response=raw_response, model_id_used=model_id)
165
- else:
 
166
  return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
167
 
168
- return LLMResponse(text=candidate.content.parts[0].text, raw_response=raw_response, model_id_used=model_id)
 
 
169
 
170
  except Exception as e:
171
- error_msg = f"Gemini API Call Error ({model_id}): {type(e).__name__} - {str(e)}"
172
- # More specific error messages based on common Google API errors
173
  if "API key not valid" in str(e) or "PERMISSION_DENIED" in str(e):
174
- error_msg = f"Gemini API Error ({model_id}): API key invalid or permission denied. Check GOOGLE_API_KEY and ensure Gemini API is enabled. Original: {str(e)}"
175
- elif "Could not find model" in str(e) or "ील नहीं मिला" in str(e): # Hindi for "model not found"
176
  error_msg = f"Gemini API Error ({model_id}): Model ID '{model_id}' not found or inaccessible with your key. Original: {str(e)}"
177
  elif "User location is not supported" in str(e):
178
  error_msg = f"Gemini API Error ({model_id}): User location not supported for this model/API. Original: {str(e)}"
179
- elif "Quota exceeded" in str(e):
180
  error_msg = f"Gemini API Error ({model_id}): API quota exceeded. Please check your Google Cloud quotas. Original: {str(e)}"
181
-
182
  print(f"ERROR: llm_clients.py - {error_msg}")
183
  return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)
 
5
  import time # For potential retries or delays
6
 
7
  # --- Configuration ---
8
+ # These will be populated by os.getenv()
9
+ GOOGLE_API_KEY = None
10
+ HF_TOKEN = None
11
 
12
+ # Status flags, default to False
13
  GEMINI_API_CONFIGURED = False
14
  HF_API_CONFIGURED = False
15
 
16
+ # Client instances
17
  hf_inference_client = None
18
+ # google_gemini_model_instances cache is not strictly necessary as genai.GenerativeModel is light.
19
+ # Removing it for now to simplify, can be added back if model instantiation proves slow.
20
 
21
+ # --- Initialization Function (to be called from app.py's global scope) ---
22
  def initialize_all_clients():
23
+ global GOOGLE_API_KEY, HF_TOKEN, GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client
24
 
25
+ print("INFO: llm_clients.py - Attempting to initialize all API clients...")
26
+
27
  # Google Gemini
28
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
29
+ if GOOGLE_API_KEY and GOOGLE_API_KEY.strip(): # Check if key is not None and not just whitespace
30
+ print("INFO: llm_clients.py - GOOGLE_API_KEY found in environment.")
31
  try:
32
+ # Test configuration by making a very simple, non-resource-intensive call
33
+ # or by listing models if supported and cheap.
34
+ # For now, genai.configure() is the main check.
35
  genai.configure(api_key=GOOGLE_API_KEY)
36
+ # Optionally, try to list models or a similar lightweight check if genai.configure isn't enough
37
+ # models = [m for m in genai.list_models() if 'generateContent' in m.supported_generation_methods]
38
+ # if not models:
39
+ # raise Exception("No usable Gemini models found with this API key, or API not fully enabled.")
40
  GEMINI_API_CONFIGURED = True
41
+ print("SUCCESS: llm_clients.py - Google Gemini API configured successfully.")
42
  except Exception as e:
43
+ GEMINI_API_CONFIGURED = False
44
+ print(f"ERROR: llm_clients.py - Failed to configure/validate Google Gemini API. Key value might be invalid, API not enabled in Google Cloud, or other issue.")
45
+ print(f" Gemini Init Error Details: {type(e).__name__}: {e}")
46
  else:
47
+ GEMINI_API_CONFIGURED = False # Explicitly set if key is missing/empty
48
+ print("WARNING: llm_clients.py - GOOGLE_API_KEY not found or is empty in environment variables.")
49
 
50
  # Hugging Face
51
+ HF_TOKEN = os.getenv("HF_TOKEN")
52
+ if HF_TOKEN and HF_TOKEN.strip(): # Check if token is not None and not just whitespace
53
+ print("INFO: llm_clients.py - HF_TOKEN found in environment.")
54
  try:
55
  hf_inference_client = InferenceClient(token=HF_TOKEN)
56
+ # Optionally, you could try a very quick ping to a known small public model if client init isn't enough
57
+ # hf_inference_client.text_generation("ping", model="gpt2", max_new_tokens=1)
58
  HF_API_CONFIGURED = True
59
+ print("SUCCESS: llm_clients.py - Hugging Face InferenceClient initialized successfully.")
60
  except Exception as e:
61
+ HF_API_CONFIGURED = False
62
+ print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient. Token might be invalid or other issue.")
63
+ print(f" HF Init Error Details: {type(e).__name__}: {e}")
64
+ hf_inference_client = None # Ensure client is None on failure
65
  else:
66
+ HF_API_CONFIGURED = False # Explicitly set if token is missing/empty
67
+ print("WARNING: llm_clients.py - HF_TOKEN not found or is empty in environment variables.")
68
+
69
+ print(f"INFO: llm_clients.py - Initialization complete. Gemini Configured: {GEMINI_API_CONFIGURED}, HF Configured: {HF_API_CONFIGURED}")
70
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # This class remains useful for standardizing responses
73
  class LLMResponse:
74
  def __init__(self, text=None, error=None, success=True, raw_response=None, model_id_used="unknown"):
75
  self.text = text
 
80
 
81
  def __str__(self):
82
  if self.success:
83
+ return str(self.text) if self.text is not None else "" # Ensure text is string
84
  return f"ERROR (Model: {self.model_id_used}): {self.error}"
85
 
86
+
87
  def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=512, system_prompt_text=None):
88
+ print(f"DEBUG: llm_clients.py - call_huggingface_api attempt for model: {model_id}")
89
  if not HF_API_CONFIGURED or not hf_inference_client:
90
+ error_msg = "Hugging Face API not configured (HF_TOKEN missing, client init failed, or token invalid)."
91
+ print(f"ERROR: llm_clients.py - {error_msg}")
92
+ return LLMResponse(error=error_msg, success=False, model_id_used=model_id)
93
 
94
  full_prompt = prompt_text
 
95
  if system_prompt_text:
96
+ full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt_text}\n<</SYS>>\n\n{prompt_text} [/INST]" # Llama-style
97
 
98
  try:
99
+ print(f" HF API Call - Prompt (first 100 chars): {full_prompt[:100]}...")
100
+ use_sample = temperature > 0.001
101
  raw_response = hf_inference_client.text_generation(
102
  full_prompt, model=model_id, max_new_tokens=max_new_tokens,
103
+ temperature=temperature if use_sample else None,
104
  do_sample=use_sample,
 
105
  stream=False
106
  )
107
+ print(f" HF API Call - Success for model: {model_id}. Response (first 100 chars): {str(raw_response)[:100]}...")
108
  return LLMResponse(text=raw_response, raw_response=raw_response, model_id_used=model_id)
109
  except Exception as e:
110
+ error_msg = f"HF API Error during text_generation ({model_id}): {type(e).__name__} - {str(e)}"
111
  print(f"ERROR: llm_clients.py - {error_msg}")
112
  return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)
113
 
114
  def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=768, system_prompt_text=None):
115
+ print(f"DEBUG: llm_clients.py - call_gemini_api attempt for model: {model_id}")
116
  if not GEMINI_API_CONFIGURED:
117
+ error_msg = "Google Gemini API not configured (GOOGLE_API_KEY missing, config failed, or key invalid)."
118
+ print(f"ERROR: llm_clients.py - {error_msg}")
119
+ return LLMResponse(error=error_msg, success=False, model_id_used=model_id)
120
 
121
  try:
122
+ # genai.GenerativeModel is the recommended way to get a model instance.
123
+ # system_instruction is preferred for newer models (like 1.5 series).
124
+ print(f" Gemini API Call - Getting model instance for: {model_id}")
125
+ model_instance = genai.GenerativeModel(
126
+ model_name=model_id,
127
+ system_instruction=system_prompt_text # Pass system prompt here
128
+ )
129
 
130
  generation_config = genai.types.GenerationConfig(
131
  temperature=temperature,
132
  max_output_tokens=max_new_tokens
 
133
  )
134
+
135
+ print(f" Gemini API Call - Prompt (first 100 chars): {prompt_text[:100]}...")
136
+ if system_prompt_text: print(f" Gemini API Call - System Prompt (first 100 chars): {system_prompt_text[:100]}...")
137
+
138
  raw_response = model_instance.generate_content(
139
+ prompt_text, # User prompt directly if system_instruction is used
140
  generation_config=generation_config,
141
  stream=False
 
 
 
 
142
  )
143
+ print(f" Gemini API Call - Raw response received for model: {model_id}. Prompt feedback: {raw_response.prompt_feedback}, Candidates: {'Yes' if raw_response.candidates else 'No'}")
144
+
145
 
146
  if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
147
  reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
 
149
  print(f"WARNING: llm_clients.py - {error_msg}")
150
  return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
151
 
152
+ if not raw_response.candidates:
153
+ error_msg = "Gemini API: No candidates returned in response. This often indicates the prompt was blocked or an internal error occurred before generation."
154
+ if raw_response.prompt_feedback: error_msg += f" Prompt Feedback: {raw_response.prompt_feedback}"
 
 
 
155
  print(f"WARNING: llm_clients.py - {error_msg}")
156
  return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
157
 
 
 
158
  candidate = raw_response.candidates[0]
159
  if not candidate.content or not candidate.content.parts:
160
+ finish_reason = str(candidate.finish_reason if candidate.finish_reason else "UNKNOWN").upper()
161
+ error_msg = f"Gemini API: Response generation stopped or yielded no content parts. Finish Reason: {finish_reason}."
162
+ if finish_reason == "SAFETY": error_msg += " Likely due to safety filters."
163
+ elif finish_reason == "RECITATION": error_msg += " Likely due to recitation policy."
164
+ elif finish_reason == "MAX_TOKENS": error_msg += " Consider increasing max_new_tokens if content seems truncated."
165
+
 
 
 
 
 
 
 
166
  print(f"WARNING: llm_clients.py - {error_msg}")
167
+ # Attempt to extract partial text if MAX_TOKENS or other non-error finish reasons
 
168
  partial_text = ""
169
+ if candidate.content and candidate.content.parts and hasattr(candidate.content.parts[0], 'text'):
170
  partial_text = candidate.content.parts[0].text
171
+
172
+ if partial_text and finish_reason != "SAFETY" and finish_reason != "RECITATION" and finish_reason != "OTHER": # Only return partial if not a hard block
173
+ return LLMResponse(text=partial_text + f"\n[Note: Generation ended due to {finish_reason}]", raw_response=raw_response, model_id_used=model_id)
174
+ else: # If safety/recitation or truly no text, return as error
175
  return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
176
 
177
+ response_text = candidate.content.parts[0].text
178
+ print(f" Gemini API Call - Success for model: {model_id}. Response text (first 100 chars): {response_text[:100]}...")
179
+ return LLMResponse(text=response_text, raw_response=raw_response, model_id_used=model_id)
180
 
181
  except Exception as e:
182
+ error_msg = f"Gemini API Call Exception ({model_id}): {type(e).__name__} - {str(e)}"
183
+ # Specific error parsing from previous version is good, let's keep it.
184
  if "API key not valid" in str(e) or "PERMISSION_DENIED" in str(e):
185
+ error_msg = f"Gemini API Error ({model_id}): API key invalid or permission denied. Check GOOGLE_API_KEY and ensure Gemini API is enabled in Google Cloud. Original: {str(e)}"
186
+ elif "Could not find model" in str(e) or "ील नहीं मिला" in str(e):
187
  error_msg = f"Gemini API Error ({model_id}): Model ID '{model_id}' not found or inaccessible with your key. Original: {str(e)}"
188
  elif "User location is not supported" in str(e):
189
  error_msg = f"Gemini API Error ({model_id}): User location not supported for this model/API. Original: {str(e)}"
190
+ elif "Quota exceeded" in str(e): # Check for "Quota" in the error message from Google
191
  error_msg = f"Gemini API Error ({model_id}): API quota exceeded. Please check your Google Cloud quotas. Original: {str(e)}"
192
+
193
  print(f"ERROR: llm_clients.py - {error_msg}")
194
  return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)