merterbak commited on
Commit
8a498c0
·
verified ·
1 Parent(s): 910824c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +562 -0
app.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import PyPDF2
4
+ import logging
5
+ import torch
6
+ import threading
7
+ import time
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ TextIteratorStreamer,
12
+ StoppingCriteria,
13
+ StoppingCriteriaList,
14
+ )
15
+ from transformers import logging as hf_logging
16
+ import spaces
17
+ from llama_index.core import (
18
+ StorageContext,
19
+ VectorStoreIndex,
20
+ load_index_from_storage,
21
+ Document as LlamaDocument,
22
+ )
23
+ from llama_index.core import Settings
24
+ from llama_index.core.node_parser import (
25
+ HierarchicalNodeParser,
26
+ get_leaf_nodes,
27
+ get_root_nodes,
28
+ )
29
+ from llama_index.core.retrievers import AutoMergingRetriever
30
+ from llama_index.core.storage.docstore import SimpleDocumentStore
31
+ from llama_index.llms.huggingface import HuggingFaceLLM
32
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
33
+ from tqdm import tqdm
34
+
35
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
36
+ logging.basicConfig(level=logging.INFO)
37
+ logger = logging.getLogger(__name__)
38
+ hf_logging.set_verbosity_error()
39
+
40
+ MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
41
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
42
+ HF_TOKEN = os.environ.get("HF_TOKEN")
43
+ if not HF_TOKEN:
44
+ raise ValueError("HF_TOKEN not found in environment variables")
45
+
46
+ # Custom UI
47
+ TITLE = "<h1><center>Multi-Document RAG with LLama 3.1-8B Model</center></h1>"
48
+ DESCRIPTION = """
49
+ <center>
50
+ <p>Upload PDF or text files to get started!</p>
51
+ <p>After asking question wait for RAG system to get relevant nodes and passed to LLM</p>
52
+ </center>
53
+ """
54
+ CSS = """
55
+ .upload-section {
56
+ max-width: 400px;
57
+ margin: 0 auto;
58
+ padding: 10px;
59
+ border: 2px dashed #ccc;
60
+ border-radius: 10px;
61
+ }
62
+ .upload-button {
63
+ background: #34c759 !important;
64
+ color: white !important;
65
+ border-radius: 25px !important;
66
+ }
67
+ .chatbot-container {
68
+ margin-top: 20px;
69
+ }
70
+ .status-output {
71
+ margin-top: 10px;
72
+ font-size: 14px;
73
+ }
74
+ .processing-info {
75
+ margin-top: 5px;
76
+ font-size: 12px;
77
+ color: #666;
78
+ }
79
+ .info-container {
80
+ margin-top: 10px;
81
+ padding: 10px;
82
+ border-radius: 5px;
83
+ }
84
+ .file-list {
85
+ margin-top: 0;
86
+ max-height: 200px;
87
+ overflow-y: auto;
88
+ padding: 5px;
89
+ border: 1px solid #eee;
90
+ border-radius: 5px;
91
+ }
92
+ .stats-box {
93
+ margin-top: 10px;
94
+ padding: 10px;
95
+ border-radius: 5px;
96
+ font-size: 12px;
97
+ }
98
+ .submit-btn {
99
+ background: #1a73e8 !important;
100
+ color: white !important;
101
+ border-radius: 25px !important;
102
+ margin-left: 10px;
103
+ padding: 5px 10px;
104
+ font-size: 16px;
105
+ }
106
+ .input-row {
107
+ display: flex;
108
+ align-items: center;
109
+ }
110
+ @media (min-width: 768px) {
111
+ .main-container {
112
+ display: flex;
113
+ justify-content: space-between;
114
+ gap: 20px;
115
+ }
116
+ .upload-section {
117
+ flex: 1;
118
+ max-width: 300px;
119
+ }
120
+ .chatbot-container {
121
+ flex: 2;
122
+ margin-top: 0;
123
+ }
124
+ }
125
+ """
126
+
127
+ global_model = None
128
+ global_tokenizer = None
129
+ global_file_info = {}
130
+
131
+ def initialize_model_and_tokenizer():
132
+ global global_model, global_tokenizer
133
+ if global_model is None or global_tokenizer is None:
134
+ logger.info("Initializing model and tokenizer...")
135
+ global_tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN)
136
+ global_model = AutoModelForCausalLM.from_pretrained(
137
+ MODEL,
138
+ device_map="auto",
139
+ trust_remote_code=True,
140
+ token=HF_TOKEN,
141
+ torch_dtype=torch.float16
142
+ )
143
+ logger.info("Model and tokenizer initialized successfully")
144
+
145
+ def get_llm(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
146
+ global global_model, global_tokenizer
147
+ if global_model is None or global_tokenizer is None:
148
+ initialize_model_and_tokenizer()
149
+
150
+ return HuggingFaceLLM(
151
+ context_window=4096,
152
+ max_new_tokens=max_new_tokens,
153
+ tokenizer=global_tokenizer,
154
+ model=global_model,
155
+ generate_kwargs={
156
+ "do_sample": True,
157
+ "temperature": temperature,
158
+ "top_k": top_k,
159
+ "top_p": top_p
160
+ }
161
+ )
162
+
163
+ def extract_text_from_document(file):
164
+ file_name = file.name
165
+ file_extension = os.path.splitext(file_name)[1].lower()
166
+
167
+ if file_extension == '.txt':
168
+ text = file.read().decode('utf-8')
169
+ return text, len(text.split()), None
170
+ elif file_extension == '.pdf':
171
+ pdf_reader = PyPDF2.PdfReader(file)
172
+ text = "\n\n".join(page.extract_text() for page in pdf_reader.pages)
173
+ return text, len(text.split()), None
174
+ else:
175
+ return None, 0, ValueError(f"Unsupported file format: {file_extension}")
176
+
177
+ @spaces.GPU()
178
+ def create_or_update_index(files, request: gr.Request):
179
+ global global_file_info
180
+
181
+ if not files:
182
+ return "Please provide files.", ""
183
+
184
+ start_time = time.time()
185
+ user_id = request.session_hash
186
+ save_dir = f"./{user_id}_index"
187
+ # Initialize LlamaIndex modules
188
+ llm = get_llm()
189
+ embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
190
+ Settings.llm = llm
191
+ Settings.embed_model = embed_model
192
+ file_stats = []
193
+ new_documents = []
194
+
195
+ for file in tqdm(files, desc="Processing files"):
196
+ file_basename = os.path.basename(file.name)
197
+ text, word_count, error = extract_text_from_document(file)
198
+ if error:
199
+ logger.error(f"Error processing file {file_basename}: {str(error)}")
200
+ file_stats.append({
201
+ "name": file_basename,
202
+ "words": 0,
203
+ "status": f"error: {str(error)}"
204
+ })
205
+ continue
206
+
207
+ doc = LlamaDocument(
208
+ text=text,
209
+ metadata={
210
+ "file_name": file_basename,
211
+ "word_count": word_count,
212
+ "source": "user_upload"
213
+ }
214
+ )
215
+ new_documents.append(doc)
216
+
217
+ file_stats.append({
218
+ "name": file_basename,
219
+ "words": word_count,
220
+ "status": "processed"
221
+ })
222
+
223
+ global_file_info[file_basename] = {
224
+ "word_count": word_count,
225
+ "processed_at": time.time()
226
+ }
227
+
228
+ node_parser = HierarchicalNodeParser.from_defaults(
229
+ chunk_sizes=[2048, 512, 128],
230
+ chunk_overlap=20
231
+ )
232
+ logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes")
233
+ new_nodes = node_parser.get_nodes_from_documents(new_documents)
234
+ new_leaf_nodes = get_leaf_nodes(new_nodes)
235
+ new_root_nodes = get_root_nodes(new_nodes)
236
+ logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
237
+ node_ancestry = {}
238
+ for node in new_nodes:
239
+ if hasattr(node, 'metadata') and 'file_name' in node.metadata:
240
+ file_origin = node.metadata['file_name']
241
+ if file_origin not in node_ancestry:
242
+ node_ancestry[file_origin] = 0
243
+ node_ancestry[file_origin] += 1
244
+
245
+ if os.path.exists(save_dir):
246
+ logger.info(f"Loading existing index from {save_dir}")
247
+ storage_context = StorageContext.from_defaults(persist_dir=save_dir)
248
+ index = load_index_from_storage(storage_context, settings=Settings)
249
+ docstore = storage_context.docstore
250
+
251
+ docstore.add_documents(new_nodes)
252
+ for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"):
253
+ index.insert_nodes([node])
254
+
255
+ total_docs = len(docstore.docs)
256
+ logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files")
257
+ else:
258
+ logger.info("Creating new index")
259
+ docstore = SimpleDocumentStore()
260
+ storage_context = StorageContext.from_defaults(docstore=docstore)
261
+ docstore.add_documents(new_nodes)
262
+
263
+ index = VectorStoreIndex(
264
+ new_leaf_nodes,
265
+ storage_context=storage_context,
266
+ settings=Settings
267
+ )
268
+ total_docs = len(new_documents)
269
+ logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files")
270
+
271
+ index.storage_context.persist(persist_dir=save_dir)
272
+ # custom outputs after processing files
273
+ file_list_html = "<div class='file-list'>"
274
+ for stat in file_stats:
275
+ status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336"
276
+ file_list_html += f"<div><span style='color:{status_color}'>●</span> {stat['name']} - {stat['words']} words</div>"
277
+ file_list_html += "</div>"
278
+ processing_time = time.time() - start_time
279
+ stats_output = f"<div class='stats-box'>"
280
+ stats_output += f"✓ Processed {len(files)} files in {processing_time:.2f} seconds<br>"
281
+ stats_output += f"✓ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)<br>"
282
+ stats_output += f"✓ Total documents in index: {total_docs}<br>"
283
+ stats_output += f"✓ Index saved to: {save_dir}<br>"
284
+ stats_output += "</div>"
285
+ output_container = f"<div class='info-container'>"
286
+ output_container += file_list_html
287
+ output_container += stats_output
288
+ output_container += "</div>"
289
+ return f"Successfully indexed {len(files)} files.", output_container
290
+
291
+ @spaces.GPU()
292
+ def stream_chat(
293
+ message: str,
294
+ history: list,
295
+ system_prompt: str,
296
+ temperature: float,
297
+ max_new_tokens: int,
298
+ top_p: float,
299
+ top_k: int,
300
+ penalty: float,
301
+ retriever_k: int,
302
+ merge_threshold: float,
303
+ request: gr.Request
304
+ ):
305
+ if not request:
306
+ yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}]
307
+ return
308
+ user_id = request.session_hash
309
+ index_dir = f"./{user_id}_index"
310
+ if not os.path.exists(index_dir):
311
+ yield history + [{"role": "assistant", "content": "Please upload documents first."}]
312
+ return
313
+
314
+ max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 1024
315
+ temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.9
316
+ top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95
317
+ top_k = int(top_k) if isinstance(top_k, (int, float)) else 50
318
+ penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2
319
+ retriever_k = int(retriever_k) if isinstance(retriever_k, (int, float)) else 15
320
+ merge_threshold = float(merge_threshold) if isinstance(merge_threshold, (int, float)) else 0.5
321
+ llm = get_llm(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k)
322
+ embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
323
+ Settings.llm = llm
324
+ Settings.embed_model = embed_model
325
+ storage_context = StorageContext.from_defaults(persist_dir=index_dir)
326
+ index = load_index_from_storage(storage_context, settings=Settings)
327
+ base_retriever = index.as_retriever(similarity_top_k=retriever_k)
328
+ auto_merging_retriever = AutoMergingRetriever(
329
+ base_retriever,
330
+ storage_context=storage_context,
331
+ simple_ratio_thresh=merge_threshold,
332
+ verbose=True
333
+ )
334
+ logger.info(f"Query: {message}")
335
+ retrieval_start = time.time()
336
+ base_nodes = base_retriever.retrieve(message)
337
+ logger.info(f"Retrieved {len(base_nodes)} base nodes in {time.time() - retrieval_start:.2f}s")
338
+ base_file_sources = {}
339
+ for node in base_nodes:
340
+ if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
341
+ file_name = node.node.metadata['file_name']
342
+ if file_name not in base_file_sources:
343
+ base_file_sources[file_name] = 0
344
+ base_file_sources[file_name] += 1
345
+ logger.info(f"Base retrieval file distribution: {base_file_sources}")
346
+ merging_start = time.time()
347
+ merged_nodes = auto_merging_retriever.retrieve(message)
348
+ logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - merging_start:.2f}s")
349
+ merged_file_sources = {}
350
+ for node in merged_nodes:
351
+ if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
352
+ file_name = node.node.metadata['file_name']
353
+ if file_name not in merged_file_sources:
354
+ merged_file_sources[file_name] = 0
355
+ merged_file_sources[file_name] += 1
356
+ logger.info(f"Merged retrieval file distribution: {merged_file_sources}")
357
+ context = "\n\n".join([n.node.text for n in merged_nodes])
358
+ source_info = ""
359
+ if merged_file_sources:
360
+ source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys())
361
+ formatted_system_prompt = f"{system_prompt}\n\nDocument Context:\n{context}{source_info}"
362
+ messages = [{"role": "system", "content": formatted_system_prompt}]
363
+ for entry in history:
364
+ messages.append(entry)
365
+ messages.append({"role": "user", "content": message})
366
+ prompt = global_tokenizer.apply_chat_template(
367
+ messages,
368
+ tokenize=False,
369
+ add_generation_prompt=True
370
+ )
371
+ stop_event = threading.Event()
372
+ class StopOnEvent(StoppingCriteria):
373
+ def __init__(self, stop_event):
374
+ super().__init__()
375
+ self.stop_event = stop_event
376
+
377
+ def __call__(self, input_ids, scores, **kwargs):
378
+ return self.stop_event.is_set()
379
+ stopping_criteria = StoppingCriteriaList([StopOnEvent(stop_event)])
380
+ streamer = TextIteratorStreamer(
381
+ global_tokenizer,
382
+ skip_prompt=True,
383
+ skip_special_tokens=True
384
+ )
385
+ inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
386
+ generation_kwargs = dict(
387
+ inputs,
388
+ streamer=streamer,
389
+ max_new_tokens=max_new_tokens,
390
+ temperature=temperature,
391
+ top_p=top_p,
392
+ top_k=top_k,
393
+ repetition_penalty=penalty,
394
+ do_sample=True,
395
+ stopping_criteria=stopping_criteria
396
+ )
397
+ thread = threading.Thread(target=global_model.generate, kwargs=generation_kwargs)
398
+ thread.start()
399
+ updated_history = history + [
400
+ {"role": "user", "content": message},
401
+ {"role": "assistant", "content": ""}
402
+ ]
403
+ yield updated_history
404
+ partial_response = ""
405
+ try:
406
+ for new_text in streamer:
407
+ partial_response += new_text
408
+ updated_history[-1]["content"] = partial_response
409
+ yield updated_history
410
+ output_ids = global_tokenizer.encode(partial_response, return_tensors="pt")
411
+ yield updated_history
412
+ except GeneratorExit:
413
+ stop_event.set()
414
+ thread.join()
415
+ raise
416
+
417
+ def create_demo():
418
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
419
+ gr.HTML(TITLE)
420
+ gr.HTML(DESCRIPTION)
421
+
422
+ with gr.Row(elem_classes="main-container"):
423
+ with gr.Column(elem_classes="upload-section"):
424
+ file_upload = gr.File(
425
+ file_count="multiple",
426
+ label="Drag and Drop Files Here",
427
+ file_types=[".pdf", ".txt"],
428
+ elem_id="file-upload"
429
+ )
430
+ upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
431
+ status_output = gr.Textbox(
432
+ label="Status",
433
+ placeholder="Upload files to start...",
434
+ interactive=False
435
+ )
436
+ file_info_output = gr.HTML(
437
+ label="File Information",
438
+ elem_classes="processing-info"
439
+ )
440
+ upload_button.click(
441
+ fn=create_or_update_index,
442
+ inputs=[file_upload],
443
+ outputs=[status_output, file_info_output]
444
+ )
445
+
446
+ with gr.Column(elem_classes="chatbot-container"):
447
+ chatbot = gr.Chatbot(
448
+ height=500,
449
+ placeholder="Chat with your documents here... Type your question below.",
450
+ show_label=False,
451
+ type="messages"
452
+ )
453
+ with gr.Row(elem_classes="input-row"):
454
+ message_input = gr.Textbox(
455
+ placeholder="Type your question here...",
456
+ show_label=False,
457
+ container=False,
458
+ lines=1,
459
+ scale=8
460
+ )
461
+ submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
462
+
463
+ with gr.Accordion("Advanced Settings", open=False):
464
+ system_prompt = gr.Textbox(
465
+ value="As a knowledgeable assistant, your task is to provide detailed and context-rich answers based on the relevant information from all uploaded documents. When information is sourced from multiple documents, summarize the key points from each and explain how they relate, noting any connections or contradictions. Your response should be thorough, informative, and easy to understand.",
466
+ label="System Prompt",
467
+ lines=3
468
+ )
469
+
470
+ with gr.Tab("Generation Parameters"):
471
+ temperature = gr.Slider(
472
+ minimum=0,
473
+ maximum=1,
474
+ step=0.1,
475
+ value=0.9,
476
+ label="Temperature"
477
+ )
478
+ max_new_tokens = gr.Slider(
479
+ minimum=128,
480
+ maximum=8192,
481
+ step=64,
482
+ value=1024,
483
+ label="Max New Tokens",
484
+ )
485
+ top_p = gr.Slider(
486
+ minimum=0.0,
487
+ maximum=1.0,
488
+ step=0.1,
489
+ value=0.95,
490
+ label="Top P"
491
+ )
492
+ top_k = gr.Slider(
493
+ minimum=1,
494
+ maximum=100,
495
+ step=1,
496
+ value=50,
497
+ label="Top K"
498
+ )
499
+ penalty = gr.Slider(
500
+ minimum=0.0,
501
+ maximum=2.0,
502
+ step=0.1,
503
+ value=1.2,
504
+ label="Repetition Penalty"
505
+ )
506
+
507
+ with gr.Tab("Retrieval Parameters"):
508
+ retriever_k = gr.Slider(
509
+ minimum=5,
510
+ maximum=30,
511
+ step=1,
512
+ value=15,
513
+ label="Initial Retrieval Size (Top K)"
514
+ )
515
+ merge_threshold = gr.Slider(
516
+ minimum=0.1,
517
+ maximum=0.9,
518
+ step=0.1,
519
+ value=0.5,
520
+ label="Merge Threshold (lower = more merging)"
521
+ )
522
+
523
+ submit_button.click(
524
+ fn=stream_chat,
525
+ inputs=[
526
+ message_input,
527
+ chatbot,
528
+ system_prompt,
529
+ temperature,
530
+ max_new_tokens,
531
+ top_p,
532
+ top_k,
533
+ penalty,
534
+ retriever_k,
535
+ merge_threshold
536
+ ],
537
+ outputs=chatbot
538
+ )
539
+
540
+ message_input.submit(
541
+ fn=stream_chat,
542
+ inputs=[
543
+ message_input,
544
+ chatbot,
545
+ system_prompt,
546
+ temperature,
547
+ max_new_tokens,
548
+ top_p,
549
+ top_k,
550
+ penalty,
551
+ retriever_k,
552
+ merge_threshold
553
+ ],
554
+ outputs=chatbot
555
+ )
556
+
557
+ return demo
558
+
559
+ if __name__ == "__main__":
560
+ initialize_model_and_tokenizer()
561
+ demo = create_demo()
562
+ demo.launch()