Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -40,10 +40,10 @@ class TransformersLLM(LLM):
|
|
40 |
max_new_tokens = kwargs.pop('max_new_tokens', 256)
|
41 |
temperature = kwargs.pop('temperature', 0.7)
|
42 |
do_sample = kwargs.pop('do_sample', True)
|
43 |
-
|
44 |
# Call super with model_name explicitly
|
45 |
super().__init__(model_name=model_name, **kwargs)
|
46 |
-
|
47 |
# Set our custom attributes
|
48 |
self.pipeline = pipeline_obj
|
49 |
self.max_new_tokens = max_new_tokens
|
@@ -59,7 +59,7 @@ class TransformersLLM(LLM):
|
|
59 |
user_question = prompt.split("Question:")[-1].replace("Helpful Answer:", "").strip()
|
60 |
else:
|
61 |
user_question = prompt
|
62 |
-
|
63 |
# Create a focused prompt for the model
|
64 |
if "qwen" in self.model_name.lower():
|
65 |
system_prompt = "You are a helpful assistant that analyzes CSV data and answers questions accurately and concisely."
|
@@ -71,7 +71,7 @@ class TransformersLLM(LLM):
|
|
71 |
formatted_prompt = f"Question: {user_question}\n\nBased on the provided CSV data, please provide a clear and informative answer:\n\nAnswer:"
|
72 |
|
73 |
print(f"Generating response for: {user_question[:100]}...")
|
74 |
-
|
75 |
# Generate response
|
76 |
with torch.no_grad():
|
77 |
response = self.pipeline(
|
@@ -82,10 +82,10 @@ class TransformersLLM(LLM):
|
|
82 |
pad_token_id=self.pipeline.tokenizer.eos_token_id,
|
83 |
return_full_text=False
|
84 |
)
|
85 |
-
|
86 |
if response and len(response) > 0:
|
87 |
generated_text = response[0]['generated_text'].strip()
|
88 |
-
|
89 |
# Clean up the response
|
90 |
if "assistant" in generated_text:
|
91 |
generated_text = generated_text.split("assistant")[-1].strip()
|
@@ -93,18 +93,18 @@ class TransformersLLM(LLM):
|
|
93 |
generated_text = generated_text.split("<|im_end|>")[0].strip()
|
94 |
if "<|eot_id|>" in generated_text:
|
95 |
generated_text = generated_text.split("<|eot_id|>")[0].strip()
|
96 |
-
|
97 |
# Remove any remaining special tokens
|
98 |
for token in ["<|im_start|>", "<|im_end|>", "<|eot_id|>", "<|begin_of_text|>", "<|end_of_text|>"]:
|
99 |
generated_text = generated_text.replace(token, "")
|
100 |
-
|
101 |
generated_text = generated_text.strip()
|
102 |
-
|
103 |
if len(generated_text) > 10:
|
104 |
return generated_text
|
105 |
-
|
106 |
return "I apologize, but I couldn't generate a meaningful response. Please try rephrasing your question."
|
107 |
-
|
108 |
except Exception as e:
|
109 |
print(f"Error in LLM generation: {e}")
|
110 |
return f"I encountered an error while processing your question: {str(e)}. Please try again."
|
@@ -115,8 +115,8 @@ class TransformersLLM(LLM):
|
|
115 |
|
116 |
# Available models
|
117 |
AVAILABLE_MODELS = {
|
118 |
-
"
|
119 |
-
"
|
120 |
}
|
121 |
|
122 |
CHUNK_SIZES = {
|
@@ -128,10 +128,10 @@ CHUNK_SIZES = {
|
|
128 |
def load_model(model_choice: str, progress=gr.Progress()):
|
129 |
"""Load the selected model with proper memory management"""
|
130 |
global current_model, current_tokenizer, current_pipeline
|
131 |
-
|
132 |
try:
|
133 |
model_id = AVAILABLE_MODELS[model_choice]
|
134 |
-
|
135 |
with model_lock:
|
136 |
# Clear existing model if different
|
137 |
if current_model is not None:
|
@@ -144,22 +144,22 @@ def load_model(model_choice: str, progress=gr.Progress()):
|
|
144 |
else:
|
145 |
# Same model already loaded
|
146 |
return current_pipeline, f"β
Model {model_choice} already loaded!"
|
147 |
-
|
148 |
progress(0.2, desc=f"Loading tokenizer for {model_choice}...")
|
149 |
print(f"Loading tokenizer for {model_id}...")
|
150 |
-
|
151 |
tokenizer = AutoTokenizer.from_pretrained(
|
152 |
-
model_id,
|
153 |
trust_remote_code=True
|
154 |
)
|
155 |
-
|
156 |
# Set pad token if not exists
|
157 |
if tokenizer.pad_token is None:
|
158 |
tokenizer.pad_token = tokenizer.eos_token
|
159 |
-
|
160 |
progress(0.5, desc=f"Loading model {model_choice}...")
|
161 |
print(f"Loading model {model_id}...")
|
162 |
-
|
163 |
# Load model with appropriate settings for Colab
|
164 |
# model = AutoModelForCausalLM.from_pretrained(
|
165 |
# model_id,
|
@@ -190,7 +190,7 @@ def load_model(model_choice: str, progress=gr.Progress()):
|
|
190 |
)
|
191 |
progress(0.8, desc="Creating pipeline...")
|
192 |
print("Creating text generation pipeline...")
|
193 |
-
|
194 |
# Create pipeline
|
195 |
# pipe = pipeline(
|
196 |
# "text-generation",
|
@@ -210,11 +210,11 @@ def load_model(model_choice: str, progress=gr.Progress()):
|
|
210 |
current_model = model
|
211 |
current_tokenizer = tokenizer
|
212 |
current_pipeline = pipe
|
213 |
-
|
214 |
progress(1.0, desc="Model loaded successfully!")
|
215 |
-
|
216 |
return pipe, f"β
Model {model_choice} loaded successfully!"
|
217 |
-
|
218 |
except Exception as e:
|
219 |
print(f"Error loading model: {e}")
|
220 |
traceback.print_exc()
|
@@ -250,7 +250,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
|
|
250 |
# Read CSV file with multiple encoding attempts
|
251 |
df = None
|
252 |
encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1']
|
253 |
-
|
254 |
for encoding in encodings:
|
255 |
try:
|
256 |
df = pd.read_csv(file_path, encoding=encoding)
|
@@ -258,7 +258,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
|
|
258 |
break
|
259 |
except UnicodeDecodeError:
|
260 |
continue
|
261 |
-
|
262 |
if df is None:
|
263 |
print(f"Could not read {file_path} with any encoding")
|
264 |
return []
|
@@ -269,12 +269,12 @@ def csv_to_documents(file_path: str) -> List[Document]:
|
|
269 |
|
270 |
# Clean the dataframe
|
271 |
df = df.dropna(how='all') # Remove completely empty rows
|
272 |
-
|
273 |
# Get basic info about the CSV
|
274 |
filename = os.path.basename(file_path)
|
275 |
total_rows = len(df)
|
276 |
columns = list(df.columns)
|
277 |
-
|
278 |
print(f"Processing {filename}: {total_rows} rows, {len(columns)} columns")
|
279 |
print(f"Columns: {columns}")
|
280 |
|
@@ -282,7 +282,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
|
|
282 |
|
283 |
# Create a summary document first
|
284 |
summary_content = f"Dataset: {filename}\nTotal rows: {total_rows}\nColumns: {', '.join(columns)}\n"
|
285 |
-
|
286 |
# Add column statistics if numeric columns exist
|
287 |
numeric_cols = df.select_dtypes(include=['number']).columns
|
288 |
if len(numeric_cols) > 0:
|
@@ -314,7 +314,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
|
|
314 |
try:
|
315 |
# Create a more structured text representation
|
316 |
row_text_parts = []
|
317 |
-
|
318 |
# Add row identifier
|
319 |
row_text_parts.append(f"Row {idx + 1} of {filename}")
|
320 |
|
@@ -323,7 +323,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
|
|
323 |
value = row[col]
|
324 |
if pd.isna(value):
|
325 |
continue # Skip NaN values
|
326 |
-
|
327 |
# Clean and format the value
|
328 |
if isinstance(value, (int, float)):
|
329 |
formatted_value = f"{value:,.2f}" if isinstance(value, float) else f"{value:,}"
|
@@ -331,7 +331,7 @@ def csv_to_documents(file_path: str) -> List[Document]:
|
|
331 |
formatted_value = str(value).replace('\n', ' ').replace('\r', ' ').strip()
|
332 |
if len(formatted_value) > 100:
|
333 |
formatted_value = formatted_value[:100] + "..."
|
334 |
-
|
335 |
row_text_parts.append(f"{col}: {formatted_value}")
|
336 |
|
337 |
# Combine all parts
|
@@ -395,18 +395,18 @@ def load_doc(list_file_path: List[str], splitting_strategy: str, chunk_size: str
|
|
395 |
|
396 |
# Apply text splitting with adjusted parameters
|
397 |
text_splitter = get_text_splitter(
|
398 |
-
splitting_strategy,
|
399 |
-
chunk_size_value,
|
400 |
max(50, chunk_size_value // 10) # Dynamic overlap
|
401 |
)
|
402 |
doc_splits = text_splitter.split_documents(all_documents)
|
403 |
|
404 |
print(f"Total document chunks after splitting: {len(doc_splits)}")
|
405 |
-
|
406 |
# Print some sample chunks for debugging
|
407 |
for i, split in enumerate(doc_splits[:3]):
|
408 |
print(f"Sample chunk {i+1}: {split.page_content[:100]}...")
|
409 |
-
|
410 |
return doc_splits
|
411 |
|
412 |
except Exception as e:
|
@@ -421,7 +421,7 @@ def create_db(splits, db_choice: str = "faiss"):
|
|
421 |
raise ValueError("No document splits provided")
|
422 |
|
423 |
print(f"Creating {db_choice} database with {len(splits)} documents")
|
424 |
-
|
425 |
# Use a reliable embedding model with better parameters
|
426 |
embeddings = HuggingFaceEmbeddings(
|
427 |
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
@@ -434,29 +434,29 @@ def create_db(splits, db_choice: str = "faiss"):
|
|
434 |
'batch_size': 16
|
435 |
}
|
436 |
)
|
437 |
-
|
438 |
print("Testing embeddings with sample text...")
|
439 |
test_embedding = embeddings.embed_query("test query")
|
440 |
print(f"Embedding dimension: {len(test_embedding)}")
|
441 |
-
|
442 |
db_creators = {
|
443 |
"faiss": lambda: FAISS.from_documents(splits, embeddings),
|
444 |
"chroma": lambda: Chroma.from_documents(
|
445 |
-
splits,
|
446 |
embeddings,
|
447 |
persist_directory=None # In-memory
|
448 |
)
|
449 |
}
|
450 |
-
|
451 |
db = db_creators[db_choice]()
|
452 |
print(f"Successfully created {db_choice} database")
|
453 |
-
|
454 |
# Test the database with a simple query
|
455 |
test_results = db.similarity_search("test", k=1)
|
456 |
print(f"Database test successful, found {len(test_results)} results")
|
457 |
-
|
458 |
return db
|
459 |
-
|
460 |
except Exception as e:
|
461 |
print(f"Error creating database: {e}")
|
462 |
traceback.print_exc()
|
@@ -513,15 +513,15 @@ def initialize_llmchain(model_choice, temperature, max_tokens, top_k, vector_db,
|
|
513 |
return None, "β Please create vector database first."
|
514 |
|
515 |
progress(0.2, desc="Loading model...")
|
516 |
-
|
517 |
# Load the selected model
|
518 |
pipeline_obj, model_status = load_model(model_choice, progress)
|
519 |
-
|
520 |
if pipeline_obj is None:
|
521 |
return None, model_status
|
522 |
|
523 |
progress(0.7, desc="Creating LLM instance...")
|
524 |
-
|
525 |
# Create our custom LLM wrapper
|
526 |
llm = TransformersLLM(
|
527 |
model_name=AVAILABLE_MODELS[model_choice],
|
@@ -530,9 +530,9 @@ def initialize_llmchain(model_choice, temperature, max_tokens, top_k, vector_db,
|
|
530 |
temperature=max(0.1, min(1.0, temperature)),
|
531 |
do_sample=temperature > 0.1
|
532 |
)
|
533 |
-
|
534 |
progress(0.8, desc="Setting up retriever...")
|
535 |
-
|
536 |
# Create retriever with optimized parameters
|
537 |
retriever = vector_db.as_retriever(
|
538 |
search_type="similarity",
|
@@ -541,7 +541,7 @@ def initialize_llmchain(model_choice, temperature, max_tokens, top_k, vector_db,
|
|
541 |
"fetch_k": min(max(3, top_k * 2), 20) # Fetch more, then filter
|
542 |
}
|
543 |
)
|
544 |
-
|
545 |
# Test the retriever
|
546 |
try:
|
547 |
test_docs = retriever.get_relevant_documents("test query")
|
@@ -551,7 +551,7 @@ def initialize_llmchain(model_choice, temperature, max_tokens, top_k, vector_db,
|
|
551 |
return None, f"β Database retriever failed: {str(e)}"
|
552 |
|
553 |
progress(0.9, desc="Creating QA chain...")
|
554 |
-
|
555 |
# Create QA chain with error handling
|
556 |
qa_chain = RetrievalQA.from_chain_type(
|
557 |
llm=llm,
|
@@ -568,7 +568,7 @@ def initialize_llmchain(model_choice, temperature, max_tokens, top_k, vector_db,
|
|
568 |
f"π‘οΈ Temperature: {temperature}\n"
|
569 |
f"π Max tokens: {max_tokens}\n"
|
570 |
f"π Retriever K: {top_k}")
|
571 |
-
|
572 |
return qa_chain, success_msg
|
573 |
|
574 |
except Exception as e:
|
@@ -601,14 +601,14 @@ def conversation(qa_chain, message, history):
|
|
601 |
print(f"\n{'='*50}")
|
602 |
print(f"Processing question: {message}")
|
603 |
print(f"{'='*50}")
|
604 |
-
|
605 |
# Enhance the query for better CSV data understanding
|
606 |
enhanced_query = f"""Based on the CSV data provided, please answer this question: {message.strip()}
|
607 |
|
608 |
Please provide a clear, informative answer that directly addresses the question. If you're analyzing data, include specific values, trends, or patterns you observe."""
|
609 |
-
|
610 |
print(f"Enhanced query: {enhanced_query}")
|
611 |
-
|
612 |
# Call the QA chain with timeout handling
|
613 |
start_time = time.time()
|
614 |
try:
|
@@ -628,7 +628,7 @@ Please provide a clear, informative answer that directly addresses the question.
|
|
628 |
fallback_response += "Please try rephrasing your question or try again."
|
629 |
else:
|
630 |
fallback_response = "I couldn't find relevant information in your CSV data for this question. Please try a different question or check if your data contains the information you're looking for."
|
631 |
-
|
632 |
return (
|
633 |
qa_chain,
|
634 |
gr.update(value=""),
|
@@ -643,18 +643,18 @@ Please provide a clear, informative answer that directly addresses the question.
|
|
643 |
except Exception as fallback_error:
|
644 |
print(f"Fallback also failed: {fallback_error}")
|
645 |
error_response = f"I encountered an error processing your question: {str(qa_error)}. Please try:\n\n1. Using a simpler question\n2. Waiting a moment and trying again\n3. Reloading the model"
|
646 |
-
|
647 |
return (
|
648 |
qa_chain,
|
649 |
gr.update(value=""),
|
650 |
history + [(message, error_response)],
|
651 |
f"Error: {str(qa_error)}", "Error processing", "", "No source", "", "No source"
|
652 |
)
|
653 |
-
|
654 |
# Extract and process the response
|
655 |
print(f"Raw response type: {type(response)}")
|
656 |
print(f"Response keys: {response.keys() if isinstance(response, dict) else 'Not a dict'}")
|
657 |
-
|
658 |
if isinstance(response, dict):
|
659 |
# Get the answer
|
660 |
response_answer = response.get("result") or response.get("answer") or str(response)
|
@@ -672,15 +672,15 @@ Please provide a clear, informative answer that directly addresses the question.
|
|
672 |
response_answer = response_answer.replace("Based on the following context, please provide a helpful and accurate answer", "").strip()
|
673 |
response_answer = response_answer.replace("Helpful Answer:", "").strip()
|
674 |
response_answer = response_answer.replace("Answer:", "").strip()
|
675 |
-
|
676 |
# Remove repeated prompts
|
677 |
if enhanced_query[:50] in response_answer:
|
678 |
response_answer = response_answer.replace(enhanced_query, "").strip()
|
679 |
-
|
680 |
# Ensure we have a meaningful response
|
681 |
if len(response_answer.strip()) < 10:
|
682 |
response_answer = "I was able to process your question, but the response was too brief. Please try rephrasing your question or providing more context."
|
683 |
-
|
684 |
if not response_answer or response_answer.strip() == "":
|
685 |
response_answer = "I apologize, but I couldn't generate a meaningful response to your question. Please try rephrasing your question or ensure your CSV data contains relevant information."
|
686 |
|
@@ -703,14 +703,14 @@ Please provide a clear, informative answer that directly addresses the question.
|
|
703 |
content = content[:300] + "..."
|
704 |
|
705 |
source_contents.append(content)
|
706 |
-
|
707 |
if doc_type == "csv_summary":
|
708 |
source_info.append(f"Summary of {source_file}")
|
709 |
else:
|
710 |
source_info.append(f"File: {source_file} | Row: {row_info}")
|
711 |
-
|
712 |
print(f"Source {i+1}: {source_info[-1][:50]}...")
|
713 |
-
|
714 |
except Exception as e:
|
715 |
print(f"Error processing source {i}: {e}")
|
716 |
source_contents.append(f"Error processing source: {str(e)}")
|
@@ -724,11 +724,14 @@ Please provide a clear, informative answer that directly addresses the question.
|
|
724 |
print(f"Final response length: {len(response_answer)} characters")
|
725 |
print(f"Response preview: {response_answer[:100]}...")
|
726 |
print(f"Sources: {[info for info in source_info if info != 'No additional sources']}")
|
727 |
-
|
|
|
|
|
|
|
728 |
return (
|
729 |
qa_chain,
|
730 |
gr.update(value=""),
|
731 |
-
|
732 |
source_contents[0],
|
733 |
source_info[0],
|
734 |
source_contents[1],
|
@@ -742,9 +745,9 @@ Please provide a clear, informative answer that directly addresses the question.
|
|
742 |
print(f"Error: {str(e)}")
|
743 |
print(f"Error type: {type(e).__name__}")
|
744 |
traceback.print_exc()
|
745 |
-
|
746 |
error_msg = f"β I encountered an error while processing your question:\n\n{str(e)}\n\nPlease try:\n1. Using a simpler question\n2. Waiting a moment and trying again\n3. Reloading the model\n4. Recreating the database"
|
747 |
-
|
748 |
return (
|
749 |
qa_chain,
|
750 |
gr.update(value=""),
|
@@ -905,7 +908,7 @@ def demo():
|
|
905 |
**Available Models:**
|
906 |
- **Qwen2.5-7B-Instruct**: Advanced Chinese-English bilingual model, excellent for analysis
|
907 |
- **Llama-3.1-8B-Instruct**: Meta's powerful instruction-following model
|
908 |
-
|
909 |
**Note**: Models are loaded locally with 4-bit quantization for memory efficiency. First load may take several minutes.
|
910 |
""")
|
911 |
|
@@ -994,7 +997,7 @@ def demo():
|
|
994 |
)
|
995 |
|
996 |
demo.queue().launch(
|
997 |
-
debug=True,
|
998 |
share=False,
|
999 |
show_error=True
|
1000 |
)
|
|
|
40 |
max_new_tokens = kwargs.pop('max_new_tokens', 256)
|
41 |
temperature = kwargs.pop('temperature', 0.7)
|
42 |
do_sample = kwargs.pop('do_sample', True)
|
43 |
+
|
44 |
# Call super with model_name explicitly
|
45 |
super().__init__(model_name=model_name, **kwargs)
|
46 |
+
|
47 |
# Set our custom attributes
|
48 |
self.pipeline = pipeline_obj
|
49 |
self.max_new_tokens = max_new_tokens
|
|
|
59 |
user_question = prompt.split("Question:")[-1].replace("Helpful Answer:", "").strip()
|
60 |
else:
|
61 |
user_question = prompt
|
62 |
+
|
63 |
# Create a focused prompt for the model
|
64 |
if "qwen" in self.model_name.lower():
|
65 |
system_prompt = "You are a helpful assistant that analyzes CSV data and answers questions accurately and concisely."
|
|
|
71 |
formatted_prompt = f"Question: {user_question}\n\nBased on the provided CSV data, please provide a clear and informative answer:\n\nAnswer:"
|
72 |
|
73 |
print(f"Generating response for: {user_question[:100]}...")
|
74 |
+
|
75 |
# Generate response
|
76 |
with torch.no_grad():
|
77 |
response = self.pipeline(
|
|
|
82 |
pad_token_id=self.pipeline.tokenizer.eos_token_id,
|
83 |
return_full_text=False
|
84 |
)
|
85 |
+
|
86 |
if response and len(response) > 0:
|
87 |
generated_text = response[0]['generated_text'].strip()
|
88 |
+
|
89 |
# Clean up the response
|
90 |
if "assistant" in generated_text:
|
91 |
generated_text = generated_text.split("assistant")[-1].strip()
|
|
|
93 |
generated_text = generated_text.split("<|im_end|>")[0].strip()
|
94 |
if "<|eot_id|>" in generated_text:
|
95 |
generated_text = generated_text.split("<|eot_id|>")[0].strip()
|
96 |
+
|
97 |
# Remove any remaining special tokens
|
98 |
for token in ["<|im_start|>", "<|im_end|>", "<|eot_id|>", "<|begin_of_text|>", "<|end_of_text|>"]:
|
99 |
generated_text = generated_text.replace(token, "")
|
100 |
+
|
101 |
generated_text = generated_text.strip()
|
102 |
+
|
103 |
if len(generated_text) > 10:
|
104 |
return generated_text
|
105 |
+
|
106 |
return "I apologize, but I couldn't generate a meaningful response. Please try rephrasing your question."
|
107 |
+
|
108 |
except Exception as e:
|
109 |
print(f"Error in LLM generation: {e}")
|
110 |
return f"I encountered an error while processing your question: {str(e)}. Please try again."
|
|
|
115 |
|
116 |
# Available models
|
117 |
AVAILABLE_MODELS = {
|
118 |
+
"Llama-3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
|
119 |
+
"Qwen2.5-0.5B-Instruct": "Qwen/Qwen2.5-0.5B-Instruct"
|
120 |
}
|
121 |
|
122 |
CHUNK_SIZES = {
|
|
|
128 |
def load_model(model_choice: str, progress=gr.Progress()):
|
129 |
"""Load the selected model with proper memory management"""
|
130 |
global current_model, current_tokenizer, current_pipeline
|
131 |
+
|
132 |
try:
|
133 |
model_id = AVAILABLE_MODELS[model_choice]
|
134 |
+
|
135 |
with model_lock:
|
136 |
# Clear existing model if different
|
137 |
if current_model is not None:
|
|
|
144 |
else:
|
145 |
# Same model already loaded
|
146 |
return current_pipeline, f"β
Model {model_choice} already loaded!"
|
147 |
+
|
148 |
progress(0.2, desc=f"Loading tokenizer for {model_choice}...")
|
149 |
print(f"Loading tokenizer for {model_id}...")
|
150 |
+
|
151 |
tokenizer = AutoTokenizer.from_pretrained(
|
152 |
+
model_id,
|
153 |
trust_remote_code=True
|
154 |
)
|
155 |
+
|
156 |
# Set pad token if not exists
|
157 |
if tokenizer.pad_token is None:
|
158 |
tokenizer.pad_token = tokenizer.eos_token
|
159 |
+
|
160 |
progress(0.5, desc=f"Loading model {model_choice}...")
|
161 |
print(f"Loading model {model_id}...")
|
162 |
+
|
163 |
# Load model with appropriate settings for Colab
|
164 |
# model = AutoModelForCausalLM.from_pretrained(
|
165 |
# model_id,
|
|
|
190 |
)
|
191 |
progress(0.8, desc="Creating pipeline...")
|
192 |
print("Creating text generation pipeline...")
|
193 |
+
|
194 |
# Create pipeline
|
195 |
# pipe = pipeline(
|
196 |
# "text-generation",
|
|
|
210 |
current_model = model
|
211 |
current_tokenizer = tokenizer
|
212 |
current_pipeline = pipe
|
213 |
+
|
214 |
progress(1.0, desc="Model loaded successfully!")
|
215 |
+
|
216 |
return pipe, f"β
Model {model_choice} loaded successfully!"
|
217 |
+
|
218 |
except Exception as e:
|
219 |
print(f"Error loading model: {e}")
|
220 |
traceback.print_exc()
|
|
|
250 |
# Read CSV file with multiple encoding attempts
|
251 |
df = None
|
252 |
encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1']
|
253 |
+
|
254 |
for encoding in encodings:
|
255 |
try:
|
256 |
df = pd.read_csv(file_path, encoding=encoding)
|
|
|
258 |
break
|
259 |
except UnicodeDecodeError:
|
260 |
continue
|
261 |
+
|
262 |
if df is None:
|
263 |
print(f"Could not read {file_path} with any encoding")
|
264 |
return []
|
|
|
269 |
|
270 |
# Clean the dataframe
|
271 |
df = df.dropna(how='all') # Remove completely empty rows
|
272 |
+
|
273 |
# Get basic info about the CSV
|
274 |
filename = os.path.basename(file_path)
|
275 |
total_rows = len(df)
|
276 |
columns = list(df.columns)
|
277 |
+
|
278 |
print(f"Processing {filename}: {total_rows} rows, {len(columns)} columns")
|
279 |
print(f"Columns: {columns}")
|
280 |
|
|
|
282 |
|
283 |
# Create a summary document first
|
284 |
summary_content = f"Dataset: {filename}\nTotal rows: {total_rows}\nColumns: {', '.join(columns)}\n"
|
285 |
+
|
286 |
# Add column statistics if numeric columns exist
|
287 |
numeric_cols = df.select_dtypes(include=['number']).columns
|
288 |
if len(numeric_cols) > 0:
|
|
|
314 |
try:
|
315 |
# Create a more structured text representation
|
316 |
row_text_parts = []
|
317 |
+
|
318 |
# Add row identifier
|
319 |
row_text_parts.append(f"Row {idx + 1} of {filename}")
|
320 |
|
|
|
323 |
value = row[col]
|
324 |
if pd.isna(value):
|
325 |
continue # Skip NaN values
|
326 |
+
|
327 |
# Clean and format the value
|
328 |
if isinstance(value, (int, float)):
|
329 |
formatted_value = f"{value:,.2f}" if isinstance(value, float) else f"{value:,}"
|
|
|
331 |
formatted_value = str(value).replace('\n', ' ').replace('\r', ' ').strip()
|
332 |
if len(formatted_value) > 100:
|
333 |
formatted_value = formatted_value[:100] + "..."
|
334 |
+
|
335 |
row_text_parts.append(f"{col}: {formatted_value}")
|
336 |
|
337 |
# Combine all parts
|
|
|
395 |
|
396 |
# Apply text splitting with adjusted parameters
|
397 |
text_splitter = get_text_splitter(
|
398 |
+
splitting_strategy,
|
399 |
+
chunk_size_value,
|
400 |
max(50, chunk_size_value // 10) # Dynamic overlap
|
401 |
)
|
402 |
doc_splits = text_splitter.split_documents(all_documents)
|
403 |
|
404 |
print(f"Total document chunks after splitting: {len(doc_splits)}")
|
405 |
+
|
406 |
# Print some sample chunks for debugging
|
407 |
for i, split in enumerate(doc_splits[:3]):
|
408 |
print(f"Sample chunk {i+1}: {split.page_content[:100]}...")
|
409 |
+
|
410 |
return doc_splits
|
411 |
|
412 |
except Exception as e:
|
|
|
421 |
raise ValueError("No document splits provided")
|
422 |
|
423 |
print(f"Creating {db_choice} database with {len(splits)} documents")
|
424 |
+
|
425 |
# Use a reliable embedding model with better parameters
|
426 |
embeddings = HuggingFaceEmbeddings(
|
427 |
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
|
|
434 |
'batch_size': 16
|
435 |
}
|
436 |
)
|
437 |
+
|
438 |
print("Testing embeddings with sample text...")
|
439 |
test_embedding = embeddings.embed_query("test query")
|
440 |
print(f"Embedding dimension: {len(test_embedding)}")
|
441 |
+
|
442 |
db_creators = {
|
443 |
"faiss": lambda: FAISS.from_documents(splits, embeddings),
|
444 |
"chroma": lambda: Chroma.from_documents(
|
445 |
+
splits,
|
446 |
embeddings,
|
447 |
persist_directory=None # In-memory
|
448 |
)
|
449 |
}
|
450 |
+
|
451 |
db = db_creators[db_choice]()
|
452 |
print(f"Successfully created {db_choice} database")
|
453 |
+
|
454 |
# Test the database with a simple query
|
455 |
test_results = db.similarity_search("test", k=1)
|
456 |
print(f"Database test successful, found {len(test_results)} results")
|
457 |
+
|
458 |
return db
|
459 |
+
|
460 |
except Exception as e:
|
461 |
print(f"Error creating database: {e}")
|
462 |
traceback.print_exc()
|
|
|
513 |
return None, "β Please create vector database first."
|
514 |
|
515 |
progress(0.2, desc="Loading model...")
|
516 |
+
|
517 |
# Load the selected model
|
518 |
pipeline_obj, model_status = load_model(model_choice, progress)
|
519 |
+
|
520 |
if pipeline_obj is None:
|
521 |
return None, model_status
|
522 |
|
523 |
progress(0.7, desc="Creating LLM instance...")
|
524 |
+
|
525 |
# Create our custom LLM wrapper
|
526 |
llm = TransformersLLM(
|
527 |
model_name=AVAILABLE_MODELS[model_choice],
|
|
|
530 |
temperature=max(0.1, min(1.0, temperature)),
|
531 |
do_sample=temperature > 0.1
|
532 |
)
|
533 |
+
|
534 |
progress(0.8, desc="Setting up retriever...")
|
535 |
+
|
536 |
# Create retriever with optimized parameters
|
537 |
retriever = vector_db.as_retriever(
|
538 |
search_type="similarity",
|
|
|
541 |
"fetch_k": min(max(3, top_k * 2), 20) # Fetch more, then filter
|
542 |
}
|
543 |
)
|
544 |
+
|
545 |
# Test the retriever
|
546 |
try:
|
547 |
test_docs = retriever.get_relevant_documents("test query")
|
|
|
551 |
return None, f"β Database retriever failed: {str(e)}"
|
552 |
|
553 |
progress(0.9, desc="Creating QA chain...")
|
554 |
+
|
555 |
# Create QA chain with error handling
|
556 |
qa_chain = RetrievalQA.from_chain_type(
|
557 |
llm=llm,
|
|
|
568 |
f"π‘οΈ Temperature: {temperature}\n"
|
569 |
f"π Max tokens: {max_tokens}\n"
|
570 |
f"π Retriever K: {top_k}")
|
571 |
+
|
572 |
return qa_chain, success_msg
|
573 |
|
574 |
except Exception as e:
|
|
|
601 |
print(f"\n{'='*50}")
|
602 |
print(f"Processing question: {message}")
|
603 |
print(f"{'='*50}")
|
604 |
+
|
605 |
# Enhance the query for better CSV data understanding
|
606 |
enhanced_query = f"""Based on the CSV data provided, please answer this question: {message.strip()}
|
607 |
|
608 |
Please provide a clear, informative answer that directly addresses the question. If you're analyzing data, include specific values, trends, or patterns you observe."""
|
609 |
+
|
610 |
print(f"Enhanced query: {enhanced_query}")
|
611 |
+
|
612 |
# Call the QA chain with timeout handling
|
613 |
start_time = time.time()
|
614 |
try:
|
|
|
628 |
fallback_response += "Please try rephrasing your question or try again."
|
629 |
else:
|
630 |
fallback_response = "I couldn't find relevant information in your CSV data for this question. Please try a different question or check if your data contains the information you're looking for."
|
631 |
+
|
632 |
return (
|
633 |
qa_chain,
|
634 |
gr.update(value=""),
|
|
|
643 |
except Exception as fallback_error:
|
644 |
print(f"Fallback also failed: {fallback_error}")
|
645 |
error_response = f"I encountered an error processing your question: {str(qa_error)}. Please try:\n\n1. Using a simpler question\n2. Waiting a moment and trying again\n3. Reloading the model"
|
646 |
+
|
647 |
return (
|
648 |
qa_chain,
|
649 |
gr.update(value=""),
|
650 |
history + [(message, error_response)],
|
651 |
f"Error: {str(qa_error)}", "Error processing", "", "No source", "", "No source"
|
652 |
)
|
653 |
+
|
654 |
# Extract and process the response
|
655 |
print(f"Raw response type: {type(response)}")
|
656 |
print(f"Response keys: {response.keys() if isinstance(response, dict) else 'Not a dict'}")
|
657 |
+
|
658 |
if isinstance(response, dict):
|
659 |
# Get the answer
|
660 |
response_answer = response.get("result") or response.get("answer") or str(response)
|
|
|
672 |
response_answer = response_answer.replace("Based on the following context, please provide a helpful and accurate answer", "").strip()
|
673 |
response_answer = response_answer.replace("Helpful Answer:", "").strip()
|
674 |
response_answer = response_answer.replace("Answer:", "").strip()
|
675 |
+
|
676 |
# Remove repeated prompts
|
677 |
if enhanced_query[:50] in response_answer:
|
678 |
response_answer = response_answer.replace(enhanced_query, "").strip()
|
679 |
+
|
680 |
# Ensure we have a meaningful response
|
681 |
if len(response_answer.strip()) < 10:
|
682 |
response_answer = "I was able to process your question, but the response was too brief. Please try rephrasing your question or providing more context."
|
683 |
+
|
684 |
if not response_answer or response_answer.strip() == "":
|
685 |
response_answer = "I apologize, but I couldn't generate a meaningful response to your question. Please try rephrasing your question or ensure your CSV data contains relevant information."
|
686 |
|
|
|
703 |
content = content[:300] + "..."
|
704 |
|
705 |
source_contents.append(content)
|
706 |
+
|
707 |
if doc_type == "csv_summary":
|
708 |
source_info.append(f"Summary of {source_file}")
|
709 |
else:
|
710 |
source_info.append(f"File: {source_file} | Row: {row_info}")
|
711 |
+
|
712 |
print(f"Source {i+1}: {source_info[-1][:50]}...")
|
713 |
+
|
714 |
except Exception as e:
|
715 |
print(f"Error processing source {i}: {e}")
|
716 |
source_contents.append(f"Error processing source: {str(e)}")
|
|
|
724 |
print(f"Final response length: {len(response_answer)} characters")
|
725 |
print(f"Response preview: {response_answer[:100]}...")
|
726 |
print(f"Sources: {[info for info in source_info if info != 'No additional sources']}")
|
727 |
+
|
728 |
+
new_history = history.copy()
|
729 |
+
new_history.append({"role": "user", "content": message})
|
730 |
+
new_history.append({"role": "assistant", "content": response_answer})
|
731 |
return (
|
732 |
qa_chain,
|
733 |
gr.update(value=""),
|
734 |
+
new_history,
|
735 |
source_contents[0],
|
736 |
source_info[0],
|
737 |
source_contents[1],
|
|
|
745 |
print(f"Error: {str(e)}")
|
746 |
print(f"Error type: {type(e).__name__}")
|
747 |
traceback.print_exc()
|
748 |
+
|
749 |
error_msg = f"β I encountered an error while processing your question:\n\n{str(e)}\n\nPlease try:\n1. Using a simpler question\n2. Waiting a moment and trying again\n3. Reloading the model\n4. Recreating the database"
|
750 |
+
|
751 |
return (
|
752 |
qa_chain,
|
753 |
gr.update(value=""),
|
|
|
908 |
**Available Models:**
|
909 |
- **Qwen2.5-7B-Instruct**: Advanced Chinese-English bilingual model, excellent for analysis
|
910 |
- **Llama-3.1-8B-Instruct**: Meta's powerful instruction-following model
|
911 |
+
|
912 |
**Note**: Models are loaded locally with 4-bit quantization for memory efficiency. First load may take several minutes.
|
913 |
""")
|
914 |
|
|
|
997 |
)
|
998 |
|
999 |
demo.queue().launch(
|
1000 |
+
debug=True,
|
1001 |
share=False,
|
1002 |
show_error=True
|
1003 |
)
|