zamalali commited on
Commit
5a8ab80
·
1 Parent(s): 08de48d

Add initial project files, update README, and implement utility functions for PDF and image processing

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +478 -0
  3. packages.txt +4 -0
  4. requirements.txt +15 -0
  5. utils.py +51 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Multimodal Chat PDF
3
- emoji: 🏆
4
  colorFrom: red
5
  colorTo: purple
6
  sdk: gradio
 
1
  ---
2
  title: Multimodal Chat PDF
3
+ emoji: 🏃‍♀️
4
  colorFrom: red
5
  colorTo: purple
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import chromadb
3
+ import gc
4
+ import gradio as gr
5
+ import io
6
+ import numpy as np
7
+ import os
8
+ import pandas as pd
9
+ import pymupdf
10
+ from pypdf import PdfReader
11
+ import spaces
12
+ import torch
13
+ from PIL import Image
14
+ from chromadb.utils import embedding_functions
15
+ from chromadb.utils.data_loaders import ImageLoader
16
+ from doctr.io import DocumentFile
17
+ from doctr.models import ocr_predictor
18
+ from gradio.themes.utils import sizes
19
+ from langchain import PromptTemplate
20
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
21
+ from langchain_community.llms import HuggingFaceEndpoint
22
+ from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
23
+ from utils import *
24
+
25
+
26
+ def result_to_text(result, as_text=False) -> str or list:
27
+ full_doc = []
28
+ for _, page in enumerate(result.pages, start=1):
29
+ text = ""
30
+ for block in page.blocks:
31
+ text += "\n\t"
32
+ for line in block.lines:
33
+ for word in line.words:
34
+ text += word.value + " "
35
+
36
+ full_doc.append(clean_text(text) + "\n\n")
37
+
38
+ return "\n".join(full_doc) if as_text else full_doc
39
+
40
+
41
+ ocr_model = ocr_predictor(
42
+ "db_resnet50",
43
+ "crnn_mobilenet_v3_large",
44
+ pretrained=True,
45
+ assume_straight_pages=True,
46
+ )
47
+
48
+
49
+ if torch.cuda.is_available():
50
+ processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
51
+ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
52
+ "llava-hf/llava-v1.6-mistral-7b-hf",
53
+ torch_dtype=torch.float16,
54
+ low_cpu_mem_usage=True,
55
+ load_in_4bit=True,
56
+ )
57
+
58
+
59
+ @spaces.GPU()
60
+ def get_image_description(image):
61
+ torch.cuda.empty_cache()
62
+ gc.collect()
63
+
64
+ # n = len(prompt)
65
+ prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
66
+
67
+ inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
68
+ output = vision_model.generate(**inputs, max_new_tokens=100)
69
+ return (processor.decode(output[0], skip_special_tokens=True))
70
+
71
+
72
+ CSS = """
73
+ #table_col {background-color: rgb(33, 41, 54);}
74
+ """
75
+
76
+
77
+ # def get_vectordb(text, images, tables):
78
+ def get_vectordb(text, images, img_doc_files):
79
+ client = chromadb.EphemeralClient()
80
+ loader = ImageLoader()
81
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
82
+ model_name="multi-qa-mpnet-base-dot-v1"
83
+ )
84
+ if "text_db" in [i.name for i in client.list_collections()]:
85
+ client.delete_collection("text_db")
86
+ if "image_db" in [i.name for i in client.list_collections()]:
87
+ client.delete_collection("image_db")
88
+
89
+ text_collection = client.get_or_create_collection(
90
+ name="text_db",
91
+ embedding_function=sentence_transformer_ef,
92
+ data_loader=loader,
93
+ )
94
+ image_collection = client.get_or_create_collection(
95
+ name="image_db",
96
+ embedding_function=sentence_transformer_ef,
97
+ data_loader=loader,
98
+ metadata={"hnsw:space": "cosine"},
99
+ )
100
+ descs = []
101
+ for i in range(len(images)):
102
+ try:
103
+ descs.append(img_doc_files[i]+"\n"+get_image_description(images[i]))
104
+ except:
105
+ descs.append("Could not generate image description due to some error")
106
+ gr.Error("Could not generate image descriptions. Your GPU limit may have been exhausted. Please try again after an hour.")
107
+ print(descs[-1])
108
+ print()
109
+
110
+ # image_descriptions = get_image_descriptions(images)
111
+ image_dict = [{"image": image_to_bytes(img)} for img in images]
112
+
113
+ if len(images) > 0:
114
+ image_collection.add(
115
+ ids=[str(i) for i in range(len(images))],
116
+ documents=descs,
117
+ metadatas=image_dict,
118
+ )
119
+
120
+ splitter = RecursiveCharacterTextSplitter(
121
+ chunk_size=500,
122
+ chunk_overlap=10,
123
+ )
124
+
125
+ if len(text.replace(" ", "").replace("\n", "")) == 0:
126
+ gr.Error("No text found in documents")
127
+ else:
128
+ docs = splitter.create_documents([text])
129
+ doc_texts = [i.page_content for i in docs]
130
+ text_collection.add(
131
+ ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts
132
+ )
133
+ return client
134
+
135
+
136
+ def extract_only_text(reader):
137
+ text = ""
138
+ for _, page in enumerate(reader.pages):
139
+ text = page.extract_text()
140
+ return text.strip()
141
+
142
+
143
+ def extract_data_from_pdfs(
144
+ docs, session, include_images, do_ocr, progress=gr.Progress()
145
+ ):
146
+ if len(docs) == 0:
147
+ raise gr.Error("No documents to process")
148
+ progress(0, "Extracting Images")
149
+
150
+ # images = extract_images(docs)
151
+
152
+ progress(0.25, "Extracting Text")
153
+
154
+ all_text = ""
155
+
156
+ images = []
157
+ img_docs=[]
158
+ for doc in docs:
159
+ if do_ocr == "Get Text With OCR":
160
+ pdf_doc = DocumentFile.from_pdf(doc)
161
+ result = ocr_model(pdf_doc)
162
+ all_text += result_to_text(result, as_text=True) + "\n\n"
163
+ else:
164
+ reader = PdfReader(doc)
165
+ all_text += extract_only_text(reader) + "\n\n"
166
+
167
+ if include_images == "Include Images":
168
+ imgs = extract_images([doc])
169
+ images.extend(imgs)
170
+ img_docs.extend([doc.split("/")[-1] for _ in range(len(imgs))])
171
+
172
+ progress(
173
+ 0.6, "Generating image descriptions and inserting everything into vectorDB"
174
+ )
175
+ vectordb = get_vectordb(all_text, images, img_docs)
176
+
177
+ progress(1, "Completed")
178
+ session["processed"] = True
179
+ return (
180
+ vectordb,
181
+ session,
182
+ gr.Row(visible=True),
183
+ all_text[:2000] + "...",
184
+ # display,
185
+ images[:2],
186
+ "<h1 style='text-align: center'>Completed<h1>",
187
+ # image_descriptions
188
+ )
189
+
190
+
191
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
192
+ model_name="multi-qa-mpnet-base-dot-v1"
193
+ )
194
+
195
+
196
+ def conversation(
197
+ vectordb_client,
198
+ msg,
199
+ num_context,
200
+ img_context,
201
+ history,
202
+ temperature,
203
+ max_new_tokens,
204
+ hf_token,
205
+ model_path,
206
+ ):
207
+ if hf_token.strip() != "" and model_path.strip() != "":
208
+ llm = HuggingFaceEndpoint(
209
+ repo_id=model_path,
210
+ temperature=temperature,
211
+ max_new_tokens=max_new_tokens,
212
+ huggingfacehub_api_token=hf_token,
213
+ )
214
+ else:
215
+ llm = HuggingFaceEndpoint(
216
+ repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
217
+ temperature=temperature,
218
+ max_new_tokens=max_new_tokens,
219
+ huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
220
+ )
221
+
222
+ text_collection = vectordb_client.get_collection(
223
+ "text_db", embedding_function=sentence_transformer_ef
224
+ )
225
+ image_collection = vectordb_client.get_collection(
226
+ "image_db", embedding_function=sentence_transformer_ef
227
+ )
228
+
229
+ results = text_collection.query(
230
+ query_texts=[msg], include=["documents"], n_results=num_context
231
+ )["documents"][0]
232
+ similar_images = image_collection.query(
233
+ query_texts=[msg],
234
+ include=["metadatas", "distances", "documents"],
235
+ n_results=img_context,
236
+ )
237
+ img_links = [i["image"] for i in similar_images["metadatas"][0]]
238
+
239
+ images_and_locs = [
240
+ Image.open(io.BytesIO(base64.b64decode(i[1])))
241
+ for i in zip(similar_images["distances"][0], img_links)
242
+ ]
243
+ img_desc = "\n".join(similar_images["documents"][0])
244
+ if len(img_links) == 0:
245
+ img_desc = "No Images Are Provided"
246
+ template = """
247
+ Context:
248
+ {context}
249
+
250
+ Included Images:
251
+ {images}
252
+
253
+ Question:
254
+ {question}
255
+
256
+ Answer:
257
+
258
+ """
259
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
260
+ context = "\n\n".join(results)
261
+ # references = [gr.Textbox(i, visible=True, interactive=False) for i in results]
262
+ response = llm(prompt.format(context=context, question=msg, images=img_desc))
263
+ return history + [(msg, response)], results, images_and_locs
264
+
265
+
266
+ def check_validity_and_llm(session_states):
267
+ if session_states.get("processed", False) == True:
268
+ return gr.Tabs(selected=2)
269
+ raise gr.Error("Please extract data first")
270
+
271
+
272
+
273
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
274
+ vectordb = gr.State()
275
+ doc_collection = gr.State(value=[])
276
+ session_states = gr.State(value={})
277
+ references = gr.State(value=[])
278
+
279
+ gr.Markdown(
280
+ """<h2><center>Multimodal PDF Chatbot</center></h2>
281
+ <h3><center><b>Interact With Your PDF Documents</b></center></h3>"""
282
+ )
283
+ gr.Markdown(
284
+ """<center><h3><b>Note: </b> This application leverages advanced Retrieval-Augmented Generation (RAG) techniques to provide context-aware responses from your PDF documents</center><h3><br>
285
+ <center>Utilizing multimodal capabilities, this chatbot can interpret and answer queries based on both textual and visual information within your PDFs.</center>"""
286
+ )
287
+ gr.Markdown(
288
+ """
289
+ <center><b>Warning: </b> Extracting text and images from your document and generating embeddings may take some time due to the use of OCR and multimodal LLMs for image description<center>
290
+ """
291
+ )
292
+ with gr.Tabs() as tabs:
293
+ with gr.TabItem("Upload PDFs", id=0) as pdf_tab:
294
+ with gr.Row():
295
+ with gr.Column():
296
+ documents = gr.File(
297
+ file_count="multiple",
298
+ file_types=["pdf"],
299
+ interactive=True,
300
+ label="Upload your PDF file/s",
301
+ )
302
+ pdf_btn = gr.Button(value="Next", elem_id="button1")
303
+
304
+ with gr.TabItem("Extract Data", id=1) as preprocess:
305
+ with gr.Row():
306
+ with gr.Column():
307
+ back_p1 = gr.Button(value="Back")
308
+ with gr.Column():
309
+ embed = gr.Button(value="Extract Data")
310
+ with gr.Column():
311
+ next_p1 = gr.Button(value="Next")
312
+ with gr.Row():
313
+ include_images = gr.Radio(
314
+ ["Include Images", "Exclude Images"],
315
+ value="Include Images",
316
+ label="Include/ Exclude Images",
317
+ interactive=True,
318
+ )
319
+ do_ocr = gr.Radio(
320
+ ["Get Text With OCR", "Get Available Text Only"],
321
+ value="Get Text With OCR",
322
+ label="OCR/ No OCR",
323
+ interactive=True,
324
+ )
325
+
326
+ with gr.Row(equal_height=True, variant="panel") as row:
327
+ selected = gr.Dataframe(
328
+ interactive=False,
329
+ col_count=(1, "fixed"),
330
+ headers=["Selected Files"],
331
+ )
332
+ prog = gr.HTML(
333
+ value="<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>"
334
+ )
335
+
336
+ with gr.Accordion("See Parts of Extracted Data", open=False):
337
+ with gr.Column(visible=True) as sample_data:
338
+ with gr.Row():
339
+ with gr.Column():
340
+ ext_text = gr.Textbox(
341
+ label="Sample Extracted Text", lines=15
342
+ )
343
+ with gr.Column():
344
+ images = gr.Gallery(
345
+ label="Sample Extracted Images", columns=1, rows=2
346
+ )
347
+
348
+ with gr.TabItem("Chat", id=2) as chat_tab:
349
+ with gr.Accordion("Config (Advanced) (Optional)", open=False):
350
+ with gr.Row(variant="panel", equal_height=True):
351
+ choice = gr.Radio(
352
+ ["chromaDB"],
353
+ value="chromaDB",
354
+ label="Vector Database",
355
+ interactive=True,
356
+ )
357
+ with gr.Accordion("Use your own model (optional)", open=False):
358
+ hf_token = gr.Textbox(
359
+ label="HuggingFace Token", interactive=True
360
+ )
361
+ model_path = gr.Textbox(label="Model Path", interactive=True)
362
+ with gr.Row(variant="panel", equal_height=True):
363
+ num_context = gr.Slider(
364
+ label="Number of text context elements",
365
+ minimum=1,
366
+ maximum=20,
367
+ step=1,
368
+ interactive=True,
369
+ value=3,
370
+ )
371
+ img_context = gr.Slider(
372
+ label="Number of image context elements",
373
+ minimum=1,
374
+ maximum=10,
375
+ step=1,
376
+ interactive=True,
377
+ value=2,
378
+ )
379
+ with gr.Row(variant="panel", equal_height=True):
380
+ temp = gr.Slider(
381
+ label="Temperature",
382
+ minimum=0.1,
383
+ maximum=1,
384
+ step=0.1,
385
+ interactive=True,
386
+ value=0.4,
387
+ )
388
+ max_tokens = gr.Slider(
389
+ label="Max Tokens",
390
+ minimum=10,
391
+ maximum=2000,
392
+ step=10,
393
+ interactive=True,
394
+ value=500,
395
+ )
396
+ with gr.Row():
397
+ with gr.Column():
398
+ ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
399
+ with gr.Column():
400
+ chatbot = gr.Chatbot(height=400)
401
+ with gr.Accordion("Text References", open=False):
402
+ # text_context = gr.Row()
403
+
404
+ @gr.render(inputs=references)
405
+ def gen_refs(references):
406
+ # print(references)
407
+ n = len(references)
408
+ for i in range(n):
409
+ gr.Textbox(
410
+ label=f"Reference-{i+1}", value=references[i], lines=3
411
+ )
412
+
413
+ with gr.Row():
414
+ msg = gr.Textbox(
415
+ placeholder="Type your question here (e.g. 'What is this document about?')",
416
+ interactive=True,
417
+ container=True,
418
+ )
419
+ with gr.Row():
420
+ submit_btn = gr.Button("Submit message")
421
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
422
+
423
+ pdf_btn.click(
424
+ fn=extract_pdfs,
425
+ inputs=[documents, doc_collection],
426
+ outputs=[doc_collection, tabs, selected],
427
+ )
428
+ embed.click(
429
+ extract_data_from_pdfs,
430
+ inputs=[doc_collection, session_states, include_images, do_ocr],
431
+ outputs=[
432
+ vectordb,
433
+ session_states,
434
+ sample_data,
435
+ ext_text,
436
+ images,
437
+ prog,
438
+ ],
439
+ )
440
+
441
+ submit_btn.click(
442
+ conversation,
443
+ [
444
+ vectordb,
445
+ msg,
446
+ num_context,
447
+ img_context,
448
+ chatbot,
449
+ temp,
450
+ max_tokens,
451
+ hf_token,
452
+ model_path,
453
+ ],
454
+ [chatbot, references, ret_images],
455
+ )
456
+ msg.submit(
457
+ conversation,
458
+ [
459
+ vectordb,
460
+ msg,
461
+ num_context,
462
+ img_context,
463
+ chatbot,
464
+ temp,
465
+ max_tokens,
466
+ hf_token,
467
+ model_path,
468
+ ],
469
+ [chatbot, references, ret_images],
470
+ )
471
+
472
+ documents.change(lambda: "<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>", None, prog)
473
+
474
+ back_p1.click(lambda: gr.Tabs(selected=0), None, tabs)
475
+
476
+ next_p1.click(check_validity_and_llm, session_states, tabs)
477
+ if __name__ == "__main__":
478
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ poppler-utils
2
+ tesseract-ocr
3
+ libtesseract-dev
4
+ ghostscript
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chromadb==0.5.3
2
+ langchain==0.2.5
3
+ langchain_community==0.2.5
4
+ langchain-huggingface
5
+ numpy<2.0.0
6
+ pandas==2.2.2
7
+ Pillow==10.3.0
8
+ pymupdf==1.24.5
9
+ sentence_transformers==3.0.1
10
+ accelerate
11
+ bitsandbytes
12
+ tf2onnx
13
+ clean-text[gpl]
14
+ python-doctr[torch]
15
+ pypdf
utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pymupdf
2
+ from PIL import Image
3
+ import io
4
+ import gradio as gr
5
+ import base64
6
+ import pandas as pd
7
+
8
+ def image_to_bytes(image):
9
+ img_byte_arr = io.BytesIO()
10
+ image.save(img_byte_arr, format="PNG")
11
+ return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
12
+
13
+
14
+ def extract_pdfs(docs, doc_collection):
15
+ if docs:
16
+ doc_collection = []
17
+ doc_collection.extend(docs)
18
+ return (
19
+ doc_collection,
20
+ gr.Tabs(selected=1),
21
+ pd.DataFrame([i.split("/")[-1] for i in list(docs)], columns=["Filename"]),
22
+ )
23
+
24
+
25
+ def extract_images(docs):
26
+ images = []
27
+ for doc_path in docs:
28
+ doc = pymupdf.open(doc_path)
29
+
30
+ for page_index in range(len(doc)):
31
+ page = doc[page_index]
32
+ image_list = page.get_images()
33
+
34
+ for _, img in enumerate(image_list, start=1):
35
+ xref = img[0]
36
+ pix = pymupdf.Pixmap(doc, xref)
37
+
38
+ if pix.n - pix.alpha > 3:
39
+ pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
40
+
41
+ images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))
42
+ return images
43
+
44
+
45
+ def clean_text(text):
46
+ text = text.strip()
47
+ cleaned_text = text.replace("\n", " ")
48
+ cleaned_text = cleaned_text.replace("\t", " ")
49
+ cleaned_text = cleaned_text.replace(" ", " ")
50
+ cleaned_text = cleaned_text.strip()
51
+ return cleaned_text