Muzammil6376 commited on
Commit
225229c
Β·
verified Β·
1 Parent(s): 2a4ba68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -133
app.py CHANGED
@@ -2,44 +2,52 @@ import os
2
  import shutil
3
  from typing import List
4
 
 
5
  import gradio as gr
6
  from PIL import Image
7
 
8
- # Unstructured for rich PDF parsing
9
  from unstructured.partition.pdf import partition_pdf
10
  from unstructured.partition.utils.constants import PartitionStrategy
11
 
12
  # Vision-language captioning (BLIP)
13
- from transformers import BlipProcessor, BlipForConditionalGeneration
14
 
15
- # Hugging Face Inference client
16
  from huggingface_hub import InferenceClient
17
 
18
  # FAISS vectorstore
19
- from langchain.vectorstores.faiss import FAISS
 
 
 
20
 
21
  # ── Globals ───────────────────────────────────────────────────────────────────
22
- retriever = None # FAISS retriever for multimodal content
23
- current_pdf_name = None # Name of the currently loaded PDF
24
- combined_texts: List[str] = [] # Combined text + image captions corpus
25
- pdf_text: str = "" # Full PDF text for summary/keywords
 
26
 
27
- # ── Setup: directories ─────────────────────────────────────────────────────────
28
  FIGURES_DIR = "figures"
29
  if os.path.exists(FIGURES_DIR):
30
  shutil.rmtree(FIGURES_DIR)
31
- os.makedirs(FIGURES_DIR, exist_ok=True)
 
32
 
33
  # ── Clients & Models ───────────────────────────────────────────────────────────
34
- hf = InferenceClient() # uses HUGGINGFACEHUB_API_TOKEN env var
 
35
 
36
- # BLIP captioner
37
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
38
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
 
 
 
39
 
40
 
41
  def generate_caption(image_path: str) -> str:
42
- """Generate caption for image via BLIP."""
43
  image = Image.open(image_path).convert("RGB")
44
  inputs = blip_processor(image, return_tensors="pt")
45
  out = blip_model.generate(**inputs)
@@ -47,159 +55,118 @@ def generate_caption(image_path: str) -> str:
47
 
48
 
49
  def embed_texts(texts: List[str]) -> List[List[float]]:
50
- """Call HF inference embeddings endpoint."""
51
- resp = hf.embeddings(model="google/Gemma-Embeddings-v1.0", inputs=texts)
52
- return resp["embeddings"]
53
 
54
 
55
- def process_pdf(pdf_file):
56
- """
57
- Reads & extracts text and images from the PDF, captions images,
58
- splits & embeds chunks, builds FAISS index, and stores full text.
59
- Returns filename, status, and enables Q&A box.
60
- """
61
- global retriever, current_pdf_name, combined_texts, pdf_text
 
 
 
62
 
 
 
63
  if pdf_file is None:
64
  return None, "❌ Please upload a PDF file.", gr.update(interactive=False)
65
 
66
  current_pdf_name = os.path.basename(pdf_file.name)
67
- # extract full text for summary/keywords
68
  from pypdf import PdfReader
69
  reader = PdfReader(pdf_file.name)
70
  pages = [page.extract_text() or "" for page in reader.pages]
71
  pdf_text = "\n\n".join(pages)
72
 
73
- # parse with unstructured for images
74
  try:
75
- elements = partition_pdf(
76
  filename=pdf_file.name,
77
  strategy=PartitionStrategy.HI_RES,
78
- extract_image_block_types=["Image", "Table"],
79
  extract_image_block_output_dir=FIGURES_DIR,
80
  )
81
- text_elements = [el.text for el in elements if el.category not in ["Image","Table"] and el.text]
82
- image_files = [os.path.join(FIGURES_DIR, f) for f in os.listdir(FIGURES_DIR)
83
- if f.lower().endswith((".png",".jpg",".jpeg"))]
84
- except Exception:
85
- text_elements = pages
86
- image_files = []
87
-
88
- captions = [generate_caption(img) for img in image_files]
89
- # split text elements into chunks
90
  from langchain.text_splitter import CharacterTextSplitter
91
  splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
92
  chunks = []
93
- for t in text_elements:
94
  chunks.extend(splitter.split_text(t))
95
- combined_texts = chunks + captions
96
-
97
- vectors = embed_texts(combined_texts)
98
- index = FAISS.from_embeddings(texts=combined_texts, embeddings=vectors)
 
 
 
 
 
 
 
 
 
99
  retriever = index.as_retriever(search_kwargs={"k":2})
100
-
101
- status = f"βœ… Indexed '{current_pdf_name}' β€” {len(chunks)} text chunks + {len(captions)} image captions"
102
  return current_pdf_name, status, gr.update(interactive=True)
103
 
104
 
105
- def ask_question(pdf_name, question):
106
- """Retrieve relevant chunks and generate answer via remote LLM."""
107
  global retriever
108
  if retriever is None:
109
- return "❌ Please upload and index a PDF first."
110
  if not question.strip():
111
- return "❌ Please enter a question."
112
-
113
  docs = retriever.get_relevant_documents(question)
114
- context = "\n\n".join(doc.page_content for doc in docs)
115
- prompt = (
116
- "Use the following document excerpts to answer the question.\n\n"
117
- f"{context}\n\nQuestion: {question}\nAnswer:"
118
- )
119
- response = hf.chat_completion(
120
- model="google/gemma-3-27b-it",
121
- messages=[{"role":"user","content":prompt}],
122
- max_tokens=128,
123
- temperature=0.5,
124
- )
125
- return response["choices"][0]["message"]["content"].strip()
126
-
127
-
128
- def generate_summary():
129
- """Ask remote LLM for concise summary using full text."""
130
- if not pdf_text:
131
- return "❌ Please upload and index a PDF first."
132
- ctx = pdf_text[:2000]
133
- resp = hf.chat_completion(
134
- model="google/gemma-3-27b-it",
135
- messages=[{"role":"user","content":f"Summarize concisely:\n\n{ctx}..."}],
136
- max_tokens=150,
137
- temperature=0.5,
138
- )
139
- return resp["choices"][0]["message"]["content"].strip()
140
-
141
-
142
- def extract_keywords():
143
- """Ask remote LLM to extract key terms from full text."""
144
- if not pdf_text:
145
- return "❌ Please upload and index a PDF first."
146
- ctx = pdf_text[:2000]
147
- resp = hf.chat_completion(
148
- model="google/gemma-3-27b-it",
149
- messages=[{"role":"user","content":f"Extract 10-15 key terms:\n\n{ctx}..."}],
150
- max_tokens=60,
151
- temperature=0.5,
152
- )
153
- return resp["choices"][0]["message"]["content"].strip()
154
 
155
 
 
 
 
 
156
  def clear_interface():
157
- """Reset state and clear extracted images."""
158
- global retriever, current_pdf_name, combined_texts, pdf_text
159
- retriever = None
160
- current_pdf_name = None
161
- combined_texts = []
162
- pdf_text = ""
163
- shutil.rmtree(FIGURES_DIR, ignore_errors=True)
164
- os.makedirs(FIGURES_DIR, exist_ok=True)
165
  return None, "", gr.update(interactive=False)
166
 
167
- # ── Gradio UI ────────────────────────────────────────────────────────────────
168
- theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue")
169
- with gr.Blocks(theme=theme, css="""
170
- .container { border-radius: 10px; padding: 15px; }
171
- .pdf-active { border-left: 3px solid #6366f1; padding-left: 10px; background-color: rgba(99,102,241,0.1); }
172
- .footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #666; }
173
- .main-title { text-align: center; font-size: 64px; font-weight: bold; margin-bottom: 20px; }
174
- """) as demo:
175
- gr.Markdown("<div class='main-title'>DocQueryAI (Multimodal RAG)</div>")
176
  with gr.Row():
177
  with gr.Column():
178
- gr.Markdown("## πŸ“„ Document Input")
179
- pdf_display = gr.Textbox(label="Active Document", interactive=False, elem_classes="pdf-active")
180
- pdf_file = gr.File(file_types=[".pdf"], type="filepath")
181
- upload_button = gr.Button("πŸ“€ Process Document", variant="primary")
182
- status_box = gr.Textbox(label="Status", interactive=False)
183
  with gr.Column():
184
- gr.Markdown("## ❓ Ask Questions")
185
- question_input = gr.Textbox(lines=3, placeholder="Enter your question here…", interactive=False)
186
- ask_button = gr.Button("πŸ” Ask Question", variant="primary", interactive=False)
187
- answer_output = gr.Textbox(label="Answer", lines=8, interactive=False)
188
- with gr.Row():
189
- summary_button = gr.Button("πŸ“‹ Generate Summary", variant="secondary", interactive=False)
190
- summary_output = gr.Textbox(label="Summary", lines=4, interactive=False)
191
- keywords_button = gr.Button("🏷️ Extract Keywords", variant="secondary", interactive=False)
192
- keywords_output = gr.Textbox(label="Keywords", lines=4, interactive=False)
193
- clear_button = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
194
- gr.Markdown("<div class='footer'>Powered by HF Inference + FAISS + BLIP | Gradio</div>")
195
-
196
- upload_button.click(process_pdf, [pdf_file], [pdf_display, status_box, question_input])
197
- ask_button.click(ask_question, [pdf_display, question_input], answer_output)
198
- summary_button.click(generate_summary, [], summary_output)
199
- keywords_button.click(extract_keywords, [], keywords_output)
200
- clear_button.click(clear_interface, [], [pdf_display, status_box, question_input])
201
-
202
- if __name__ == "__main__":
203
- demo.launch(debug=True)
204
-
205
-
 
2
  import shutil
3
  from typing import List
4
 
5
+ import torch
6
  import gradio as gr
7
  from PIL import Image
8
 
9
+ # Unstructured for PDF parsing
10
  from unstructured.partition.pdf import partition_pdf
11
  from unstructured.partition.utils.constants import PartitionStrategy
12
 
13
  # Vision-language captioning (BLIP)
14
+ from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
15
 
16
+ # Hugging Face Inference client for LLM
17
  from huggingface_hub import InferenceClient
18
 
19
  # FAISS vectorstore
20
+ from langchain_community.vectorstores import FAISS
21
+
22
+ # Text embeddings
23
+ from langchain_huggingface import HuggingFaceEmbeddings
24
 
25
  # ── Globals ───────────────────────────────────────────────────────────────────
26
+ retriever = None
27
+ current_pdf_name = None
28
+ combined_texts: List[str] = [] # text chunks + captions
29
+ combined_vectors: List[List[float]] = []
30
+ pdf_text: str = ""
31
 
32
+ # ── Setup ─────────────────────────────────────────────────────────────────────
33
  FIGURES_DIR = "figures"
34
  if os.path.exists(FIGURES_DIR):
35
  shutil.rmtree(FIGURES_DIR)
36
+ else:
37
+ os.makedirs(FIGURES_DIR, exist_ok=True)
38
 
39
  # ── Clients & Models ───────────────────────────────────────────────────────────
40
+ hf = InferenceClient() # for chat completions
41
+ txt_emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
42
 
 
43
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
44
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
45
+
46
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
47
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
48
 
49
 
50
  def generate_caption(image_path: str) -> str:
 
51
  image = Image.open(image_path).convert("RGB")
52
  inputs = blip_processor(image, return_tensors="pt")
53
  out = blip_model.generate(**inputs)
 
55
 
56
 
57
  def embed_texts(texts: List[str]) -> List[List[float]]:
58
+ return txt_emb.embed_documents(texts)
 
 
59
 
60
 
61
+ def embed_images(image_paths: List[str]) -> List[List[float]]:
62
+ feats = []
63
+ for p in image_paths:
64
+ img = Image.open(p).convert("RGB")
65
+ inputs = clip_processor(images=img, return_tensors="pt")
66
+ with torch.no_grad():
67
+ v = clip_model.get_image_features(**inputs)
68
+ feats.append(v[0].cpu().tolist())
69
+ return feats
70
+
71
 
72
+ def process_pdf(pdf_file):
73
+ global retriever, current_pdf_name, combined_texts, combined_vectors, pdf_text
74
  if pdf_file is None:
75
  return None, "❌ Please upload a PDF file.", gr.update(interactive=False)
76
 
77
  current_pdf_name = os.path.basename(pdf_file.name)
78
+ # extract full text
79
  from pypdf import PdfReader
80
  reader = PdfReader(pdf_file.name)
81
  pages = [page.extract_text() or "" for page in reader.pages]
82
  pdf_text = "\n\n".join(pages)
83
 
84
+ # rich parsing for images
85
  try:
86
+ els = partition_pdf(
87
  filename=pdf_file.name,
88
  strategy=PartitionStrategy.HI_RES,
89
+ extract_image_block_types=["Image","Table"],
90
  extract_image_block_output_dir=FIGURES_DIR,
91
  )
92
+ texts = [e.text for e in els if e.category not in ["Image","Table"] and e.text]
93
+ imgs = [os.path.join(FIGURES_DIR,f) for f in os.listdir(FIGURES_DIR)
94
+ if f.lower().endswith((".png",".jpg",".jpeg"))]
95
+ except:
96
+ texts = pages
97
+ imgs = []
98
+
99
+ # split text chunks
 
100
  from langchain.text_splitter import CharacterTextSplitter
101
  splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
102
  chunks = []
103
+ for t in texts:
104
  chunks.extend(splitter.split_text(t))
105
+ caps = [generate_caption(i) for i in imgs]
106
+
107
+ # embed
108
+ tvecs = embed_texts(chunks + caps)
109
+ ivecs = embed_images(imgs)
110
+ # align dims: captions embedded twice? simple: drop caps embeddings from tvecs
111
+ text_count = len(chunks)
112
+ cap_count = len(caps)
113
+ # use text embeddings for text and clip for images
114
+ combined_texts = chunks + caps
115
+ combined_vectors = tvecs[:text_count] + ivecs
116
+
117
+ index = FAISS.from_embeddings(texts=combined_texts, embeddings=combined_vectors)
118
  retriever = index.as_retriever(search_kwargs={"k":2})
119
+ status = f"βœ… Indexed '{current_pdf_name}' β€” {len(chunks)} text chunks + {len(imgs)} images"
 
120
  return current_pdf_name, status, gr.update(interactive=True)
121
 
122
 
123
+ def ask_question(pdf_name,question):
 
124
  global retriever
125
  if retriever is None:
126
+ return "❌ Please process a PDF first."
127
  if not question.strip():
128
+ return "❌ Enter a question."
 
129
  docs = retriever.get_relevant_documents(question)
130
+ ctx = "\n\n".join(d.page_content for d in docs)
131
+ prompt = f"Use contexts:\n{ctx}\nQuestion:{question}\nAnswer:"
132
+ res = hf.chat_completion(model="google/gemma-3-27b-it",messages=[{"role":"user","content":prompt}],max_tokens=128)
133
+ return res["choices"][0]["message"]["content"].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
 
136
+ def generate_summary(): return ask_question(None,"Summarize:\n"+pdf_text[:2000])
137
+
138
+ def extract_keywords(): return ask_question(None,"Extract keywords:\n"+pdf_text[:2000])
139
+
140
  def clear_interface():
141
+ global retriever,combined_texts,combined_vectors,pdf_text
142
+ retriever=None
143
+ combined_texts=[]
144
+ combined_vectors=[]
145
+ pdf_text=""
146
+ shutil.rmtree(FIGURES_DIR,ignore_errors=True)
147
+ os.makedirs(FIGURES_DIR,exist_ok=True)
 
148
  return None, "", gr.update(interactive=False)
149
 
150
+ # UI
151
+ theme=gr.themes.Soft(primary_hue="indigo",secondary_hue="blue")
152
+ with gr.Blocks(theme=theme) as demo:
153
+ gr.Markdown("# DocQueryAI (True Multimodal RAG)")
 
 
 
 
 
154
  with gr.Row():
155
  with gr.Column():
156
+ pdf_disp=gr.Textbox(label="Active Document",interactive=False)
157
+ pdf_file=gr.File(file_types=[".pdf"],type="filepath")
158
+ btn_process=gr.Button("Process PDF")
159
+ status=gr.Textbox(interactive=False)
 
160
  with gr.Column():
161
+ q_in=gr.Textbox(lines=3,interactive=False)
162
+ btn_ask=gr.Button("Ask")
163
+ ans=gr.Textbox(interactive=False)
164
+ btn_sum=gr.Button("Summary",interactive=False);out_sum=gr.Textbox(interactive=False)
165
+ btn_key=gr.Button("Keywords",interactive=False);out_key=gr.Textbox(interactive=False)
166
+ btn_clear=gr.Button("Clear All")
167
+ btn_process.click(process_pdf,[pdf_file],[pdf_disp,status,q_in])
168
+ btn_ask.click(ask_question,[pdf_disp,q_in],ans)
169
+ btn_sum.click(generate_summary,[],out_sum)
170
+ btn_key.click(extract_keywords,[],out_key)
171
+ btn_clear.click(clear_interface,[],[pdf_disp,status,q_in])
172
+ if __name__=="__main__": demo.launch()