sandeep-huggingface commited on
Commit
af2cce2
Β·
verified Β·
1 Parent(s): 4745ef7
Files changed (1) hide show
  1. app.py +454 -0
app.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import pandas as pd
4
+ from typing import List, Dict
5
+ from langchain.text_splitter import (
6
+ RecursiveCharacterTextSplitter,
7
+ CharacterTextSplitter,
8
+ TokenTextSplitter
9
+ )
10
+ from langchain_community.vectorstores import FAISS, Chroma, Qdrant
11
+ from langchain_community.document_loaders import CSVLoader
12
+ from langchain.chains import ConversationalRetrievalChain
13
+ from langchain_community.embeddings import HuggingFaceEmbeddings
14
+ from langchain_huggingface import HuggingFaceEndpoint
15
+ from langchain.memory import ConversationBufferMemory
16
+ from langchain.schema import Document
17
+ import tempfile
18
+ import shutil
19
+
20
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
21
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
+ api_token = os.getenv("HF_TOKEN")
23
+
24
+ CHUNK_SIZES = {
25
+ "small": {"recursive": 512, "fixed": 512, "token": 256},
26
+ "medium": {"recursive": 1024, "fixed": 1024, "token": 512}
27
+ }
28
+
29
+ def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
30
+ """Get text splitter based on strategy"""
31
+ splitters = {
32
+ "recursive": RecursiveCharacterTextSplitter(
33
+ chunk_size=chunk_size,
34
+ chunk_overlap=chunk_overlap
35
+ ),
36
+ "fixed": CharacterTextSplitter(
37
+ chunk_size=chunk_size,
38
+ chunk_overlap=chunk_overlap
39
+ ),
40
+ "token": TokenTextSplitter(
41
+ chunk_size=chunk_size,
42
+ chunk_overlap=chunk_overlap
43
+ )
44
+ }
45
+ return splitters.get(strategy)
46
+
47
+ def csv_to_documents(file_path: str) -> List[Document]:
48
+ """Convert CSV file to LangChain documents with enhanced metadata"""
49
+ try:
50
+ # Read CSV file
51
+ df = pd.read_csv(file_path)
52
+
53
+ # Get basic info about the CSV
54
+ filename = os.path.basename(file_path)
55
+ total_rows = len(df)
56
+ columns = list(df.columns)
57
+
58
+ documents = []
59
+
60
+ # Create documents from each row
61
+ for idx, row in df.iterrows():
62
+ # Create a readable text representation of the row
63
+ row_text_parts = []
64
+
65
+ # Add column headers and values
66
+ for col in df.columns:
67
+ value = str(row[col]) if pd.notna(row[col]) else "N/A"
68
+ row_text_parts.append(f"{col}: {value}")
69
+
70
+ # Combine all column-value pairs
71
+ content = "\n".join(row_text_parts)
72
+
73
+ # Create document with rich metadata
74
+ doc = Document(
75
+ page_content=content,
76
+ metadata={
77
+ "source": filename,
78
+ "row": idx + 1, # 1-based row numbering
79
+ "total_rows": total_rows,
80
+ "columns": ", ".join(columns),
81
+ "file_path": file_path
82
+ }
83
+ )
84
+ documents.append(doc)
85
+
86
+ return documents
87
+
88
+ except Exception as e:
89
+ print(f"Error processing CSV file {file_path}: {str(e)}")
90
+ return []
91
+
92
+ def load_doc(list_file_path: List[str], splitting_strategy: str, chunk_size: str):
93
+ """Load and process CSV documents"""
94
+ chunk_size_value = CHUNK_SIZES[chunk_size][splitting_strategy]
95
+
96
+ # Load all CSV files and convert to documents
97
+ all_documents = []
98
+ for file_path in list_file_path:
99
+ documents = csv_to_documents(file_path)
100
+ all_documents.extend(documents)
101
+
102
+ if not all_documents:
103
+ return []
104
+
105
+ # Apply text splitting
106
+ text_splitter = get_text_splitter(splitting_strategy, chunk_size_value)
107
+ doc_splits = text_splitter.split_documents(all_documents)
108
+
109
+ return doc_splits
110
+
111
+ def create_db(splits, db_choice: str = "faiss"):
112
+ """Create vector database from document splits"""
113
+ embeddings = HuggingFaceEmbeddings()
114
+ db_creators = {
115
+ "faiss": lambda: FAISS.from_documents(splits, embeddings),
116
+ "chroma": lambda: Chroma.from_documents(splits, embeddings),
117
+ "qdrant": lambda: Qdrant.from_documents(
118
+ splits,
119
+ embeddings,
120
+ location=":memory:",
121
+ collection_name="csv_docs"
122
+ )
123
+ }
124
+ return db_creators[db_choice]()
125
+
126
+ def initialize_database(list_file_obj, splitting_strategy, chunk_size, db_choice, progress=gr.Progress()):
127
+ """Initialize vector database with error handling"""
128
+ try:
129
+ if not list_file_obj:
130
+ return None, "No files uploaded. Please upload CSV documents first."
131
+
132
+ list_file_path = [x.name for x in list_file_obj if x is not None]
133
+ if not list_file_path:
134
+ return None, "No valid files found. Please upload CSV documents."
135
+
136
+ # Validate that all files are CSV
137
+ non_csv_files = [path for path in list_file_path if not path.lower().endswith('.csv')]
138
+ if non_csv_files:
139
+ return None, f"Non-CSV files detected: {', '.join([os.path.basename(f) for f in non_csv_files])}. Please upload only CSV files."
140
+
141
+ progress(0.2, desc="Loading CSV files...")
142
+ doc_splits = load_doc(list_file_path, splitting_strategy, chunk_size)
143
+
144
+ if not doc_splits:
145
+ return None, "No content extracted from CSV documents. Please check if the files contain data."
146
+
147
+ progress(0.6, desc="Creating vector database...")
148
+ vector_db = create_db(doc_splits, db_choice)
149
+
150
+ progress(1.0, desc="Database created successfully!")
151
+
152
+ num_files = len(list_file_path)
153
+ num_chunks = len(doc_splits)
154
+ file_names = [os.path.basename(f) for f in list_file_path]
155
+
156
+ success_msg = (f"Database created successfully!\n"
157
+ f"πŸ“ Files processed: {num_files} ({', '.join(file_names)})\n"
158
+ f"πŸ“Š Document chunks: {num_chunks}\n"
159
+ f"πŸ”§ Strategy: {splitting_strategy} splitting\n"
160
+ f"πŸ’Ύ Database: {db_choice}")
161
+
162
+ return vector_db, success_msg
163
+
164
+ except Exception as e:
165
+ return None, f"Error creating database: {str(e)}"
166
+
167
+ def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
168
+ """Initialize LLM chain with error handling"""
169
+ try:
170
+ if vector_db is None:
171
+ return None, "Please create vector database first."
172
+
173
+ progress(0.3, desc="Initializing LLM...")
174
+ llm_model = list_llm[llm_choice]
175
+
176
+ llm = HuggingFaceEndpoint(
177
+ repo_id=llm_model,
178
+ huggingfacehub_api_token=api_token,
179
+ temperature=temperature,
180
+ max_new_tokens=max_tokens,
181
+ top_k=top_k
182
+ )
183
+
184
+ progress(0.7, desc="Setting up memory and retriever...")
185
+ memory = ConversationBufferMemory(
186
+ memory_key="chat_history",
187
+ output_key='answer',
188
+ return_messages=True
189
+ )
190
+
191
+ retriever = vector_db.as_retriever()
192
+ qa_chain = ConversationalRetrievalChain.from_llm(
193
+ llm,
194
+ retriever=retriever,
195
+ memory=memory,
196
+ return_source_documents=True
197
+ )
198
+
199
+ progress(1.0, desc="LLM initialized successfully!")
200
+
201
+ success_msg = (f"LLM initialized successfully!\n"
202
+ f"πŸ€– Model: {os.path.basename(llm_model)}\n"
203
+ f"🌑️ Temperature: {temperature}\n"
204
+ f"πŸ“ Max tokens: {max_tokens}\n"
205
+ f"πŸ” Top K: {top_k}")
206
+
207
+ return qa_chain, success_msg
208
+
209
+ except Exception as e:
210
+ return None, f"Error initializing LLM: {str(e)}"
211
+
212
+ def conversation(qa_chain, message, history):
213
+ """Conversation function with CSV-specific source formatting"""
214
+ try:
215
+ response = qa_chain.invoke({
216
+ "question": message,
217
+ "chat_history": [(hist[0], hist[1]) for hist in history]
218
+ })
219
+
220
+ response_answer = response["answer"]
221
+ if "Helpful Answer:" in response_answer:
222
+ response_answer = response_answer.split("Helpful Answer:")[-1].strip()
223
+
224
+ # Get source documents
225
+ sources = response["source_documents"][:3]
226
+ source_contents = []
227
+ source_info = []
228
+
229
+ for source in sources:
230
+ # Format source content for CSV data
231
+ content = source.page_content.strip()
232
+ metadata = source.metadata
233
+
234
+ # Create readable source info for CSV
235
+ source_file = metadata.get("source", "Unknown")
236
+ row_num = metadata.get("row", 0)
237
+
238
+ source_contents.append(content)
239
+ source_info.append(f"File: {source_file} | Row: {row_num}")
240
+
241
+ # Pad with empty values if needed
242
+ while len(source_contents) < 3:
243
+ source_contents.append("")
244
+ source_info.append("No source")
245
+
246
+ return (
247
+ qa_chain,
248
+ gr.update(value=""),
249
+ history + [(message, response_answer)],
250
+ source_contents[0],
251
+ source_info[0],
252
+ source_contents[1],
253
+ source_info[1],
254
+ source_contents[2],
255
+ source_info[2]
256
+ )
257
+
258
+ except Exception as e:
259
+ error_msg = f"Error in conversation: {str(e)}"
260
+ return (
261
+ qa_chain,
262
+ gr.update(value=""),
263
+ history + [(message, error_msg)],
264
+ "", "Error", "", "Error", "", "Error"
265
+ )
266
+
267
+ def demo():
268
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="blue", neutral_hue="slate")) as demo:
269
+ vector_db = gr.State()
270
+ qa_chain = gr.State()
271
+
272
+ gr.HTML("<center><h1>πŸ“Š RAG CSV Chatbot</h1></center>")
273
+ gr.HTML("<center><p>Upload CSV files and chat with your data using advanced RAG techniques</p></center>")
274
+
275
+ with gr.Row():
276
+ with gr.Column(scale=86):
277
+ gr.Markdown("### πŸ“ Step 1 - Configure and Initialize RAG Pipeline")
278
+
279
+ document = gr.Files(
280
+ height=300,
281
+ file_count="multiple",
282
+ file_types=[".csv"],
283
+ interactive=True,
284
+ label="Upload CSV documents",
285
+ elem_id="file_upload"
286
+ )
287
+
288
+ with gr.Row():
289
+ splitting_strategy = gr.Radio(
290
+ ["recursive", "fixed", "token"],
291
+ label="Text Splitting Strategy",
292
+ value="recursive",
293
+ info="How to split CSV data into chunks"
294
+ )
295
+ db_choice = gr.Radio(
296
+ ["faiss", "chroma", "qdrant"],
297
+ label="Vector Database",
298
+ value="faiss",
299
+ info="Vector storage backend"
300
+ )
301
+ chunk_size = gr.Radio(
302
+ ["small", "medium"],
303
+ label="Chunk Size",
304
+ value="medium",
305
+ info="Size of text chunks for processing"
306
+ )
307
+
308
+ with gr.Row():
309
+ db_btn = gr.Button("πŸ”„ Create Vector Database", variant="primary")
310
+
311
+ db_progress = gr.Textbox(
312
+ value="❌ Not initialized - Please upload CSV files and create database",
313
+ show_label=False,
314
+ interactive=False,
315
+ lines=4
316
+ )
317
+
318
+ gr.Markdown("### πŸ€– Step 2 - Configure LLM")
319
+
320
+ llm_choice = gr.Radio(
321
+ list_llm_simple,
322
+ label="Available LLMs",
323
+ value=list_llm_simple[0],
324
+ type="index",
325
+ info="Choose the language model for responses"
326
+ )
327
+
328
+ with gr.Accordion("πŸ”§ LLM Parameters", open=False):
329
+ temperature = gr.Slider(
330
+ minimum=0.01,
331
+ maximum=1.0,
332
+ value=0.5,
333
+ step=0.1,
334
+ label="Temperature",
335
+ info="Controls randomness in responses"
336
+ )
337
+ max_tokens = gr.Slider(
338
+ minimum=128,
339
+ maximum=4096,
340
+ value=2048,
341
+ step=128,
342
+ label="Max Tokens",
343
+ info="Maximum length of generated responses"
344
+ )
345
+ top_k = gr.Slider(
346
+ minimum=1,
347
+ maximum=10,
348
+ value=3,
349
+ step=1,
350
+ label="Top K",
351
+ info="Number of top documents to retrieve"
352
+ )
353
+
354
+ with gr.Row():
355
+ init_llm_btn = gr.Button("πŸš€ Initialize LLM", variant="primary", interactive=False)
356
+
357
+ llm_progress = gr.Textbox(
358
+ value="❌ Not initialized - Please create database first",
359
+ show_label=False,
360
+ interactive=False,
361
+ lines=4
362
+ )
363
+
364
+ with gr.Column(scale=200):
365
+ gr.Markdown("### πŸ’¬ Step 3 - Chat with Your CSV Data")
366
+
367
+ chatbot = gr.Chatbot(
368
+ height=505,
369
+ show_label=False,
370
+ elem_id="chatbot",
371
+ placeholder="Your conversation will appear here after initializing the system..."
372
+ )
373
+
374
+ with gr.Accordion("πŸ“‹ Source References", open=False):
375
+ gr.Markdown("*Top 3 most relevant sources from your CSV data:*")
376
+ with gr.Row():
377
+ with gr.Column():
378
+ source1 = gr.Textbox(label="πŸ“„ Source 1", lines=3, interactive=False)
379
+ info1 = gr.Textbox(label="ℹ️ Source 1 Info", interactive=False)
380
+ with gr.Row():
381
+ with gr.Column():
382
+ source2 = gr.Textbox(label="πŸ“„ Source 2", lines=3, interactive=False)
383
+ info2 = gr.Textbox(label="ℹ️ Source 2 Info", interactive=False)
384
+ with gr.Row():
385
+ with gr.Column():
386
+ source3 = gr.Textbox(label="πŸ“„ Source 3", lines=3, interactive=False)
387
+ info3 = gr.Textbox(label="ℹ️ Source 3 Info", interactive=False)
388
+
389
+ with gr.Row():
390
+ msg = gr.Textbox(
391
+ placeholder="Ask questions about your CSV data... (e.g., 'What are the main trends?', 'Summarize the key findings', 'What patterns do you see?')",
392
+ show_label=False,
393
+ scale=4,
394
+ interactive=False
395
+ )
396
+ submit_btn = gr.Button("πŸ“€ Send", scale=1, interactive=False)
397
+
398
+ with gr.Row():
399
+ clear_btn = gr.ClearButton(
400
+ [msg, chatbot, source1, info1, source2, info2, source3, info3],
401
+ value="πŸ—‘οΈ Clear Chat",
402
+ scale=1
403
+ )
404
+
405
+ gr.Markdown("### πŸ’‘ Tips for Better Results")
406
+ gr.Markdown("""
407
+ - **Ask specific questions** about your data (e.g., "What are the highest values in column X?")
408
+ - **Request summaries** (e.g., "Summarize the key insights from this dataset")
409
+ - **Compare data** (e.g., "Compare categories A and B")
410
+ - **Ask for trends** (e.g., "What patterns do you see over time?")
411
+ """)
412
+
413
+ # Event handlers
414
+ db_btn.click(
415
+ initialize_database,
416
+ inputs=[document, splitting_strategy, chunk_size, db_choice],
417
+ outputs=[vector_db, db_progress]
418
+ ).then(
419
+ lambda x: gr.update(interactive=True) if x is not None else gr.update(interactive=False),
420
+ inputs=[vector_db],
421
+ outputs=[init_llm_btn]
422
+ )
423
+
424
+ init_llm_btn.click(
425
+ initialize_llmchain,
426
+ inputs=[llm_choice, temperature, max_tokens, top_k, vector_db],
427
+ outputs=[qa_chain, llm_progress]
428
+ ).then(
429
+ lambda x: [gr.update(interactive=True), gr.update(interactive=True)] if x is not None else [gr.update(interactive=False), gr.update(interactive=False)],
430
+ inputs=[qa_chain],
431
+ outputs=[msg, submit_btn]
432
+ )
433
+
434
+ msg.submit(
435
+ conversation,
436
+ inputs=[qa_chain, msg, chatbot],
437
+ outputs=[qa_chain, msg, chatbot, source1, info1, source2, info2, source3, info3]
438
+ )
439
+
440
+ submit_btn.click(
441
+ conversation,
442
+ inputs=[qa_chain, msg, chatbot],
443
+ outputs=[qa_chain, msg, chatbot, source1, info1, source2, info2, source3, info3]
444
+ )
445
+
446
+ clear_btn.click(
447
+ lambda: [[], "", "", "", "", "", ""],
448
+ outputs=[chatbot, source1, info1, source2, info2, source3, info3]
449
+ )
450
+
451
+ demo.queue().launch(debug=True)
452
+
453
+ if __name__ == "__main__":
454
+ demo()