JaganathC commited on
Commit
dfc699a
·
verified ·
1 Parent(s): 064caab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -291
app.py CHANGED
@@ -1,342 +1,123 @@
1
- """
2
- PDF-based chatbot with Retrieval-Augmented Generation
3
- """
 
 
 
 
 
 
4
 
5
  import os
6
  import gradio as gr
7
-
8
  from dotenv import load_dotenv
9
-
10
  import indexing
11
  import retrieval
 
12
 
13
-
14
- # default_persist_directory = './chroma_HF/'
15
  list_llm = [
16
  "mistralai/Mistral-7B-Instruct-v0.3",
17
  "microsoft/Phi-3.5-mini-instruct",
18
  "meta-llama/Llama-3.1-8B-Instruct",
19
  "meta-llama/Llama-3.2-3B-Instruct",
20
- "meta-llama/Llama-3.2-1B-Instruct",
21
- "HuggingFaceTB/SmolLM2-1.7B-Instruct",
22
- "HuggingFaceH4/zephyr-7b-beta",
23
- "HuggingFaceH4/zephyr-7b-gemma-v0.1",
24
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
25
  "google/gemma-2-2b-it",
26
- "google/gemma-2-9b-it",
27
- "Qwen/Qwen2.5-1.5B-Instruct",
28
  "Qwen/Qwen2.5-3B-Instruct",
29
- "Qwen/Qwen2.5-7B-Instruct",
30
  ]
31
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
32
 
33
 
34
- # Load environment file - HuggingFace API key
35
  def retrieve_api():
36
- """Retrieve HuggingFace API Key"""
37
- _ = load_dotenv()
38
  global huggingfacehub_api_token
39
  huggingfacehub_api_token = os.environ.get("HUGGINGFACE_API_KEY")
40
 
41
 
42
- # Initialize database
43
- def initialize_database(
44
- list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()
45
- ):
46
- """Initialize database"""
47
-
48
- # Create list of documents (when valid)
49
  list_file_path = [x.name for x in list_file_obj if x is not None]
50
-
51
- # Create collection_name for vector database
52
- progress(0.1, desc="Creating collection name...")
53
  collection_name = indexing.create_collection_name(list_file_path[0])
54
-
55
- progress(0.25, desc="Loading document...")
56
- # Load document and create splits
57
- doc_splits = indexing.load_doc(list_file_path, chunk_size, chunk_overlap)
58
-
59
- # Create or load vector database
60
- progress(0.5, desc="Generating vector database...")
61
-
62
- # global vector_db
63
  vector_db = indexing.create_db(doc_splits, collection_name)
 
64
 
65
- return vector_db, collection_name, "Complete!"
66
 
67
-
68
- # Initialize LLM
69
- def initialize_llm(
70
- llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
71
- ):
72
- """Initialize LLM"""
73
-
74
- # print("llm_option",llm_option)
75
  llm_name = list_llm[llm_option]
76
- print("llm_name: ", llm_name)
77
  qa_chain = retrieval.initialize_llmchain(
78
  llm_name, huggingfacehub_api_token, llm_temperature, max_tokens, top_k, vector_db, progress
79
  )
80
  return qa_chain, "Complete!"
81
 
82
 
83
- # Chatbot conversation
84
  def conversation(qa_chain, message, history):
85
- """Chatbot conversation"""
 
 
86
 
87
- qa_chain, new_history, response_sources = retrieval.invoke_qa_chain(
88
- qa_chain, message, history
89
- )
90
 
91
- # Format output gradio components
92
- response_source1 = response_sources[0].page_content.strip()
93
- response_source2 = response_sources[1].page_content.strip()
94
- response_source3 = response_sources[2].page_content.strip()
95
- # Langchain sources are zero-based
96
- response_source1_page = response_sources[0].metadata["page"] + 1
97
- response_source2_page = response_sources[1].metadata["page"] + 1
98
- response_source3_page = response_sources[2].metadata["page"] + 1
99
 
100
- return (
101
- qa_chain,
102
- gr.update(value=""),
103
- new_history,
104
- response_source1,
105
- response_source1_page,
106
- response_source2,
107
- response_source2_page,
108
- response_source3,
109
- response_source3_page,
110
- )
111
-
112
-
113
- SPACE_TITLE = """
114
- <center><h2>PDF-based chatbot</center></h2>
115
- <h3>Ask any questions about your PDF documents</h3>
116
- """
117
 
118
- SPACE_INFO = """
119
- <b>Description:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
120
- The user interface explicitely shows multiple steps to help understand the RAG workflow.
121
- This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
122
- <br><b>Notes:</b> Updated space with more recent LLM models (Qwen 2.5, Llama 3.2, SmolLM2 series)
123
- <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
124
- """
125
 
126
 
127
- # Gradio User Interface
128
  def gradio_ui():
129
- """Gradio User Interface"""
130
-
131
- with gr.Blocks(theme="base") as demo:
132
  vector_db = gr.State()
133
  qa_chain = gr.State()
134
  collection_name = gr.State()
135
 
136
- gr.Markdown(SPACE_TITLE)
137
- gr.Markdown(SPACE_INFO)
138
-
139
- with gr.Tab("Step 1 - Upload PDF"):
140
- with gr.Row():
141
- document = gr.File(
142
- height=200,
143
- file_count="multiple",
144
- file_types=[".pdf"],
145
- interactive=True,
146
- label="Upload your PDF documents (single or multiple)",
147
- )
148
-
149
- with gr.Tab("Step 2 - Process document"):
150
- with gr.Row():
151
- db_btn = gr.Radio(
152
- ["ChromaDB"],
153
- label="Vector database type",
154
- value="ChromaDB",
155
- type="index",
156
- info="Choose your vector database",
157
- )
158
- with gr.Accordion("Advanced options - Document text splitter", open=False):
159
- with gr.Row():
160
- slider_chunk_size = gr.Slider(
161
- minimum=100,
162
- maximum=1000,
163
- value=600,
164
- step=20,
165
- label="Chunk size",
166
- info="Chunk size",
167
- interactive=True,
168
- )
169
- with gr.Row():
170
- slider_chunk_overlap = gr.Slider(
171
- minimum=10,
172
- maximum=200,
173
- value=40,
174
- step=10,
175
- label="Chunk overlap",
176
- info="Chunk overlap",
177
- interactive=True,
178
- )
179
- with gr.Row():
180
- db_progress = gr.Textbox(
181
- label="Vector database initialization", value="None"
182
- )
183
- with gr.Row():
184
- db_btn = gr.Button("Generate vector database")
185
-
186
- with gr.Tab("Step 3 - Initialize QA chain"):
187
- with gr.Row():
188
- llm_btn = gr.Radio(
189
- list_llm_simple,
190
- label="LLM models",
191
- value=list_llm_simple[6],
192
- type="index",
193
- info="Choose your LLM model",
194
- )
195
- with gr.Accordion("Advanced options - LLM model", open=False):
196
- with gr.Row():
197
- slider_temperature = gr.Slider(
198
- minimum=0.01,
199
- maximum=1.0,
200
- value=0.7,
201
- step=0.1,
202
- label="Temperature",
203
- info="Model temperature",
204
- interactive=True,
205
- )
206
- with gr.Row():
207
- slider_maxtokens = gr.Slider(
208
- minimum=224,
209
- maximum=4096,
210
- value=1024,
211
- step=32,
212
- label="Max Tokens",
213
- info="Model max tokens",
214
- interactive=True,
215
- )
216
- with gr.Row():
217
- slider_topk = gr.Slider(
218
- minimum=1,
219
- maximum=10,
220
- value=3,
221
- step=1,
222
- label="top-k samples",
223
- info="Model top-k samples",
224
- interactive=True,
225
- )
226
- with gr.Row():
227
- llm_progress = gr.Textbox(value="None", label="QA chain initialization")
228
- with gr.Row():
229
- qachain_btn = gr.Button("Initialize Question Answering chain")
230
-
231
- with gr.Tab("Step 4 - Chatbot"):
232
- chatbot = gr.Chatbot(height=300, type="tuples")
233
- with gr.Accordion("Advanced - Document references", open=False):
234
- with gr.Row():
235
- doc_source1 = gr.Textbox(
236
- label="Reference 1", lines=2, container=True, scale=20
237
- )
238
- source1_page = gr.Number(label="Page", scale=1)
239
- with gr.Row():
240
- doc_source2 = gr.Textbox(
241
- label="Reference 2", lines=2, container=True, scale=20
242
- )
243
- source2_page = gr.Number(label="Page", scale=1)
244
- with gr.Row():
245
- doc_source3 = gr.Textbox(
246
- label="Reference 3", lines=2, container=True, scale=20
247
- )
248
- source3_page = gr.Number(label="Page", scale=1)
249
- with gr.Row():
250
- msg = gr.Textbox(
251
- placeholder="Type message (e.g. 'Can you summarize this document in one paragraph?')",
252
- container=True,
253
- )
254
- with gr.Row():
255
- submit_btn = gr.Button("Submit message")
256
- clear_btn = gr.ClearButton(
257
- components=[msg, chatbot], value="Clear conversation"
258
- )
259
-
260
- # Preprocessing events
261
- db_btn.click(
262
- initialize_database,
263
- inputs=[document, slider_chunk_size, slider_chunk_overlap],
264
- outputs=[vector_db, collection_name, db_progress],
265
- )
266
- qachain_btn.click(
267
- initialize_llm,
268
- inputs=[
269
- llm_btn,
270
- slider_temperature,
271
- slider_maxtokens,
272
- slider_topk,
273
- vector_db,
274
- ],
275
- outputs=[qa_chain, llm_progress],
276
- ).then(
277
- lambda: [None, "", 0, "", 0, "", 0],
278
- inputs=None,
279
- outputs=[
280
- chatbot,
281
- doc_source1,
282
- source1_page,
283
- doc_source2,
284
- source2_page,
285
- doc_source3,
286
- source3_page,
287
- ],
288
- queue=False,
289
- )
290
-
291
- # Chatbot events
292
- msg.submit(
293
- conversation,
294
- inputs=[qa_chain, msg, chatbot],
295
- outputs=[
296
- qa_chain,
297
- msg,
298
- chatbot,
299
- doc_source1,
300
- source1_page,
301
- doc_source2,
302
- source2_page,
303
- doc_source3,
304
- source3_page,
305
- ],
306
- queue=False,
307
- )
308
- submit_btn.click(
309
- conversation,
310
- inputs=[qa_chain, msg, chatbot],
311
- outputs=[
312
- qa_chain,
313
- msg,
314
- chatbot,
315
- doc_source1,
316
- source1_page,
317
- doc_source2,
318
- source2_page,
319
- doc_source3,
320
- source3_page,
321
- ],
322
- queue=False,
323
- )
324
- clear_btn.click(
325
- lambda: [None, "", 0, "", 0, "", 0],
326
- inputs=None,
327
- outputs=[
328
- chatbot,
329
- doc_source1,
330
- source1_page,
331
- doc_source2,
332
- source2_page,
333
- doc_source3,
334
- source3_page,
335
- ],
336
- queue=False,
337
- )
338
- demo.queue().launch(debug=True)
339
-
340
 
341
  if __name__ == "__main__":
342
  retrieve_api()
 
1
+ # ✅ Enhanced GenAI Assistant with:
2
+ # - PDF/TXT support
3
+ # - Ask Anything + Challenge Me modes
4
+ # - Auto Summary (<=150 words)
5
+ # - Memory handling
6
+ # - Reference highlighting
7
+ # - Stunning UI (Gradio upgraded)
8
+
9
+ # --- FILE: app.py ---
10
 
11
  import os
12
  import gradio as gr
 
13
  from dotenv import load_dotenv
 
14
  import indexing
15
  import retrieval
16
+ import utils
17
 
 
 
18
  list_llm = [
19
  "mistralai/Mistral-7B-Instruct-v0.3",
20
  "microsoft/Phi-3.5-mini-instruct",
21
  "meta-llama/Llama-3.1-8B-Instruct",
22
  "meta-llama/Llama-3.2-3B-Instruct",
 
 
 
 
23
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
24
  "google/gemma-2-2b-it",
 
 
25
  "Qwen/Qwen2.5-3B-Instruct",
 
26
  ]
27
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
28
 
29
 
 
30
  def retrieve_api():
31
+ load_dotenv()
 
32
  global huggingfacehub_api_token
33
  huggingfacehub_api_token = os.environ.get("HUGGINGFACE_API_KEY")
34
 
35
 
36
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
 
 
 
 
 
 
37
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
 
 
38
  collection_name = indexing.create_collection_name(list_file_path[0])
39
+ doc_splits, full_text = indexing.load_doc(list_file_path, chunk_size, chunk_overlap)
40
+ summary = utils.generate_summary(full_text)
 
 
 
 
 
 
 
41
  vector_db = indexing.create_db(doc_splits, collection_name)
42
+ return vector_db, collection_name, summary, "Complete!"
43
 
 
44
 
45
+ def initialize_llm(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
 
 
 
 
 
 
46
  llm_name = list_llm[llm_option]
 
47
  qa_chain = retrieval.initialize_llmchain(
48
  llm_name, huggingfacehub_api_token, llm_temperature, max_tokens, top_k, vector_db, progress
49
  )
50
  return qa_chain, "Complete!"
51
 
52
 
 
53
  def conversation(qa_chain, message, history):
54
+ qa_chain, new_history, response_sources = retrieval.invoke_qa_chain(qa_chain, message, history)
55
+ highlights = utils.extract_highlight_snippets(response_sources)
56
+ return qa_chain, gr.update(value=""), new_history, *highlights
57
 
 
 
 
58
 
59
+ def challenge_me(qa_chain):
60
+ questions = utils.generate_challenge_questions(qa_chain)
61
+ return questions
 
 
 
 
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ def evaluate_answers(qa_chain, questions, user_answers):
65
+ feedback = utils.evaluate_responses(qa_chain, questions, user_answers)
66
+ return feedback
 
 
 
 
67
 
68
 
 
69
  def gradio_ui():
70
+ with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display:none}") as demo:
 
 
71
  vector_db = gr.State()
72
  qa_chain = gr.State()
73
  collection_name = gr.State()
74
 
75
+ gr.Markdown("""<h1 style='text-align:center;'>📚 GenAI Document Assistant</h1>
76
+ <h3 style='text-align:center;color:gray;'>Smart, interactive reading of research papers, legal docs, and more.</h3>""")
77
+
78
+ with gr.Tab("1️⃣ Upload Document"):
79
+ document = gr.File(label="Upload PDF or TXT", file_types=[".pdf", ".txt"], file_count="multiple")
80
+ slider_chunk_size = gr.Slider(100, 1000, value=600, step=20, label="Chunk Size")
81
+ slider_chunk_overlap = gr.Slider(10, 200, value=40, step=10, label="Chunk Overlap")
82
+ db_progress = gr.Textbox(label="Processing Status")
83
+ summary_box = gr.Textbox(label="Auto Summary (≤ 150 words)", lines=5)
84
+ db_btn = gr.Button("📥 Process Document")
85
+
86
+ with gr.Tab("2️⃣ QA Chain Initialization"):
87
+ llm_btn = gr.Radio(list_llm_simple, label="Select LLM", value=list_llm_simple[0], type="index")
88
+ slider_temperature = gr.Slider(0.01, 1.0, value=0.7, step=0.1, label="Temperature")
89
+ slider_maxtokens = gr.Slider(224, 4096, value=1024, step=32, label="Max Tokens")
90
+ slider_topk = gr.Slider(1, 10, value=3, step=1, label="Top-K")
91
+ llm_progress = gr.Textbox(label="LLM Status")
92
+ qachain_btn = gr.Button("⚙️ Initialize QA Chain")
93
+
94
+ with gr.Tab("3️⃣ Ask Anything"):
95
+ chatbot = gr.Chatbot(height=300)
96
+ msg = gr.Textbox(placeholder="Ask a question from the document...")
97
+ submit_btn = gr.Button("💬 Ask")
98
+ clear_btn = gr.ClearButton([msg, chatbot])
99
+ ref1 = gr.Textbox(label="Reference 1")
100
+ ref2 = gr.Textbox(label="Reference 2")
101
+ ref3 = gr.Textbox(label="Reference 3")
102
+
103
+ with gr.Tab("4️⃣ Challenge Me"):
104
+ challenge_btn = gr.Button("🎯 Generate Questions")
105
+ q1 = gr.Textbox(label="Question 1")
106
+ a1 = gr.Textbox(label="Your Answer 1")
107
+ q2 = gr.Textbox(label="Question 2")
108
+ a2 = gr.Textbox(label="Your Answer 2")
109
+ q3 = gr.Textbox(label="Question 3")
110
+ a3 = gr.Textbox(label="Your Answer 3")
111
+ eval_btn = gr.Button("✅ Submit Answers")
112
+ feedback = gr.Textbox(label="Feedback", lines=5)
113
+
114
+ db_btn.click(initialize_database, [document, slider_chunk_size, slider_chunk_overlap], [vector_db, collection_name, summary_box, db_progress])
115
+ qachain_btn.click(initialize_llm, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress])
116
+ submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, ref1, ref2, ref3])
117
+ challenge_btn.click(challenge_me, [qa_chain], [q1, q2, q3])
118
+ eval_btn.click(evaluate_answers, [qa_chain, [q1, q2, q3], [a1, a2, a3]], [feedback])
119
+
120
+ demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  if __name__ == "__main__":
123
  retrieve_api()