Futuresony commited on
Commit
0990c9c
·
verified ·
1 Parent(s): 43855ed

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -1133
app.py DELETED
@@ -1,1133 +0,0 @@
1
-
2
- import os
3
- import gradio as gr
4
- from huggingface_hub import InferenceClient
5
- import torch
6
- import re
7
- import warnings
8
- import time
9
- import json
10
- # Removed specific transformers imports that might not be strictly necessary for InferenceClient
11
- # from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
12
- from sentence_transformers import SentenceTransformer, util, CrossEncoder
13
- import gspread
14
- # Removed google.auth.default as service account from dict is used
15
- # from google.auth import default
16
- from tqdm import tqdm
17
- from ddgs import DDGS
18
- import spacy
19
- from datetime import date, timedelta, datetime
20
- from dateutil.relativedelta import relativedelta # Corrected import
21
- import traceback
22
- import base64
23
- import dateparser
24
- from dateparser.search import search_dates
25
- import pytz
26
- # Removed userdata as secrets are accessed via environment variables in Spaces
27
- #from google.colab import userdata
28
- import os # Ensure os is imported for environment variables
29
- from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
30
- from huggingface_hub import HfApi, login # Import login for initial auth
31
-
32
- import faiss
33
- import numpy as np
34
- import pickle
35
-
36
-
37
- # --- SQL Logging Imports and Connection Placeholder (Removed for HF Space) ---
38
- # Removed SQL related code as per user's request to use HF Datasets
39
- # ---
40
-
41
- # Define the dataset name (replace with your actual Hugging Face username and desired dataset name)
42
- # Ensure this dataset is set to private on the Hugging Face Hub
43
- dataset_name = "Futuresony/Logs_Conversation" # REPLACE WITH YOUR ACTUAL DATASET NAME
44
-
45
- # Global variable to store the dataset
46
- conversation_dataset = None
47
-
48
- # Initialize HfApi for pushing - Use token from environment variable
49
- # No need to re-initialize HfApi with token here as login handles it
50
- # hf_api = HfApi(token=HF_TOKEN)
51
-
52
-
53
- # Suppress warnings
54
- warnings.filterwarnings("ignore", category=UserWarning)
55
-
56
- # Define global variables and load secrets from environment variables for HF Spaces
57
- # HF_TOKEN is now accessed via os.environ or handled by huggingface_hub login
58
- HF_TOKEN = os.getenv("HF_TOKEN") # Access HF_TOKEN from environment variable
59
- # Add a print statement to check if HF_TOKEN is loaded
60
- print(f"HF_TOKEN loaded: {'*' * len(HF_TOKEN) if HF_TOKEN else 'None'}")
61
-
62
- SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw"
63
- # GOOGLE_BASE64_CREDENTIALS is now accessed via os.environ
64
- GOOGLE_BASE64_CREDENTIALS = os.getenv("GOOGLE_BASE64_CREDENTIALS")
65
-
66
-
67
- # SECRET_API_KEY is now accessed via os.environ
68
- SECRET_API_KEY = os.getenv("APP_API_KEY")
69
- # Add a print statement to check if SECRET_API_KEY is loaded
70
- print(f"SECRET_API_KEY loaded: {'*' * len(SECRET_API_KEY) if SECRET_API_KEY else 'None'}")
71
-
72
- if not SECRET_API_KEY:
73
- print("Warning: APP_API_KEY secret not set. API key validation will fail.")
74
- elif not SECRET_API_KEY.startswith("fs_"):
75
- print("Warning: APP_API_KEY secret does not start with 'fs_'. Please check your secret.")
76
-
77
- # Authenticate with Hugging Face Hub using the token from environment variable
78
- try:
79
- print("Attempting to authenticate with Hugging Face Hub...")
80
- # login() automatically looks for HF_TOKEN in environment variables
81
- login(add_to_git_credential=True)
82
- print("Hugging Face Hub authentication successful.")
83
- except Exception as e:
84
- print(f"Hugging Face Hub authentication failed: {e}")
85
- print(traceback.format_exc())
86
-
87
-
88
- # Initialize InferenceClient for primary model (LLaMA-3.3-70B-Instruct)
89
- primary_client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", token=HF_TOKEN)
90
- print("Primary model (LLaMA-3.3-70B-Instruct) client initialized.")
91
-
92
- # Initialize InferenceClient for fallback model (Gemma-2-9b-it)
93
- fallback_client = InferenceClient("google/gemma-2-9b-it", token=HF_TOKEN)
94
- print("Fallback model (Gemma-2-9b-it) client initialized.")
95
-
96
-
97
- # Load spacy model for sentence splitting
98
- nlp = None
99
- try:
100
- nlp = spacy.load("en_core_web_sm")
101
- print("SpaCy model 'en_core_web_sm' loaded.")
102
- except OSError:
103
- print("SpaCy model 'en_core_web_sm' not found. Downloading...")
104
- try:
105
- import subprocess
106
- subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
107
- nlp = spacy.load("en_core_web_sm")
108
- print("SpaCy model 'en_core_web_sm' downloaded and loaded.")
109
- except Exception as e:
110
- print(f"Failed to download or load SpaCy model: {e}")
111
-
112
- # Load SentenceTransformer for RAG/business info retrieval and semantic detection
113
- embedder = None
114
- try:
115
- print("Attempting to load Sentence Transformer (sentence-transformers/paraphrase-MiniLM-L6-v2)...")
116
- embedder = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
117
- print("Sentence Transformer loaded.")
118
- except Exception as e:
119
- print(f"Error loading Sentence Transformer: {e}")
120
-
121
- # Load a Cross-Encoder model for re-ranking retrieved documents
122
- reranker = None
123
- try:
124
- print("Attempting to load Cross-Encoder Reranker (cross-encoder/ms-marco-MiniLM-L6-v2)...")
125
- reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2')
126
- print("Cross-Encoder Reranker loaded.")
127
- except Exception as e:
128
- print(f"Error loading Cross-Encoder Reranker: {e}")
129
- print("Please ensure the model identifier 'cross-encoder/ms-marco-MiniLM-L6-v2' is correct and accessible on Hugging Face Hub.")
130
- print(traceback.format_exc())
131
- reranker = None
132
-
133
- # Google Sheets Authentication
134
- gc = None
135
- def authenticate_google_sheets():
136
- """Authenticates with Google Sheets using base64 encoded credentials."""
137
- global gc
138
- print("Authenticating Google Account...")
139
- if not GOOGLE_BASE64_CREDENTIALS:
140
- print("Error: GOOGLE_BASE64_CREDENTIALS secret not found.")
141
- return False
142
- try:
143
- credentials_json = base64.b64decode(GOOGLE_BASE64_CREDENTIALS).decode('utf-8')
144
- credentials = json.loads(credentials_json)
145
- gc = gspread.service_account_from_dict(credentials)
146
- print("Google Sheets authentication successful via service account.")
147
- return True
148
- except Exception as e:
149
- print(f"Google Sheets authentication failed: {e}")
150
- print(traceback.format_exc())
151
- print("Please ensure your GOOGLE_BASE64_CREDENTIALS secret is correctly set and contains valid service account credentials.")
152
- return False
153
-
154
- # Google Sheets Data Loading and Embedding for RAG
155
- data = []
156
- descriptions_for_embedding = []
157
- embeddings = torch.tensor([]) # This will store embeddings for RAG data
158
- business_info_available = False
159
-
160
- def load_business_info():
161
- """Loads business information from Google Sheet and creates embeddings."""
162
- global data, descriptions_for_embedding, embeddings, business_info_available
163
- business_info_available = False
164
- if gc is None:
165
- print("Skipping Google Sheet loading: Google Sheets client not authenticated.")
166
- return
167
- if not SHEET_ID:
168
- print("Error: SHEET_ID not set.")
169
- return
170
- try:
171
- sheet = gc.open_by_key(SHEET_ID).sheet1
172
- print(f"Successfully opened Google Sheet with ID: {SHEET_ID}")
173
- data_records = sheet.get_all_records()
174
- if not data_records:
175
- print(f"Warning: No data records found in Google Sheet with ID: {SHEET_ID}")
176
- data = []
177
- descriptions_for_embedding = []
178
- else:
179
- filtered_data = [row for row in data_records if row.get('Service') and row.get('Description')]
180
- if not filtered_data:
181
- print("Warning: Filtered data is empty after checking for 'Service' and 'Description'.")
182
- data = []
183
- descriptions_for_embedding = []
184
- else:
185
- data = filtered_data
186
- descriptions_for_embedding = [f"Service: {row['Service']}. Description: {row['Description']}" for row in data]
187
- if descriptions_for_embedding and embedder is not None:
188
- print("Encoding descriptions for RAG...")
189
- try:
190
- embeddings = embedder.encode(descriptions_for_embedding, convert_to_tensor=True)
191
- print("Encoding complete. RAG embeddings created.")
192
- business_info_available = True
193
- except Exception as e:
194
- print(f"Error during description encoding for RAG: {e}")
195
- embeddings = torch.tensor([])
196
- business_info_available = False
197
- else:
198
- print("Skipping encoding descriptions for RAG: No descriptions found or embedder not available.")
199
- embeddings = torch.tensor([])
200
- business_info_available = False
201
- print(f"Loaded {len(descriptions_for_embedding)} entries from Google Sheet for embedding/RAG.")
202
- if not business_info_available:
203
- print("Business information retrieval (RAG) is NOT available.")
204
- else:
205
- print("Business information retrieval (RAG) is available.")
206
- except gspread.exceptions.SpreadsheetNotFound:
207
- print(f"Error: Google Sheet with ID '{SHEET_ID}' not found.")
208
- print("Please check the SHEET_ID and ensure your authenticated Google Account has access to this sheet.")
209
- business_info_available = False
210
- except Exception as e:
211
- print(f"An error occurred while accessing the Google Sheet: {e}")
212
- print(traceback.format_exc())
213
- business_info_available = False
214
-
215
- # Business Info Retrieval (RAG) function - Reusing the existing one
216
- def retrieve_business_info(query: str, top_n: int = 3) -> list:
217
- """
218
- Retrieves relevant business information from loaded data based on a query.
219
- """
220
- global data, embeddings
221
- if not business_info_available or embedder is None or not descriptions_for_embedding or not data or embeddings.numel() == 0:
222
- print("Business information retrieval is not available or RAG data is empty.")
223
- return []
224
- try:
225
- query_embedding = embedder.encode(query, convert_to_tensor=True)
226
- # Ensure both tensors are on the same device for cosine similarity calculation
227
- if query_embedding.device != embeddings.device:
228
- query_embedding = query_embedding.to(embeddings.device)
229
-
230
- cosine_scores = util.cos_sim(query_embedding, embeddings)[0]
231
- top_results_indices = torch.topk(cosine_scores, k=min(top_n, len(data)))[1].tolist()
232
- top_results = [data[i] for i in top_results_indices]
233
- if reranker is not None and top_results:
234
- print("Re-ranking top results...")
235
- rerank_pairs = [(query, descriptions_for_embedding[i]) for i in top_results_indices]
236
- rerank_scores = reranker.predict(rerank_pairs)
237
- reranked_indices = sorted(range(len(rerank_scores)), key=lambda i: rerank_scores[i], reverse=True)
238
- reranked_results = [top_results[i] for i in reranked_indices]
239
- print("Re-ranking complete.")
240
- return reranked_results
241
- else:
242
- return top_results
243
- except Exception as e:
244
- print(f"Error during business information retrieval: {e}")
245
- print(traceback.format_exc())
246
- return []
247
-
248
-
249
- # Function to perform DuckDuckGo Search and return results with URLs
250
- def perform_duckduckgo_search(query: str, max_results: int = 5):
251
- """
252
- Performs a search using DuckDuckGo and returns a list of dictionaries.
253
- """
254
- print(f"Executing Tool: perform_duckduckgo_search with query='{query}')")
255
- search_results_list = []
256
- try:
257
- time.sleep(1)
258
- with DDGS() as ddgs:
259
- search_query = query.strip()
260
- if not search_query or len(search_query.split()) < 2:
261
- print(f"Skipping search for short query: '{search_query}'")
262
- return []
263
- print(f"Sending search query to DuckDuckGo: '{search_query}'")
264
- results_generator = ddgs.text(search_query, max_results=max_results)
265
- results_found = False
266
- for r in results_generator:
267
- search_results_list.append(r)
268
- results_found = True
269
- print(f"Raw results from DuckDuckGo: {search_results_list}")
270
- if not results_found and max_results > 0:
271
- print(f"DuckDuckGo search for '{search_query}' returned no results.")
272
- elif results_found:
273
- print(f"DuckDuckGo search for '{search_query}' completed. Found {len(search_results_list)} results.")
274
- except Exception as e:
275
- print(f"Error during Duckduckgo search for '{search_query if 'search_query' in locals() else query}': {e}")
276
- print(traceback.format_exc())
277
- return []
278
- return search_results_list
279
-
280
- # Define the new semantic date/time detection and calculation function using dateparser
281
- def perform_date_calculation(query: str) -> str or None:
282
- """
283
- Analyzes query for date/time information using dateparser.
284
- """
285
- print(f"Executing Tool: perform_date_calculation with query='{query}') using dateparser.search_dates")
286
- try:
287
- eafrica_tz = pytz.timezone('Africa/Dar_es_Salaam')
288
- now = datetime.now(eafrica_tz)
289
- except pytz.UnknownTimeZoneError:
290
- print("Error: Unknown timezone 'Africa/Dar_es_Salaam'. Using default system time.")
291
- now = datetime.now()
292
- try:
293
- found = search_dates(
294
- query,
295
- settings={
296
- "PREFER_DATES_FROM": "future",
297
- "RELATIVE_BASE": now
298
- },
299
- languages=['sw', 'en']
300
- )
301
- if not found:
302
- print("dateparser.search_dates could not parse any date/time.")
303
- return None
304
- text_snippet, parsed = found[0]
305
- print(f"dateparser.search_dates found: text='{text_snippet}', parsed='{parsed}'")
306
- is_swahili = any(swahili_phrase in query.lower() for swahili_phrase in ['tarehe', 'siku', 'saa', 'muda', 'leo', 'kesho', 'jana', 'ngapi', 'gani', 'mwezi', 'mwaka', 'habari', 'mambo', 'shikamoo', 'karibu', 'asante'])
307
-
308
- if is_swahili:
309
- query_lower = query.lower().strip()
310
- if query_lower in ['habari', 'mambo', 'habari gani']:
311
- return "Nzuri! Habari zako?"
312
- elif query_lower in ['shikamoo']:
313
- return "Marahaba!"
314
- elif query_lower in ['asante']:
315
- return "Karibu!"
316
- elif query_lower in ['karibu']:
317
- return "Asante!"
318
-
319
- if now.tzinfo is not None and parsed.tzinfo is None:
320
- parsed = now.tzinfo.localize(parsed)
321
- elif now.tzinfo is None and parsed.tzinfo is not None:
322
- parsed = parsed.replace(tzinfo=None)
323
-
324
- if parsed.date() == now.date():
325
- if abs((parsed - now).total_seconds()) < 60 or parsed.time() == datetime.min.time():
326
- print("Query parsed to today's date and time is close to 'now' or midnight, returning current time/date.")
327
- if is_swahili:
328
- return f"Kwa saa za Afrika Mashariki (Tanzania), tarehe ya leo ni {now.strftime('%A, %d %B %Y')} na saa ni {now.strftime('%H:%M:%S')}."
329
- else:
330
- return f"In East Africa (Tanzania), the current date is {now.strftime('%A, %d %B %Y')} and the time is {now.strftime('%H:%M:%S')}."
331
- else:
332
- print(f"Query parsed to a specific time today: {parsed.strftime('%H:%M:%S')}")
333
- if is_swahili:
334
- return f"Hiyo inafanyika leo, {parsed.strftime('%A, %d %B %Y')}, saa {parsed.strftime('%H:%M:%S')} saa za Afrika Mashariki."
335
- else:
336
- return f"That falls on today, {parsed.strftime('%A, %d %B %Y')}, at {parsed.strftime('%H:%M:%S')} East Africa Time."
337
- else:
338
- print(f"Query parsed to a specific date: {parsed.strftime('%A, %d %B %Y')} at {parsed.strftime('%H:%M:%S')}")
339
- time_str = parsed.strftime('%H:%M:%S')
340
- date_str = parsed.strftime('%A, %d %B %Y')
341
- if parsed.tzinfo:
342
- tz_name = parsed.tzinfo.tzname(parsed) or 'UTC'
343
- if is_swahili:
344
- return f"Hiyo inafanyika tarehe {date_str} saa {time_str} {tz_name}."
345
- else:
346
- return f"That falls on {date_str} at {time_str} {tz_name}."
347
- else:
348
- if is_swahili:
349
- return f"Hiyo inafanyika tarehe {date_str} saa {time_str}."
350
- else:
351
- return f"That falls on {date_str} at {time_str}."
352
- except Exception as e:
353
- print(f"Error during dateparser.search_dates execution: {e}")
354
- print(traceback.format_exc())
355
- return f"An error occurred while parsing date/time: {e}"
356
-
357
- # Function to determine if a query requires a tool or can be answered directly
358
- # Modified to include complexity check for routing to primary vs fallback
359
- def determine_tool_usage(query: str) -> tuple[str, str]:
360
- """
361
- Analyzes the query to determine if a specific tool is needed and its complexity.
362
- Returns a tuple: (tool_name, complexity_level)
363
- Complexity levels: 'simple' (fallback), 'complex' (primary)
364
- """
365
- query_lower = query.lower()
366
-
367
- swahili_conversational_phrases = ['habari', 'mambo', 'shikamoo', 'karibu', 'asante', 'habari gani']
368
- if any(swahili_phrase in query_lower for swahili_phrase in swahili_conversational_phrases):
369
- print(f"Detected a Swahili conversational phrase: '{query}'. Using 'date_calculation' tool and 'simple' complexity.")
370
- return "date_calculation", "simple" # Simple conversational queries routed to fallback
371
-
372
- # Check for business info retrieval first
373
- if business_info_available:
374
- # Use a simple LLM call to check if the query is business-related
375
- messages_business_check = [{"role": "user", "content": f"Does the following query ask about a specific person, service, offering, or description that is likely to be found *only* within a specific business's internal knowledge base, and not general knowledge? For example, questions about 'Salum' or 'Jackson Kisanga' are likely business-related, while questions about 'the current president of the USA' or 'who won the Ballon d'Or' are general knowledge. Answer only 'yes' or 'no'. Query: {query}"}]
376
- try:
377
- business_check_response = primary_client.chat_completion( # Use primary client for this check
378
- messages=messages_business_check,
379
- max_tokens=10,
380
- temperature=0.1
381
- ).choices[0].message.content.strip().lower()
382
- if business_check_response == "yes":
383
- print(f"Detected as specific business info query based on LLM check: '{query}'. Using 'business_info_retrieval' tool and 'simple' complexity.")
384
- # Business info RAG is handled by the fallback model
385
- return "business_info_retrieval", "simple"
386
- else:
387
- print(f"LLM check indicates not a specific business info query: '{query}')")
388
- except Exception as e:
389
- print(f"Error during LLM call for business info check for query '{query}': {e}")
390
- print(traceback.format_exc())
391
- print(f"Proceeding without business info check for query '{query}' due to error.")
392
-
393
- # Check for date/time calculation
394
- date_time_check_result = perform_date_calculation(query)
395
- if date_time_check_result is not None and not any(phrase in query_lower for phrase in swahili_conversational_phrases):
396
- print(f"Detected as date/time calculation query based on dateparser result for: '{query}'. Using 'date_calculation' tool and 'simple' complexity.")
397
- # Date calculation is handled by the fallback model
398
- return "date_calculation", "simple"
399
-
400
- # Check if web search is needed for general knowledge or current info
401
- messages_tool_determination_search = [{"role": "user", "content": f"Does the following query require searching the web for current or general knowledge information (e.g., news, facts, definitions, current events)? Respond ONLY with 'duckduckgo_search' or 'none'. Query: {query}"}]
402
- try:
403
- search_determination_response = primary_client.chat_completion( # Use primary client for this check
404
- messages=messages_tool_determination_search,
405
- max_tokens=20,
406
- temperature=0.1,
407
- top_p=0.9
408
- ).choices[0].message.content or ""
409
- response_lower = search_determination_response.strip().lower()
410
- if "duckduckgo_search" in response_lower:
411
- print(f"Model-determined tool for '{query}': 'duckduckgo_search'. Using 'complex' complexity.")
412
- # Web search queries are generally more complex and routed to primary
413
- return "duckduckgo_search", "complex"
414
- else:
415
- print(f"Model-determined tool for '{query}': 'none' (for search).")
416
- except Exception as e:
417
- print(f"Error during LLM call for search tool determination for query '{query}': {e}")
418
- print(traceback.format_exc())
419
- print(f"Proceeding without search tool check for query '{query}' due to error.")
420
-
421
- # If no specific tool is determined, route based on query complexity
422
- messages_complexity = [{"role": "user", "content": f"Is the following query simple or complex? A simple query is a basic question, a greeting, or a question that can be answered with a short, direct response. A complex query requires detailed understanding, multiple steps, or external information synthesis. Respond ONLY with 'simple' or 'complex'. Query: {query}"}]
423
- try:
424
- complexity_response = primary_client.chat_completion( # Use primary client for complexity check
425
- messages=messages_complexity,
426
- max_tokens=10,
427
- temperature=0.1
428
- ).choices[0].message.content.strip().lower()
429
-
430
- if "complex" in complexity_response:
431
- print(f"Determined query complexity for '{query}': 'complex'. Using 'none' tool.")
432
- return "none", "complex" # No tool, complex query routed to primary
433
- else:
434
- print(f"Determined query complexity for '{query}': 'simple'. Using 'none' tool.")
435
- return "none", "simple" # No tool, simple query routed to fallback
436
-
437
- except Exception as e:
438
- print(f"Error during LLM call for complexity determination for query '{query}': {e}")
439
- print(traceback.format_exc())
440
- print(f"Defaulting query '{query}' to 'complex' due to error.")
441
- return "none", "complex" # Default to complex on error
442
-
443
-
444
- # Function to summarize chat history
445
- def summarize_chat_history(chat_history: list[dict]) -> str:
446
- """
447
- Summarizes the provided chat history using the LLM.
448
- Uses the primary client for summarization.
449
- """
450
- print("\n--- Summarizing chat history ---")
451
- if not chat_history:
452
- print("Chat history is empty, no summarization needed.")
453
- return ""
454
-
455
- history_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history])
456
-
457
- prompt_for_summary = f"""
458
- Please provide a concise summary of the following conversation history.
459
- Conversation History:
460
- {history_text}
461
-
462
- Summary:
463
- """
464
- try:
465
- messages_summary = [{"role": "user", "content": prompt_for_summary}]
466
- summary_response = primary_client.chat_completion( # Use primary client
467
- messages=messages_summary,
468
- max_tokens=200, # Adjust based on desired summary length
469
- temperature=0.3,
470
- top_p=0.9
471
- ).choices[0].message.content or ""
472
- print("Chat history summarization successful using primary client.")
473
- return summary_response.strip()
474
- except Exception as e:
475
- print(f"Error during LLM call for chat history summarization (primary client): {e}")
476
- print(traceback.format_exc())
477
- return "Unable to summarize previous conversation."
478
-
479
-
480
- # Function to generate text using the LLM, incorporating tool results if available
481
- # Modified to use primary or fallback client based on complexity
482
- def generate_text(prompt: str, tool_results: dict = None, chat_history: list[dict] = None, complexity_level: str = 'complex') -> str:
483
- """
484
- Generates text using the configured LLM (primary or fallback), optionally incorporating tool results and chat history.
485
- Implements conversation summarization and windowing for long histories.
486
- """
487
- persona_instructions = """You are absa_ai, an AI developed on August 7, 2025, by the absa team. Your knowledge about business data comes from the company's internal Google Sheet.
488
- You are a friendly and helpful chatbot. Respond to greetings appropriately (e.g., "Hello!", "Hi there!", "Habari!"). If the user uses Swahili greetings or simple conversational phrases, respond in Swahili. Otherwise, respond in English unless the query is clearly in Swahili. Handle conversational flow and ask follow-up questions when appropriate.
489
- If the user asks a question about other companies or general knowledge, answer their question. However, subtly remind them that your primary expertise and purpose are related to Absa-specific information.
490
- """
491
- messages = [{"role": "user", "content": persona_instructions}]
492
-
493
- # --- Conversation Summarization and Windowing ---
494
- SUMMARY_THRESHOLD = 10 # Summarize after 10 turns (5 user/assistant pairs)
495
- HISTORY_WINDOW_SIZE = 4 # Keep the last 4 turns (2 user/assistant pairs)
496
-
497
- if chat_history:
498
- print(f"Current chat history length: {len(chat_history)}")
499
- if len(chat_history) > SUMMARY_THRESHOLD:
500
- print("Chat history exceeds threshold, summarizing older turns.")
501
- history_to_summarize = chat_history[:-HISTORY_WINDOW_SIZE]
502
- summary = summarize_chat_history(history_to_summarize) # summarize_chat_history uses primary client
503
- if summary:
504
- messages.append({"role": "assistant", "content": f"Summary of previous conversation: {summary}"})
505
- print("Added summary to messages.")
506
- recent_history = chat_history[-HISTORY_WINDOW_SIZE:]
507
- print(f"Including last {len(recent_history)} turns from history.")
508
- for message_dict in recent_history:
509
- role = message_dict.get("role")
510
- content = message_dict.get("content")
511
- if role in ["user", "assistant"] and content is not None:
512
- messages.append({"role": role, "content": content})
513
- else:
514
- print("Including full chat history in LLM prompt.")
515
- for message_dict in chat_history:
516
- role = message_dict.get("role")
517
- content = message_dict.get("content")
518
- if role in ["user", "assistant"] and content is not None:
519
- messages.append({"role": role, "content": content})
520
-
521
-
522
- current_user_content = prompt
523
- if tool_results and any(tool_results.values()):
524
- current_user_content += "\n\nTool Results:\n"
525
- for question, results in tool_results.items():
526
- if results and results != "none": # Only include if results are not None or "none"
527
- current_user_content += f"--- Results for: {question} ---\n"
528
- if isinstance(results, list):
529
- for i, result in enumerate(results):
530
- if isinstance(result, dict) and 'Service' in result and 'Description' in result:
531
- current_user_content += f"Business Info {i+1}:\nService: {result.get('Service', 'N/A')}\nDescription: {result.get('Description', 'N/A')}\n\n"
532
- elif isinstance(result, dict) and 'url' in result:
533
- current_user_content += f"Search Result {i+1}:\nTitle: {result.get('title', 'N/A')}\nURL: {result.get('url', 'N/A')}\nSnippet: {result.get('body', 'N/A')}\n\n"
534
- else:
535
- current_user_content += f"{result}\n\n"
536
- elif isinstance(results, dict):
537
- for key, value in results.items():
538
- current_user_content += f"{key}: {value}\n"
539
- current_user_content += "\n"
540
- else:
541
- current_user_content += f"{results}\n\n"
542
-
543
- current_user_content += "Based on the provided tool results and the conversation history, answer the user's latest query. If a question was answered by a tool, use the tool's result directly in your response. Maintain the language of the original query if possible, especially for simple greetings or direct questions answered by tools."
544
- print("Added tool results and instruction to final prompt.")
545
- else:
546
- current_user_content += "Based on the conversation history, answer the user's latest query."
547
- print("No tool results to add to final prompt, relying on conversation history.")
548
-
549
- messages.append({"role": "user", "content": current_user_content})
550
-
551
- generation_config = {
552
- "temperature": 0.7,
553
- "max_new_tokens": 500,
554
- "top_p": 0.95,
555
- "top_k": 50,
556
- "do_sample": True,
557
- }
558
-
559
- try:
560
- if complexity_level == 'complex':
561
- print("Using primary client for generation.")
562
- response = primary_client.chat_completion(
563
- messages=messages,
564
- max_tokens=generation_config.get("max_new_tokens", 512),
565
- temperature=generation_config.get("temperature", 0.7),
566
- top_p=generation_config.get("top_p", 0.95)
567
- ).choices[0].message.content or ""
568
- print("LLM generation successful using primary client.")
569
- else: # complexity_level == 'simple' or fallback needed
570
- print("Using fallback client for generation.")
571
- # Use fallback_client for chat completion with Gemma
572
- response = fallback_client.chat_completion(
573
- messages=messages,
574
- max_tokens=generation_config.get("max_new_tokens", 512),
575
- temperature=generation_config.get("temperature", 0.7),
576
- top_p=generation_config.get("top_p", 0.95)
577
- ).choices[0].message.content or ""
578
- print("LLM generation successful using fallback client.")
579
-
580
-
581
- return response.strip()
582
- except Exception as e:
583
- print(f"Error during final LLM generation (primary or fallback): {e}")
584
- print(traceback.format_exc())
585
- return "An error occurred while generating the final response."
586
-
587
- # Function to log conversation data to the Hugging Face Dataset and push
588
- def log_conversation(user_query: str, model_response: str, tool_details: dict = None, user_id: str = None):
589
- """
590
- Logs conversation data (query, response, timestamp, optional details) to the Hugging Face Dataset
591
- and pushes the changes to the Hub.
592
- """
593
- global conversation_dataset # Access the global dataset variable
594
- global dataset_name # Access the dataset name
595
-
596
- print("\n--- Attempting to log conversation to Hugging Face Dataset ---")
597
-
598
- if conversation_dataset is None:
599
- print("Warning: Hugging Face dataset not loaded or created. Skipping conversation logging.")
600
- return
601
-
602
- try:
603
- timestamp = datetime.now().isoformat()
604
- # Ensure tool_details is a JSON string or None
605
- tool_details_json = json.dumps(tool_details) if tool_details is not None else None
606
- # Handle potential None values for user_id
607
- user_id_val = user_id if user_id is not None else "anonymous"
608
-
609
- # Create a dictionary for the new log entry
610
- new_log_entry = {
611
- 'timestamp': timestamp,
612
- 'user_id': user_id_val,
613
- 'user_query': user_query,
614
- 'model_response': model_response,
615
- 'tool_details': tool_details_json
616
- }
617
-
618
- # Append the new log entry to the 'train' split of the dataset
619
- new_row_dataset = Dataset.from_dict({key: [value] for key, value in new_log_entry.items()})
620
-
621
- # Check if the 'train' split exists before concatenating
622
- if 'train' in conversation_dataset:
623
- conversation_dataset['train'] = concatenate_datasets([conversation_dataset['train'], new_row_dataset])
624
- else:
625
- # If 'train' doesn't exist (e.g., first log entry), create it
626
- conversation_dataset['train'] = new_dataset # Use the new_dataset with schema
627
-
628
-
629
- print("Conversation data successfully added to the dataset object.")
630
-
631
- # --- Pushing to the Hugging Face Hub ---
632
- print(f"Attempting to push dataset to {dataset_name}...")
633
- # Use the push_to_hub method of the DatasetDict
634
- # Use commit_message for clarity
635
- conversation_dataset.push_to_hub(dataset_name, token=HF_TOKEN, commit_message=f"Add conversation log: {timestamp}")
636
- print(f"Successfully pushed dataset to {dataset_name}.")
637
-
638
-
639
- except Exception as e:
640
- print(f"An unexpected error occurred during Hugging Face Dataset logging and pushing: {e}")
641
- print(traceback.format_exc())
642
-
643
-
644
- # Need to import concatenate_datasets
645
- from datasets import concatenate_datasets
646
- from huggingface_hub import HfApi # Ensure HfApi is imported
647
-
648
- # --- Caching Implementation ---
649
- # Define the path for the FAISS index file
650
- FAISS_INDEX_FILE = "cache.index"
651
- # Define the path for the metadata file (query text, response, timestamp)
652
- CACHE_METADATA_FILE = "cache_metadata.pkl"
653
-
654
- # Global variables for FAISS index and metadata
655
- faiss_index = None
656
- cache_metadata = {}
657
-
658
- # Dimension of the embeddings (should match your embedder model output dimension)
659
- # For 'sentence-transformers/paraphrase-MiniLM-L6-v2', the dimension is 384
660
- EMBEDDING_DIM = 384
661
-
662
- CACHE_SIMILARITY_THRESHOLD = 0.9 # Cosine similarity threshold for cache hit
663
- CACHE_EXPIRATION_DAYS = 7 # Cache entries expire after 7 days
664
-
665
-
666
- def initialize_cache():
667
- """Initializes or loads the FAISS index and cache metadata."""
668
- global faiss_index, cache_metadata
669
- print("\n--- Initializing Cache ---")
670
- if os.path.exists(FAISS_INDEX_FILE) and os.path.exists(CACHE_METADATA_FILE):
671
- print("Loading existing cache...")
672
- try:
673
- faiss_index = faiss.read_index(FAISS_INDEX_FILE)
674
- with open(CACHE_METADATA_FILE, 'rb') as f:
675
- cache_metadata = pickle.load(f)
676
- print(f"Cache loaded successfully. Current cache size: {faiss_index.ntotal}")
677
- # Clean up expired entries on load
678
- cleanup_expired_cache_entries()
679
- except Exception as e:
680
- print(f"Error loading cache files: {e}. Initializing new cache.")
681
- print(traceback.format_exc())
682
- faiss_index = faiss.IndexFlatL2(EMBEDDING_DIM) # Using L2 distance
683
- cache_metadata = {}
684
- save_cache() # Save empty cache
685
- else:
686
- print("No existing cache found. Initializing new cache.")
687
- faiss_index = faiss.IndexFlatL2(EMBEDDING_DIM) # Using L2 distance
688
- cache_metadata = {}
689
- save_cache() # Save empty cache
690
-
691
- def save_cache():
692
- """Saves the FAISS index and cache metadata to files."""
693
- global faiss_index, cache_metadata
694
- if faiss_index is None:
695
- print("Warning: FAISS index not initialized. Cannot save cache.")
696
- return
697
- print("Saving cache...")
698
- try:
699
- faiss.write_index(faiss_index, FAISS_INDEX_FILE)
700
- with open(CACHE_METADATA_FILE, 'wb') as f:
701
- pickle.dump(cache_metadata, f)
702
- print("Cache saved successfully.")
703
- except Exception as e:
704
- print(f"Error saving cache files: {e}")
705
- print(traceback.format_exc())
706
-
707
- def get_query_embedding(query: str):
708
- """Generates an embedding for the given query."""
709
- if embedder is None:
710
- print("Warning: Embedder not available. Cannot generate query embedding for caching.")
711
- return None
712
- try:
713
- return embedder.encode(query, convert_to_tensor=False) # Return numpy array for FAISS
714
- except Exception as e:
715
- print(f"Error generating embedding for query '{query}': {e}")
716
- print(traceback.format_exc())
717
- return None
718
-
719
- def add_to_cache(query: str, response: str):
720
- """Adds the query, response, and timestamp to the cache."""
721
- global faiss_index, cache_metadata
722
- if embedder is None or faiss_index is None:
723
- print("Warning: Embedder or FAISS index not available. Cannot add query to cache.")
724
- return
725
-
726
- try:
727
- query_embedding = get_query_embedding(query)
728
- if query_embedding is None:
729
- return
730
-
731
- # Add the embedding to the FAISS index
732
- faiss_index.add(np.array([query_embedding])) # Add expects a numpy array of shape (n, dim)
733
-
734
- # Store metadata (query, response, timestamp) keyed by the FAISS index ID
735
- # The last added embedding gets the index faiss_index.ntotal - 1
736
- cache_id = faiss_index.ntotal - 1
737
- now = datetime.now()
738
- cache_metadata[cache_id] = {
739
- 'query': query, # Store original query for debugging/verification
740
- 'response': response,
741
- 'timestamp': now,
742
- 'count': 1 # Initialize count
743
- }
744
- print(f"Added query and response to cache with ID {cache_id}.")
745
- save_cache() # Save cache after adding
746
- print(f"Current cache size: {faiss_index.ntotal}")
747
-
748
- except Exception as e:
749
- print(f"Error adding query to cache: {e}")
750
- print(traceback.format_exc())
751
-
752
-
753
- def check_cache(query: str):
754
- """Checks the cache for a similar query and returns the cached response if found and not expired."""
755
- global faiss_index, cache_metadata
756
- if faiss_index is None or embedder is None or faiss_index.ntotal == 0:
757
- print("Cache is empty or not available. Skipping cache check.")
758
- return None
759
-
760
- try:
761
- query_embedding = get_query_embedding(query)
762
- if query_embedding is None:
763
- return None
764
-
765
- # Search the FAISS index for similar embeddings
766
- # D is distances, I is indices of the nearest neighbors
767
- D, I = faiss_index.search(np.array([query_embedding]), 1) # Search for the 1 nearest neighbor
768
-
769
- if I[0][0] != -1 and D[0][0] <= (1 - CACHE_SIMILARITY_THRESHOLD): # Check if a neighbor was found and distance is within threshold
770
- cached_id = I[0][0]
771
- print(f"Found potential cache hit with ID {cached_id} and distance {D[0][0]:.4f}.")
772
-
773
- if cached_id in cache_metadata:
774
- cached_data = cache_metadata[cached_id]
775
- now = datetime.now()
776
- # Check for expiration
777
- if (now - cached_data['timestamp']).days <= CACHE_EXPIRATION_DAYS:
778
- print(f"Cache hit! Returning cached response for query: '{query}'")
779
- # Update timestamp and count on cache hit
780
- cache_metadata[cached_id]['timestamp'] = now
781
- cache_metadata[cached_id]['count'] += 1
782
- save_cache() # Save cache after updating metadata
783
- return cached_data['response']
784
- else:
785
- print(f"Cache entry with ID {cached_id} found but expired.")
786
- # We could remove the expired entry here, but it's handled by cleanup_expired_cache_entries
787
- else:
788
- print(f"Cache ID {cached_id} found in index but not in metadata. Cache inconsistency.")
789
-
790
- print(f"No suitable cache entry found for query: '{query}'")
791
- return None
792
-
793
- except Exception as e:
794
- print(f"Error during cache check: {e}")
795
- print(traceback.format_exc())
796
- return None
797
-
798
- def cleanup_expired_cache_entries():
799
- """Removes expired entries from the cache and rebuilds the FAISS index if necessary."""
800
- global faiss_index, cache_metadata
801
- if faiss_index is None or faiss_index.ntotal == 0:
802
- print("Cache is empty or not initialized. No expired entries to clean.")
803
- return
804
-
805
- print("Cleaning up expired cache entries...")
806
- now = datetime.now()
807
- expired_ids = [
808
- cache_id for cache_id, cached_data in cache_metadata.items()
809
- if (now - cached_data['timestamp']).days > CACHE_EXPIRATION_DAYS
810
- ]
811
-
812
- if expired_ids:
813
- print(f"Found {len(expired_ids)} expired cache entries.")
814
- # Remove from metadata
815
- for cache_id in expired_ids:
816
- del cache_metadata[cache_id]
817
-
818
- # Rebuild FAISS index with non-expired entries
819
- if cache_metadata:
820
- print("Rebuilding FAISS index with non-expired entries...")
821
- try:
822
- # Get embeddings for non-expired entries
823
- non_expired_embeddings = []
824
- non_expired_metadata_list = sorted(cache_metadata.items()) # Sort by ID to maintain order
825
- for cache_id, cached_data in non_expired_metadata_list:
826
- # Need to retrieve original query to re-embed
827
- original_query = cached_data.get('query')
828
- if original_query and embedder:
829
- try:
830
- non_expired_embeddings.append(embedder.encode(original_query, convert_to_tensor=False).tolist())
831
- except Exception as e:
832
- print(f"Error re-embedding query '{original_query}': {e}. Skipping.")
833
-
834
-
835
- if non_expired_embeddings:
836
- print(f"Re-embedding {len(non_expired_embeddings)} non-expired queries.")
837
- faiss_index = faiss.IndexFlatL2(EMBEDDING_DIM)
838
- faiss_index.add(np.array(non_expired_embeddings))
839
- print(f"FAISS index rebuilt. New size: {faiss_index.ntotal}")
840
- else:
841
- print("No non-expired entries to rebuild FAISS index. Clearing index.")
842
- faiss_index = faiss.IndexFlatL2(EMBEDDING_DIM)
843
- cache_metadata = {} # Clear metadata if index is cleared
844
-
845
- except Exception as e:
846
- print(f"Error rebuilding FAISS index: {e}")
847
- print(traceback.format_exc())
848
- # On error, it might be safer to clear the cache to avoid inconsistencies
849
- print("Clearing cache due to rebuild error.")
850
- faiss_index = faiss.IndexFlatL2(EMBEDDING_DIM)
851
- cache_metadata = {}
852
-
853
- else:
854
- print("All cache entries expired. Clearing FAISS index and metadata.")
855
- faiss_index = faiss.IndexFlatL2(EMBEDDING_DIM)
856
- cache_metadata = {}
857
-
858
- save_cache() # Save after cleanup
859
- else:
860
- print("No expired cache entries found.")
861
-
862
-
863
- # Main chat function with query breakdown and tool execution per question
864
- def chat(query: str, chat_history: list[dict], api_key: str):
865
- """
866
- Processes user queries by breaking down multi-part queries, determining and
867
- executing appropriate tools for each question, and synthesizing results
868
- using the LLM. Incorporates caching for repeated questions and routes
869
- to primary or fallback model based on complexity.
870
- """
871
- print(f"\n--- chat function received new query ---")
872
- print(f" query: {query}")
873
- print(f" Validating against SECRET_API_KEY: {'*' * len(SECRET_API_KEY) if SECRET_API_KEY else 'None'}")
874
- print(f" chat_history: {chat_history}")
875
- print(f" api_key from UI: {'*' * len(api_key) if api_key else 'None'}")
876
-
877
- if not SECRET_API_KEY:
878
- print("Error: APP_API_KEY secret not set in Hugging Face Space Secrets.")
879
- # Log failure before returning
880
- log_conversation(
881
- user_query=query,
882
- model_response="API key validation failed: Application not configured correctly. APP_API_KEY secret is missing.",
883
- tool_details={"validation_status": "failed", "reason": "secret_not_set"},
884
- user_id="unknown"
885
- )
886
- return "API key validation failed: Application not configured correctly. APP_API_KEY secret is missing."
887
-
888
- if api_key != SECRET_API_KEY:
889
- print("Error: API key from UI does not match SECRET_API_KEY.")
890
- # Log failure before returning
891
- log_conversation(
892
- user_query=query,
893
- model_response="API key validation failed: Invalid API key provided.",
894
- tool_details={"validation_status": "failed", "reason": "invalid_api_key"},
895
- user_id="unknown"
896
- )
897
- return "API key validation failed: Invalid API key provided."
898
-
899
- # --- Cache Check ---
900
- cached_response = check_cache(query)
901
- if cached_response:
902
- print(f"Returning cached response for query: '{query}'")
903
- # Log the cached response
904
- try:
905
- user_id_to_log = "anonymous"
906
- if chat_history:
907
- for turn in chat_history:
908
- if turn.get("role") == "user" and "user_id:" in turn.get("content", "").lower():
909
- match = re.search(r"user_id:\s*(\S+)", turn.get("content", ""), re.IGNORECASE)
910
- if match:
911
- user_id_to_log = match.group(1)
912
- break
913
-
914
- log_conversation(
915
- user_query=query,
916
- model_response=cached_response,
917
- tool_details={"cache_status": "hit"},
918
- user_id=user_id_to_log
919
- )
920
- except Exception as e:
921
- print(f"Error during logging of cached response: {e}")
922
- print(traceback.format_exc())
923
-
924
- return cached_response
925
-
926
- print("\n--- Breaking down query ---")
927
- # Use the primary client for query breakdown as it's generally better at understanding complex queries
928
- prompt_for_question_breakdown = f"""
929
- Analyze the following query and list each distinct question found within it.
930
- Present each question on a new line, starting with a hyphen.
931
- Query: {query}
932
- """
933
- try:
934
- messages_question_breakdown = primary_client.chat_completion( # Use primary client
935
- messages=[{"role": "user", "content": prompt_for_question_breakdown}],
936
- max_tokens=100,
937
- temperature=0.1,
938
- top_p=0.9
939
- ).choices[0].message.content or ""
940
- individual_questions = [line.strip() for line in messages_question_breakdown.split('\n') if line.strip()]
941
- cleaned_questions = [re.sub(r'^[-*]?\s*', '', q) for q in individual_questions if not q.strip().lower().startswith('note:')]
942
- print("Individual questions identified:")
943
- for q in cleaned_questions:
944
- print(f"- {q}")
945
- except Exception as e:
946
- print(f"Error during LLM call for question breakdown (primary client): {e}")
947
- print(traceback.format_exc())
948
- print(f"Proceeding with original query as a single question due to breakdown error.")
949
- cleaned_questions = [query]
950
-
951
- print("\n--- Determining tools and complexity per question ---")
952
- determined_tools_and_complexity = {}
953
- for question in cleaned_questions:
954
- print(f"\nAnalyzing question for tool determination and complexity: '{question}'")
955
- tool, complexity = determine_tool_usage(question) # determine_tool_usage uses primary client for checks
956
- determined_tools_and_complexity[question] = {"tool": tool, "complexity": complexity}
957
- print(f"Determined tool and complexity for '{question}': Tool='{tool}', Complexity='{complexity}'")
958
-
959
- print("\nSummary of determined tools and complexity per question:")
960
- for question, details in determined_tools_and_complexity.items():
961
- print(f"'{question}': Tool='{details['tool']}', Complexity='{details['complexity']}'")
962
-
963
- print("\n--- Executing tools and collecting results ---")
964
- tool_results = {}
965
- for question, details in determined_tools_and_complexity.items():
966
- tool = details['tool']
967
- print(f"\nExecuting tool '{tool}' for question: '{question}')")
968
- result = None
969
- if tool == "date_calculation":
970
- result = perform_date_calculation(question)
971
- tool_results[question] = result
972
- elif tool == "duckduckgo_search":
973
- result = perform_duckduckgo_search(question)
974
- tool_results[question] = result
975
- elif tool == "business_info_retrieval":
976
- result = retrieve_business_info(question)
977
- tool_results[question] = result
978
- elif tool == "none":
979
- print(f"Skipping tool execution for question: '{question}' as tool is 'none'. LLM will handle.")
980
- tool_results[question] = "none"
981
-
982
- print("\n--- Collected Tool Results ---")
983
- if tool_results:
984
- for question, result in tool_results.items():
985
- print(f"\nQuestion: {question}")
986
- print(f"Result: {result}")
987
- else:
988
- print("No tool results were collected.")
989
- print("\n--------------------------")
990
-
991
- print("\n--- Generating final response ---")
992
-
993
- # Determine the overall complexity to choose the final generation model
994
- # If any question was determined as 'complex', use the primary model
995
- overall_complexity = 'simple'
996
- for details in determined_tools_and_complexity.values():
997
- if details['complexity'] == 'complex':
998
- overall_complexity = 'complex'
999
- break
1000
- print(f"Overall query complexity determined as: '{overall_complexity}'")
1001
-
1002
-
1003
- final_response = generate_text(query, tool_results, chat_history, complexity_level=overall_complexity)
1004
- print("\n--- Final Response from LLM ---")
1005
- print(final_response)
1006
- print("\n----------------------------")
1007
-
1008
- # --- Add response to cache ---
1009
- # We add the entire query and final response to the cache, not individual questions.
1010
- add_to_cache(query, final_response)
1011
-
1012
- try:
1013
- user_id_to_log = "anonymous"
1014
- if chat_history:
1015
- for turn in chat_history:
1016
- if turn.get("role") == "user" and "user_id:" in turn.get("content", "").lower():
1017
- match = re.search(r"user_id:\s*(\S+)", turn.get("content", ""), re.IGNORECASE)
1018
- if match:
1019
- user_id_to_log = match.group(1)
1020
- break
1021
-
1022
- logged_tool_details = {}
1023
- for question, details in determined_tools_and_complexity.items():
1024
- logged_tool_details[question] = {
1025
- "tool_used": details['tool'],
1026
- "complexity": details['complexity'],
1027
- "raw_output": tool_results.get(question)
1028
- }
1029
- logged_tool_details["cache_status"] = "miss" # Log cache miss when generating a new response
1030
- logged_tool_details["model_used_for_generation"] = "primary" if overall_complexity == 'complex' else "fallback"
1031
-
1032
-
1033
- # Call the logging function (currently logs to Hugging Face Dataset)
1034
- log_conversation(
1035
- user_query=query,
1036
- model_response=final_response,
1037
- tool_details=logged_tool_details,
1038
- user_id=user_id_to_log
1039
- )
1040
- except Exception as e:
1041
- print(f"Error during conversation logging after response generation: {e}")
1042
- print(traceback.format_exc())
1043
-
1044
- return final_response
1045
-
1046
- # Keep the Gradio interface setup as is for now
1047
- if __name__ == "__main__":
1048
- # Load/Create Hugging Face Dataset on startup
1049
- try:
1050
- # Attempt to load the existing dataset
1051
- print(f"Attempting to load dataset from {dataset_name} on startup...")
1052
- # Use load_dataset for loading directly from the Hub
1053
- conversation_dataset = load_dataset(dataset_name, token=HF_TOKEN)
1054
- print(f"Successfully loaded existing dataset from {dataset_name} on startup.")
1055
- print(conversation_dataset)
1056
-
1057
- except Exception as e:
1058
- print(f"Dataset not found or failed to load from {dataset_name} on startup: {e}")
1059
- print("Creating a new dataset object on startup...")
1060
-
1061
- # Define the schema for conversation logs
1062
- # Using 'string' as the data type for simplicity, tool_details will be JSON string
1063
- log_schema = {
1064
- 'timestamp': 'string',
1065
- 'user_id': 'string',
1066
- 'user_query': 'string',
1067
- 'model_response': 'string',
1068
- 'tool_details': 'string' # Store JSON string here
1069
- }
1070
-
1071
- # Create an empty dataset with the defined schema
1072
- empty_data = {col: [] for col in log_schema.keys()}
1073
- new_dataset = Dataset.from_dict(empty_data)
1074
-
1075
- # Wrap the dataset in a DatasetDict
1076
- conversation_dataset = DatasetDict({'train': new_dataset})
1077
-
1078
- print(f"Created a new empty dataset object with schema: {log_schema}")
1079
- print(conversation_dataset)
1080
-
1081
-
1082
- authenticate_google_sheets()
1083
- load_business_info() # This will also create RAG embeddings if data is loaded
1084
-
1085
- if nlp is None:
1086
- print("Warning: SpaCy model not loaded. Sentence splitting may not work correctly.")
1087
- if embedder is None:
1088
- print("Warning: Sentence Transformer (embedder) not loaded. RAG will not be available.")
1089
- if reranker is None:
1090
- print("Warning: Cross-Encoder Reranker not loaded. Re-ranking of RAG results will not be performed.")
1091
- if not business_info_available:
1092
- print("Warning: Business information (Google Sheet data) not loaded successfully. "
1093
- "RAG will not be available. Please ensure the GOOGLE_BASE64_CREDENTIALS secret is set correctly.")
1094
-
1095
-
1096
- DESCRIPTION = """
1097
- # LLM with Tools (DuckDuckGo Search, Date Calculation, Business Info RAG, Hugging Face Dataset Logging) and Two-Tier Model System
1098
- Ask me anything! I can perform web searches, calculate dates, retrieve business information using RAG, and conversation data will be logged to a Hugging Face Dataset. I use a primary LLaMA-70B model for complex queries and a fallback Gemma-2-9b-it model for simpler ones and RAG synthesis.
1099
- """
1100
-
1101
- demo = gr.ChatInterface(
1102
- fn=chat,
1103
- stop_btn=None,
1104
- examples=[
1105
- ["Hello there! How are you doing?"],
1106
- ["What is the current time in East Africa?"],
1107
- ["Tell me about the 'Project Management' service from Absa."],
1108
- ["Search the web for the latest news on AI."],
1109
- ["Habari!"],
1110
- ["What is the date next Tuesday?"],
1111
- ["What is the time in East Africa and search for latest AI news"],
1112
- ["Who is Jackson Kisanga?"], # Example for business info retrieval
1113
- ["What is the weather like in London?"], # Example for web search
1114
- ["Tell me a joke."], # Example for simple query
1115
- ],
1116
- cache_examples=False,
1117
- type="messages",
1118
- description=DESCRIPTION,
1119
- fill_height=True,
1120
- additional_inputs=[
1121
- gr.Textbox(label="API Key", type="password", placeholder="Enter your API key (starts with fs_)", interactive=True)
1122
- ],
1123
- additional_inputs_accordion="API Key (Required)"
1124
- )
1125
-
1126
- try:
1127
- # Initialize the cache before launching the demo
1128
- initialize_cache()
1129
- demo.launch(debug=True, show_error=True)
1130
- except Exception as e:
1131
- print(f"Error launching Gradio interface: {e}")
1132
- print(traceback.format_exc())
1133
- print("Please check the console output for more details on the error.")