sandeep-huggingface commited on
Commit
061aa76
Β·
verified Β·
1 Parent(s): 82b7d29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -72
app.py CHANGED
@@ -40,10 +40,10 @@ class TransformersLLM(LLM):
40
  max_new_tokens = kwargs.pop('max_new_tokens', 256)
41
  temperature = kwargs.pop('temperature', 0.7)
42
  do_sample = kwargs.pop('do_sample', True)
43
-
44
  # Call super with model_name explicitly
45
  super().__init__(model_name=model_name, **kwargs)
46
-
47
  # Set our custom attributes
48
  self.pipeline = pipeline_obj
49
  self.max_new_tokens = max_new_tokens
@@ -59,7 +59,7 @@ class TransformersLLM(LLM):
59
  user_question = prompt.split("Question:")[-1].replace("Helpful Answer:", "").strip()
60
  else:
61
  user_question = prompt
62
-
63
  # Create a focused prompt for the model
64
  if "qwen" in self.model_name.lower():
65
  system_prompt = "You are a helpful assistant that analyzes CSV data and answers questions accurately and concisely."
@@ -71,7 +71,7 @@ class TransformersLLM(LLM):
71
  formatted_prompt = f"Question: {user_question}\n\nBased on the provided CSV data, please provide a clear and informative answer:\n\nAnswer:"
72
 
73
  print(f"Generating response for: {user_question[:100]}...")
74
-
75
  # Generate response
76
  with torch.no_grad():
77
  response = self.pipeline(
@@ -82,10 +82,10 @@ class TransformersLLM(LLM):
82
  pad_token_id=self.pipeline.tokenizer.eos_token_id,
83
  return_full_text=False
84
  )
85
-
86
  if response and len(response) > 0:
87
  generated_text = response[0]['generated_text'].strip()
88
-
89
  # Clean up the response
90
  if "assistant" in generated_text:
91
  generated_text = generated_text.split("assistant")[-1].strip()
@@ -93,18 +93,18 @@ class TransformersLLM(LLM):
93
  generated_text = generated_text.split("<|im_end|>")[0].strip()
94
  if "<|eot_id|>" in generated_text:
95
  generated_text = generated_text.split("<|eot_id|>")[0].strip()
96
-
97
  # Remove any remaining special tokens
98
  for token in ["<|im_start|>", "<|im_end|>", "<|eot_id|>", "<|begin_of_text|>", "<|end_of_text|>"]:
99
  generated_text = generated_text.replace(token, "")
100
-
101
  generated_text = generated_text.strip()
102
-
103
  if len(generated_text) > 10:
104
  return generated_text
105
-
106
  return "I apologize, but I couldn't generate a meaningful response. Please try rephrasing your question."
107
-
108
  except Exception as e:
109
  print(f"Error in LLM generation: {e}")
110
  return f"I encountered an error while processing your question: {str(e)}. Please try again."
@@ -115,8 +115,8 @@ class TransformersLLM(LLM):
115
 
116
  # Available models
117
  AVAILABLE_MODELS = {
118
- "Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct",
119
- "Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct"
120
  }
121
 
122
  CHUNK_SIZES = {
@@ -128,10 +128,10 @@ CHUNK_SIZES = {
128
  def load_model(model_choice: str, progress=gr.Progress()):
129
  """Load the selected model with proper memory management"""
130
  global current_model, current_tokenizer, current_pipeline
131
-
132
  try:
133
  model_id = AVAILABLE_MODELS[model_choice]
134
-
135
  with model_lock:
136
  # Clear existing model if different
137
  if current_model is not None:
@@ -144,22 +144,22 @@ def load_model(model_choice: str, progress=gr.Progress()):
144
  else:
145
  # Same model already loaded
146
  return current_pipeline, f"βœ… Model {model_choice} already loaded!"
147
-
148
  progress(0.2, desc=f"Loading tokenizer for {model_choice}...")
149
  print(f"Loading tokenizer for {model_id}...")
150
-
151
  tokenizer = AutoTokenizer.from_pretrained(
152
- model_id,
153
  trust_remote_code=True
154
  )
155
-
156
  # Set pad token if not exists
157
  if tokenizer.pad_token is None:
158
  tokenizer.pad_token = tokenizer.eos_token
159
-
160
  progress(0.5, desc=f"Loading model {model_choice}...")
161
  print(f"Loading model {model_id}...")
162
-
163
  # Load model with appropriate settings for Colab
164
  # model = AutoModelForCausalLM.from_pretrained(
165
  # model_id,
@@ -190,7 +190,7 @@ def load_model(model_choice: str, progress=gr.Progress()):
190
  )
191
  progress(0.8, desc="Creating pipeline...")
192
  print("Creating text generation pipeline...")
193
-
194
  # Create pipeline
195
  # pipe = pipeline(
196
  # "text-generation",
@@ -210,11 +210,11 @@ def load_model(model_choice: str, progress=gr.Progress()):
210
  current_model = model
211
  current_tokenizer = tokenizer
212
  current_pipeline = pipe
213
-
214
  progress(1.0, desc="Model loaded successfully!")
215
-
216
  return pipe, f"βœ… Model {model_choice} loaded successfully!"
217
-
218
  except Exception as e:
219
  print(f"Error loading model: {e}")
220
  traceback.print_exc()
@@ -250,7 +250,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
250
  # Read CSV file with multiple encoding attempts
251
  df = None
252
  encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1']
253
-
254
  for encoding in encodings:
255
  try:
256
  df = pd.read_csv(file_path, encoding=encoding)
@@ -258,7 +258,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
258
  break
259
  except UnicodeDecodeError:
260
  continue
261
-
262
  if df is None:
263
  print(f"Could not read {file_path} with any encoding")
264
  return []
@@ -269,12 +269,12 @@ def csv_to_documents(file_path: str) -> List[Document]:
269
 
270
  # Clean the dataframe
271
  df = df.dropna(how='all') # Remove completely empty rows
272
-
273
  # Get basic info about the CSV
274
  filename = os.path.basename(file_path)
275
  total_rows = len(df)
276
  columns = list(df.columns)
277
-
278
  print(f"Processing {filename}: {total_rows} rows, {len(columns)} columns")
279
  print(f"Columns: {columns}")
280
 
@@ -282,7 +282,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
282
 
283
  # Create a summary document first
284
  summary_content = f"Dataset: {filename}\nTotal rows: {total_rows}\nColumns: {', '.join(columns)}\n"
285
-
286
  # Add column statistics if numeric columns exist
287
  numeric_cols = df.select_dtypes(include=['number']).columns
288
  if len(numeric_cols) > 0:
@@ -314,7 +314,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
314
  try:
315
  # Create a more structured text representation
316
  row_text_parts = []
317
-
318
  # Add row identifier
319
  row_text_parts.append(f"Row {idx + 1} of {filename}")
320
 
@@ -323,7 +323,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
323
  value = row[col]
324
  if pd.isna(value):
325
  continue # Skip NaN values
326
-
327
  # Clean and format the value
328
  if isinstance(value, (int, float)):
329
  formatted_value = f"{value:,.2f}" if isinstance(value, float) else f"{value:,}"
@@ -331,7 +331,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
331
  formatted_value = str(value).replace('\n', ' ').replace('\r', ' ').strip()
332
  if len(formatted_value) > 100:
333
  formatted_value = formatted_value[:100] + "..."
334
-
335
  row_text_parts.append(f"{col}: {formatted_value}")
336
 
337
  # Combine all parts
@@ -395,18 +395,18 @@ def load_doc(list_file_path: List[str], splitting_strategy: str, chunk_size: str
395
 
396
  # Apply text splitting with adjusted parameters
397
  text_splitter = get_text_splitter(
398
- splitting_strategy,
399
- chunk_size_value,
400
  max(50, chunk_size_value // 10) # Dynamic overlap
401
  )
402
  doc_splits = text_splitter.split_documents(all_documents)
403
 
404
  print(f"Total document chunks after splitting: {len(doc_splits)}")
405
-
406
  # Print some sample chunks for debugging
407
  for i, split in enumerate(doc_splits[:3]):
408
  print(f"Sample chunk {i+1}: {split.page_content[:100]}...")
409
-
410
  return doc_splits
411
 
412
  except Exception as e:
@@ -421,7 +421,7 @@ def create_db(splits, db_choice: str = "faiss"):
421
  raise ValueError("No document splits provided")
422
 
423
  print(f"Creating {db_choice} database with {len(splits)} documents")
424
-
425
  # Use a reliable embedding model with better parameters
426
  embeddings = HuggingFaceEmbeddings(
427
  model_name="sentence-transformers/all-MiniLM-L6-v2",
@@ -434,29 +434,29 @@ def create_db(splits, db_choice: str = "faiss"):
434
  'batch_size': 16
435
  }
436
  )
437
-
438
  print("Testing embeddings with sample text...")
439
  test_embedding = embeddings.embed_query("test query")
440
  print(f"Embedding dimension: {len(test_embedding)}")
441
-
442
  db_creators = {
443
  "faiss": lambda: FAISS.from_documents(splits, embeddings),
444
  "chroma": lambda: Chroma.from_documents(
445
- splits,
446
  embeddings,
447
  persist_directory=None # In-memory
448
  )
449
  }
450
-
451
  db = db_creators[db_choice]()
452
  print(f"Successfully created {db_choice} database")
453
-
454
  # Test the database with a simple query
455
  test_results = db.similarity_search("test", k=1)
456
  print(f"Database test successful, found {len(test_results)} results")
457
-
458
  return db
459
-
460
  except Exception as e:
461
  print(f"Error creating database: {e}")
462
  traceback.print_exc()
@@ -513,15 +513,15 @@ def initialize_llmchain(model_choice, temperature, max_tokens, top_k, vector_db,
513
  return None, "❌ Please create vector database first."
514
 
515
  progress(0.2, desc="Loading model...")
516
-
517
  # Load the selected model
518
  pipeline_obj, model_status = load_model(model_choice, progress)
519
-
520
  if pipeline_obj is None:
521
  return None, model_status
522
 
523
  progress(0.7, desc="Creating LLM instance...")
524
-
525
  # Create our custom LLM wrapper
526
  llm = TransformersLLM(
527
  model_name=AVAILABLE_MODELS[model_choice],
@@ -530,9 +530,9 @@ def initialize_llmchain(model_choice, temperature, max_tokens, top_k, vector_db,
530
  temperature=max(0.1, min(1.0, temperature)),
531
  do_sample=temperature > 0.1
532
  )
533
-
534
  progress(0.8, desc="Setting up retriever...")
535
-
536
  # Create retriever with optimized parameters
537
  retriever = vector_db.as_retriever(
538
  search_type="similarity",
@@ -541,7 +541,7 @@ def initialize_llmchain(model_choice, temperature, max_tokens, top_k, vector_db,
541
  "fetch_k": min(max(3, top_k * 2), 20) # Fetch more, then filter
542
  }
543
  )
544
-
545
  # Test the retriever
546
  try:
547
  test_docs = retriever.get_relevant_documents("test query")
@@ -551,7 +551,7 @@ def initialize_llmchain(model_choice, temperature, max_tokens, top_k, vector_db,
551
  return None, f"❌ Database retriever failed: {str(e)}"
552
 
553
  progress(0.9, desc="Creating QA chain...")
554
-
555
  # Create QA chain with error handling
556
  qa_chain = RetrievalQA.from_chain_type(
557
  llm=llm,
@@ -568,7 +568,7 @@ def initialize_llmchain(model_choice, temperature, max_tokens, top_k, vector_db,
568
  f"🌑️ Temperature: {temperature}\n"
569
  f"πŸ“ Max tokens: {max_tokens}\n"
570
  f"πŸ” Retriever K: {top_k}")
571
-
572
  return qa_chain, success_msg
573
 
574
  except Exception as e:
@@ -601,14 +601,14 @@ def conversation(qa_chain, message, history):
601
  print(f"\n{'='*50}")
602
  print(f"Processing question: {message}")
603
  print(f"{'='*50}")
604
-
605
  # Enhance the query for better CSV data understanding
606
  enhanced_query = f"""Based on the CSV data provided, please answer this question: {message.strip()}
607
 
608
  Please provide a clear, informative answer that directly addresses the question. If you're analyzing data, include specific values, trends, or patterns you observe."""
609
-
610
  print(f"Enhanced query: {enhanced_query}")
611
-
612
  # Call the QA chain with timeout handling
613
  start_time = time.time()
614
  try:
@@ -628,7 +628,7 @@ Please provide a clear, informative answer that directly addresses the question.
628
  fallback_response += "Please try rephrasing your question or try again."
629
  else:
630
  fallback_response = "I couldn't find relevant information in your CSV data for this question. Please try a different question or check if your data contains the information you're looking for."
631
-
632
  return (
633
  qa_chain,
634
  gr.update(value=""),
@@ -643,18 +643,18 @@ Please provide a clear, informative answer that directly addresses the question.
643
  except Exception as fallback_error:
644
  print(f"Fallback also failed: {fallback_error}")
645
  error_response = f"I encountered an error processing your question: {str(qa_error)}. Please try:\n\n1. Using a simpler question\n2. Waiting a moment and trying again\n3. Reloading the model"
646
-
647
  return (
648
  qa_chain,
649
  gr.update(value=""),
650
  history + [(message, error_response)],
651
  f"Error: {str(qa_error)}", "Error processing", "", "No source", "", "No source"
652
  )
653
-
654
  # Extract and process the response
655
  print(f"Raw response type: {type(response)}")
656
  print(f"Response keys: {response.keys() if isinstance(response, dict) else 'Not a dict'}")
657
-
658
  if isinstance(response, dict):
659
  # Get the answer
660
  response_answer = response.get("result") or response.get("answer") or str(response)
@@ -672,15 +672,15 @@ Please provide a clear, informative answer that directly addresses the question.
672
  response_answer = response_answer.replace("Based on the following context, please provide a helpful and accurate answer", "").strip()
673
  response_answer = response_answer.replace("Helpful Answer:", "").strip()
674
  response_answer = response_answer.replace("Answer:", "").strip()
675
-
676
  # Remove repeated prompts
677
  if enhanced_query[:50] in response_answer:
678
  response_answer = response_answer.replace(enhanced_query, "").strip()
679
-
680
  # Ensure we have a meaningful response
681
  if len(response_answer.strip()) < 10:
682
  response_answer = "I was able to process your question, but the response was too brief. Please try rephrasing your question or providing more context."
683
-
684
  if not response_answer or response_answer.strip() == "":
685
  response_answer = "I apologize, but I couldn't generate a meaningful response to your question. Please try rephrasing your question or ensure your CSV data contains relevant information."
686
 
@@ -703,14 +703,14 @@ Please provide a clear, informative answer that directly addresses the question.
703
  content = content[:300] + "..."
704
 
705
  source_contents.append(content)
706
-
707
  if doc_type == "csv_summary":
708
  source_info.append(f"Summary of {source_file}")
709
  else:
710
  source_info.append(f"File: {source_file} | Row: {row_info}")
711
-
712
  print(f"Source {i+1}: {source_info[-1][:50]}...")
713
-
714
  except Exception as e:
715
  print(f"Error processing source {i}: {e}")
716
  source_contents.append(f"Error processing source: {str(e)}")
@@ -724,11 +724,14 @@ Please provide a clear, informative answer that directly addresses the question.
724
  print(f"Final response length: {len(response_answer)} characters")
725
  print(f"Response preview: {response_answer[:100]}...")
726
  print(f"Sources: {[info for info in source_info if info != 'No additional sources']}")
727
-
 
 
 
728
  return (
729
  qa_chain,
730
  gr.update(value=""),
731
- history + [(message, response_answer)],
732
  source_contents[0],
733
  source_info[0],
734
  source_contents[1],
@@ -742,9 +745,9 @@ Please provide a clear, informative answer that directly addresses the question.
742
  print(f"Error: {str(e)}")
743
  print(f"Error type: {type(e).__name__}")
744
  traceback.print_exc()
745
-
746
  error_msg = f"❌ I encountered an error while processing your question:\n\n{str(e)}\n\nPlease try:\n1. Using a simpler question\n2. Waiting a moment and trying again\n3. Reloading the model\n4. Recreating the database"
747
-
748
  return (
749
  qa_chain,
750
  gr.update(value=""),
@@ -905,7 +908,7 @@ def demo():
905
  **Available Models:**
906
  - **Qwen2.5-7B-Instruct**: Advanced Chinese-English bilingual model, excellent for analysis
907
  - **Llama-3.1-8B-Instruct**: Meta's powerful instruction-following model
908
-
909
  **Note**: Models are loaded locally with 4-bit quantization for memory efficiency. First load may take several minutes.
910
  """)
911
 
@@ -994,7 +997,7 @@ def demo():
994
  )
995
 
996
  demo.queue().launch(
997
- debug=True,
998
  share=False,
999
  show_error=True
1000
  )
 
40
  max_new_tokens = kwargs.pop('max_new_tokens', 256)
41
  temperature = kwargs.pop('temperature', 0.7)
42
  do_sample = kwargs.pop('do_sample', True)
43
+
44
  # Call super with model_name explicitly
45
  super().__init__(model_name=model_name, **kwargs)
46
+
47
  # Set our custom attributes
48
  self.pipeline = pipeline_obj
49
  self.max_new_tokens = max_new_tokens
 
59
  user_question = prompt.split("Question:")[-1].replace("Helpful Answer:", "").strip()
60
  else:
61
  user_question = prompt
62
+
63
  # Create a focused prompt for the model
64
  if "qwen" in self.model_name.lower():
65
  system_prompt = "You are a helpful assistant that analyzes CSV data and answers questions accurately and concisely."
 
71
  formatted_prompt = f"Question: {user_question}\n\nBased on the provided CSV data, please provide a clear and informative answer:\n\nAnswer:"
72
 
73
  print(f"Generating response for: {user_question[:100]}...")
74
+
75
  # Generate response
76
  with torch.no_grad():
77
  response = self.pipeline(
 
82
  pad_token_id=self.pipeline.tokenizer.eos_token_id,
83
  return_full_text=False
84
  )
85
+
86
  if response and len(response) > 0:
87
  generated_text = response[0]['generated_text'].strip()
88
+
89
  # Clean up the response
90
  if "assistant" in generated_text:
91
  generated_text = generated_text.split("assistant")[-1].strip()
 
93
  generated_text = generated_text.split("<|im_end|>")[0].strip()
94
  if "<|eot_id|>" in generated_text:
95
  generated_text = generated_text.split("<|eot_id|>")[0].strip()
96
+
97
  # Remove any remaining special tokens
98
  for token in ["<|im_start|>", "<|im_end|>", "<|eot_id|>", "<|begin_of_text|>", "<|end_of_text|>"]:
99
  generated_text = generated_text.replace(token, "")
100
+
101
  generated_text = generated_text.strip()
102
+
103
  if len(generated_text) > 10:
104
  return generated_text
105
+
106
  return "I apologize, but I couldn't generate a meaningful response. Please try rephrasing your question."
107
+
108
  except Exception as e:
109
  print(f"Error in LLM generation: {e}")
110
  return f"I encountered an error while processing your question: {str(e)}. Please try again."
 
115
 
116
  # Available models
117
  AVAILABLE_MODELS = {
118
+ "Llama-3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
119
+ "Qwen2.5-0.5B-Instruct": "Qwen/Qwen2.5-0.5B-Instruct"
120
  }
121
 
122
  CHUNK_SIZES = {
 
128
  def load_model(model_choice: str, progress=gr.Progress()):
129
  """Load the selected model with proper memory management"""
130
  global current_model, current_tokenizer, current_pipeline
131
+
132
  try:
133
  model_id = AVAILABLE_MODELS[model_choice]
134
+
135
  with model_lock:
136
  # Clear existing model if different
137
  if current_model is not None:
 
144
  else:
145
  # Same model already loaded
146
  return current_pipeline, f"βœ… Model {model_choice} already loaded!"
147
+
148
  progress(0.2, desc=f"Loading tokenizer for {model_choice}...")
149
  print(f"Loading tokenizer for {model_id}...")
150
+
151
  tokenizer = AutoTokenizer.from_pretrained(
152
+ model_id,
153
  trust_remote_code=True
154
  )
155
+
156
  # Set pad token if not exists
157
  if tokenizer.pad_token is None:
158
  tokenizer.pad_token = tokenizer.eos_token
159
+
160
  progress(0.5, desc=f"Loading model {model_choice}...")
161
  print(f"Loading model {model_id}...")
162
+
163
  # Load model with appropriate settings for Colab
164
  # model = AutoModelForCausalLM.from_pretrained(
165
  # model_id,
 
190
  )
191
  progress(0.8, desc="Creating pipeline...")
192
  print("Creating text generation pipeline...")
193
+
194
  # Create pipeline
195
  # pipe = pipeline(
196
  # "text-generation",
 
210
  current_model = model
211
  current_tokenizer = tokenizer
212
  current_pipeline = pipe
213
+
214
  progress(1.0, desc="Model loaded successfully!")
215
+
216
  return pipe, f"βœ… Model {model_choice} loaded successfully!"
217
+
218
  except Exception as e:
219
  print(f"Error loading model: {e}")
220
  traceback.print_exc()
 
250
  # Read CSV file with multiple encoding attempts
251
  df = None
252
  encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1']
253
+
254
  for encoding in encodings:
255
  try:
256
  df = pd.read_csv(file_path, encoding=encoding)
 
258
  break
259
  except UnicodeDecodeError:
260
  continue
261
+
262
  if df is None:
263
  print(f"Could not read {file_path} with any encoding")
264
  return []
 
269
 
270
  # Clean the dataframe
271
  df = df.dropna(how='all') # Remove completely empty rows
272
+
273
  # Get basic info about the CSV
274
  filename = os.path.basename(file_path)
275
  total_rows = len(df)
276
  columns = list(df.columns)
277
+
278
  print(f"Processing {filename}: {total_rows} rows, {len(columns)} columns")
279
  print(f"Columns: {columns}")
280
 
 
282
 
283
  # Create a summary document first
284
  summary_content = f"Dataset: {filename}\nTotal rows: {total_rows}\nColumns: {', '.join(columns)}\n"
285
+
286
  # Add column statistics if numeric columns exist
287
  numeric_cols = df.select_dtypes(include=['number']).columns
288
  if len(numeric_cols) > 0:
 
314
  try:
315
  # Create a more structured text representation
316
  row_text_parts = []
317
+
318
  # Add row identifier
319
  row_text_parts.append(f"Row {idx + 1} of {filename}")
320
 
 
323
  value = row[col]
324
  if pd.isna(value):
325
  continue # Skip NaN values
326
+
327
  # Clean and format the value
328
  if isinstance(value, (int, float)):
329
  formatted_value = f"{value:,.2f}" if isinstance(value, float) else f"{value:,}"
 
331
  formatted_value = str(value).replace('\n', ' ').replace('\r', ' ').strip()
332
  if len(formatted_value) > 100:
333
  formatted_value = formatted_value[:100] + "..."
334
+
335
  row_text_parts.append(f"{col}: {formatted_value}")
336
 
337
  # Combine all parts
 
395
 
396
  # Apply text splitting with adjusted parameters
397
  text_splitter = get_text_splitter(
398
+ splitting_strategy,
399
+ chunk_size_value,
400
  max(50, chunk_size_value // 10) # Dynamic overlap
401
  )
402
  doc_splits = text_splitter.split_documents(all_documents)
403
 
404
  print(f"Total document chunks after splitting: {len(doc_splits)}")
405
+
406
  # Print some sample chunks for debugging
407
  for i, split in enumerate(doc_splits[:3]):
408
  print(f"Sample chunk {i+1}: {split.page_content[:100]}...")
409
+
410
  return doc_splits
411
 
412
  except Exception as e:
 
421
  raise ValueError("No document splits provided")
422
 
423
  print(f"Creating {db_choice} database with {len(splits)} documents")
424
+
425
  # Use a reliable embedding model with better parameters
426
  embeddings = HuggingFaceEmbeddings(
427
  model_name="sentence-transformers/all-MiniLM-L6-v2",
 
434
  'batch_size': 16
435
  }
436
  )
437
+
438
  print("Testing embeddings with sample text...")
439
  test_embedding = embeddings.embed_query("test query")
440
  print(f"Embedding dimension: {len(test_embedding)}")
441
+
442
  db_creators = {
443
  "faiss": lambda: FAISS.from_documents(splits, embeddings),
444
  "chroma": lambda: Chroma.from_documents(
445
+ splits,
446
  embeddings,
447
  persist_directory=None # In-memory
448
  )
449
  }
450
+
451
  db = db_creators[db_choice]()
452
  print(f"Successfully created {db_choice} database")
453
+
454
  # Test the database with a simple query
455
  test_results = db.similarity_search("test", k=1)
456
  print(f"Database test successful, found {len(test_results)} results")
457
+
458
  return db
459
+
460
  except Exception as e:
461
  print(f"Error creating database: {e}")
462
  traceback.print_exc()
 
513
  return None, "❌ Please create vector database first."
514
 
515
  progress(0.2, desc="Loading model...")
516
+
517
  # Load the selected model
518
  pipeline_obj, model_status = load_model(model_choice, progress)
519
+
520
  if pipeline_obj is None:
521
  return None, model_status
522
 
523
  progress(0.7, desc="Creating LLM instance...")
524
+
525
  # Create our custom LLM wrapper
526
  llm = TransformersLLM(
527
  model_name=AVAILABLE_MODELS[model_choice],
 
530
  temperature=max(0.1, min(1.0, temperature)),
531
  do_sample=temperature > 0.1
532
  )
533
+
534
  progress(0.8, desc="Setting up retriever...")
535
+
536
  # Create retriever with optimized parameters
537
  retriever = vector_db.as_retriever(
538
  search_type="similarity",
 
541
  "fetch_k": min(max(3, top_k * 2), 20) # Fetch more, then filter
542
  }
543
  )
544
+
545
  # Test the retriever
546
  try:
547
  test_docs = retriever.get_relevant_documents("test query")
 
551
  return None, f"❌ Database retriever failed: {str(e)}"
552
 
553
  progress(0.9, desc="Creating QA chain...")
554
+
555
  # Create QA chain with error handling
556
  qa_chain = RetrievalQA.from_chain_type(
557
  llm=llm,
 
568
  f"🌑️ Temperature: {temperature}\n"
569
  f"πŸ“ Max tokens: {max_tokens}\n"
570
  f"πŸ” Retriever K: {top_k}")
571
+
572
  return qa_chain, success_msg
573
 
574
  except Exception as e:
 
601
  print(f"\n{'='*50}")
602
  print(f"Processing question: {message}")
603
  print(f"{'='*50}")
604
+
605
  # Enhance the query for better CSV data understanding
606
  enhanced_query = f"""Based on the CSV data provided, please answer this question: {message.strip()}
607
 
608
  Please provide a clear, informative answer that directly addresses the question. If you're analyzing data, include specific values, trends, or patterns you observe."""
609
+
610
  print(f"Enhanced query: {enhanced_query}")
611
+
612
  # Call the QA chain with timeout handling
613
  start_time = time.time()
614
  try:
 
628
  fallback_response += "Please try rephrasing your question or try again."
629
  else:
630
  fallback_response = "I couldn't find relevant information in your CSV data for this question. Please try a different question or check if your data contains the information you're looking for."
631
+
632
  return (
633
  qa_chain,
634
  gr.update(value=""),
 
643
  except Exception as fallback_error:
644
  print(f"Fallback also failed: {fallback_error}")
645
  error_response = f"I encountered an error processing your question: {str(qa_error)}. Please try:\n\n1. Using a simpler question\n2. Waiting a moment and trying again\n3. Reloading the model"
646
+
647
  return (
648
  qa_chain,
649
  gr.update(value=""),
650
  history + [(message, error_response)],
651
  f"Error: {str(qa_error)}", "Error processing", "", "No source", "", "No source"
652
  )
653
+
654
  # Extract and process the response
655
  print(f"Raw response type: {type(response)}")
656
  print(f"Response keys: {response.keys() if isinstance(response, dict) else 'Not a dict'}")
657
+
658
  if isinstance(response, dict):
659
  # Get the answer
660
  response_answer = response.get("result") or response.get("answer") or str(response)
 
672
  response_answer = response_answer.replace("Based on the following context, please provide a helpful and accurate answer", "").strip()
673
  response_answer = response_answer.replace("Helpful Answer:", "").strip()
674
  response_answer = response_answer.replace("Answer:", "").strip()
675
+
676
  # Remove repeated prompts
677
  if enhanced_query[:50] in response_answer:
678
  response_answer = response_answer.replace(enhanced_query, "").strip()
679
+
680
  # Ensure we have a meaningful response
681
  if len(response_answer.strip()) < 10:
682
  response_answer = "I was able to process your question, but the response was too brief. Please try rephrasing your question or providing more context."
683
+
684
  if not response_answer or response_answer.strip() == "":
685
  response_answer = "I apologize, but I couldn't generate a meaningful response to your question. Please try rephrasing your question or ensure your CSV data contains relevant information."
686
 
 
703
  content = content[:300] + "..."
704
 
705
  source_contents.append(content)
706
+
707
  if doc_type == "csv_summary":
708
  source_info.append(f"Summary of {source_file}")
709
  else:
710
  source_info.append(f"File: {source_file} | Row: {row_info}")
711
+
712
  print(f"Source {i+1}: {source_info[-1][:50]}...")
713
+
714
  except Exception as e:
715
  print(f"Error processing source {i}: {e}")
716
  source_contents.append(f"Error processing source: {str(e)}")
 
724
  print(f"Final response length: {len(response_answer)} characters")
725
  print(f"Response preview: {response_answer[:100]}...")
726
  print(f"Sources: {[info for info in source_info if info != 'No additional sources']}")
727
+
728
+ new_history = history.copy()
729
+ new_history.append({"role": "user", "content": message})
730
+ new_history.append({"role": "assistant", "content": response_answer})
731
  return (
732
  qa_chain,
733
  gr.update(value=""),
734
+ new_history,
735
  source_contents[0],
736
  source_info[0],
737
  source_contents[1],
 
745
  print(f"Error: {str(e)}")
746
  print(f"Error type: {type(e).__name__}")
747
  traceback.print_exc()
748
+
749
  error_msg = f"❌ I encountered an error while processing your question:\n\n{str(e)}\n\nPlease try:\n1. Using a simpler question\n2. Waiting a moment and trying again\n3. Reloading the model\n4. Recreating the database"
750
+
751
  return (
752
  qa_chain,
753
  gr.update(value=""),
 
908
  **Available Models:**
909
  - **Qwen2.5-7B-Instruct**: Advanced Chinese-English bilingual model, excellent for analysis
910
  - **Llama-3.1-8B-Instruct**: Meta's powerful instruction-following model
911
+
912
  **Note**: Models are loaded locally with 4-bit quantization for memory efficiency. First load may take several minutes.
913
  """)
914
 
 
997
  )
998
 
999
  demo.queue().launch(
1000
+ debug=True,
1001
  share=False,
1002
  show_error=True
1003
  )