Praneetha N commited on
Commit
e05fa86
·
0 Parent(s):

Initial clean commit

Browse files
Files changed (6) hide show
  1. .gitattributes +3 -0
  2. .gitignore +16 -0
  3. README.md +0 -0
  4. app.py +151 -0
  5. rag_pipeline.py +272 -0
  6. requirements.txt +17 -0
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ __pycache__/
3
+ *.pyc
4
+ .venv/
5
+ venv/
6
+ models/
7
+ screenshots/
8
+ data/*.pdf
9
+ *.pdf
10
+ *.png
11
+ *.jpg
12
+ *.jpeg
13
+ *.bin
14
+ *.safetensors
15
+ *.onnx
16
+ *.pt
README.md ADDED
File without changes
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import json
4
+ import time
5
+ import tempfile
6
+ import streamlit as st
7
+
8
+ from doc_loader import load_document
9
+ from figure_extractor import extract_figures
10
+ from rag_pipeline import (
11
+ build_rag_pipeline,
12
+ query_rag_full,
13
+ evaluate_rag,
14
+ )
15
+
16
+ st.set_page_config(page_title="📑 Visual Document RAG", layout="wide")
17
+ st.title("📑 Visual Document RAG ")
18
+
19
+ domain = st.selectbox(
20
+ "Domain focus",
21
+ ["Finance", "Healthcare", "Law", "Education", "Multimodal"],
22
+ index=0,
23
+ help="Used to lightly steer retrieval/answers",
24
+ )
25
+
26
+ uploaded_file = st.file_uploader("📂 Upload a PDF or Image", type=["pdf", "png", "jpg"])
27
+ query = st.text_input("🔎 Ask a question about the document:")
28
+ summarize = st.checkbox("📝 Summarize the whole document")
29
+
30
+ if uploaded_file:
31
+ # persist to temp path for libs
32
+ with tempfile.NamedTemporaryFile(
33
+ delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}"
34
+ ) as tmp:
35
+ tmp.write(uploaded_file.read())
36
+ tmp_path = tmp.name
37
+
38
+ with st.spinner("Processing document..."):
39
+ docs_text, sections = load_document(tmp_path, return_sections=True)
40
+
41
+ # Optional figure extraction (PDF only)
42
+ figures_meta = []
43
+ extra_docs = []
44
+ if uploaded_file.type == "application/pdf":
45
+ try:
46
+ figures_meta = extract_figures(tmp_path, out_dir="figures", lang="eng") or []
47
+ if figures_meta:
48
+ sections["Figures (OCR+captions)"] = "\n\n".join(
49
+ [
50
+ f"p.{f['page']} — {f.get('caption') or '(no caption)'}\n"
51
+ f"{(f.get('ocr_text') or '')[:200]}"
52
+ for f in figures_meta
53
+ ]
54
+ )
55
+ # vectorizable figure docs
56
+ for f in figures_meta:
57
+ content = (
58
+ f"FIGURE p.{f['page']}: {f.get('caption') or ''}\n"
59
+ f"OCR: {f.get('ocr_text') or ''}\n"
60
+ f"TAGS: {' '.join(f.get('tags', []))}"
61
+ ).strip()
62
+ metadata = {
63
+ "type": "figure",
64
+ "page": f.get("page"),
65
+ "path": f.get("path"),
66
+ "caption": f.get("caption"),
67
+ "tags": f.get("tags", []),
68
+ }
69
+ extra_docs.append({"content": content, "metadata": metadata})
70
+ except Exception as e:
71
+ st.warning(f"Figure extraction skipped: {e}")
72
+
73
+ # Build vector index (now indexes figures too)
74
+ if docs_text.strip():
75
+ db = build_rag_pipeline(docs_text, extra_docs=extra_docs or None)
76
+ st.success(f"✅ Document indexed! (Domain: {domain})")
77
+
78
+ # Show extracted sections
79
+ with st.expander("📂 Extracted Document Content"):
80
+ tab_names = list(sections.keys())
81
+ tabs = st.tabs(tab_names)
82
+ for i, name in enumerate(tab_names):
83
+ with tabs[i]:
84
+ st.text_area(f"{name}", sections[name], height=230)
85
+
86
+ # Actions (Summarize or Q&A)
87
+ answer_text, retrieved_docs, latency = None, None, None
88
+ colA, colB = st.columns([1, 1])
89
+
90
+ with colA:
91
+ if summarize:
92
+ if st.button("📝 Summarize Document"):
93
+ start = time.time()
94
+ q = f"[Domain: {domain}] Summarize this document briefly with key points and numbers."
95
+ answer_text, _, retrieved_docs = query_rag_full(
96
+ db, q, domain=domain
97
+ )
98
+ latency = round(time.time() - start, 3)
99
+ else:
100
+ if query and st.button("💡 Get Answer"):
101
+ start = time.time()
102
+ q = f"[Domain: {domain}] {query}"
103
+ answer_text, _, retrieved_docs = query_rag_full(
104
+ db, q, domain=domain
105
+ )
106
+ latency = round(time.time() - start, 3)
107
+
108
+ # Show results
109
+ if answer_text is not None:
110
+ st.subheader("💡 Answer")
111
+ st.write(answer_text)
112
+ st.caption(f"⏱️ Latency: {latency}s")
113
+
114
+ with st.expander("🔍 Retrieved Contexts"):
115
+ if retrieved_docs:
116
+ for i, d in enumerate(retrieved_docs, 1):
117
+ meta = getattr(d, "metadata", {}) or {}
118
+ if meta.get("type") == "figure" and meta.get("path"):
119
+ st.write(
120
+ f"Figure (page {meta.get('page')}): "
121
+ f"{meta.get('caption') or '(no caption)'}"
122
+ )
123
+ st.image(meta["path"], use_container_width=True)
124
+ else:
125
+ st.info(f"Chunk {i}:\n\n{d.page_content}")
126
+
127
+ # Evaluation
128
+ st.markdown("---")
129
+ st.subheader("📊 Evaluation")
130
+ if st.button("Evaluate (LLM-based: Faithfulness & Relevancy)"):
131
+ try:
132
+ raw = evaluate_rag(
133
+ answer_text,
134
+ [d.page_content for d in (retrieved_docs or [])],
135
+ f"[Domain: {domain}] {query or 'Summary'}",
136
+ )
137
+ try:
138
+ payload = json.loads(raw) if isinstance(raw, str) else raw
139
+ except Exception:
140
+ payload = {"raw": raw}
141
+ st.json(payload)
142
+ st.download_button(
143
+ "⬇️ Download Evaluation JSON",
144
+ data=json.dumps(payload, indent=2),
145
+ file_name="evaluation.json",
146
+ mime="application/json",
147
+ )
148
+ except Exception as e:
149
+ st.warning(f"Evaluation unavailable: {e}")
150
+ else:
151
+ st.error("❌ No text could be extracted from the uploaded file.")
rag_pipeline.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_pipeline.py
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Tuple, Optional
5
+ from langchain_community.vectorstores import DocArrayInMemorySearch
6
+
7
+
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ from langchain.schema import Document
10
+
11
+ # Prefer FAISS if available; otherwise use a pure-Python fallback (no native build)
12
+ try:
13
+ from langchain_community.vectorstores import FAISS
14
+ _HAS_FAISS = True
15
+ except Exception:
16
+ _HAS_FAISS = False
17
+ from langchain_community.vectorstores import DocArrayInMemorySearch
18
+
19
+ # Embeddings + LLM
20
+ try:
21
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
22
+ _HAS_OPENAI = True
23
+ except Exception:
24
+ _HAS_OPENAI = False
25
+
26
+ try:
27
+ from langchain_community.embeddings import HuggingFaceEmbeddings
28
+ _HAS_HF = True
29
+ except Exception:
30
+ _HAS_HF = False
31
+
32
+ # Optional cross-encoder re-ranker
33
+ try:
34
+ from sentence_transformers import CrossEncoder
35
+ _HAS_RERANK = True
36
+ except Exception:
37
+ _HAS_RERANK = False
38
+
39
+
40
+ def _has_openai_key() -> bool:
41
+ return bool(os.getenv("OPENAI_API_KEY"))
42
+
43
+
44
+ def _hf_offline() -> bool:
45
+ # Either var set to "1"/"true" or network really blocked—honor explicit offline flags.
46
+ return str(os.getenv("HF_HUB_OFFLINE", "")).strip() not in ("", "0", "false") or \
47
+ str(os.getenv("TRANSFORMERS_OFFLINE", "")).strip() not in ("", "0", "false")
48
+
49
+
50
+ def _resolve_local_dir(env_var: str, default_subdir: str) -> Optional[str]:
51
+ """
52
+ Return an absolute path to a local model dir if it exists, else None.
53
+ env_var (e.g., EMB_LOCAL_DIR / RERANK_LOCAL_DIR) takes priority.
54
+ default_subdir is relative to this file's directory (e.g., 'models/<name>').
55
+ """
56
+ # 1) explicit env
57
+ p = os.getenv(env_var)
58
+ if p and Path(p).is_dir():
59
+ return str(Path(p).resolve())
60
+
61
+ # 2) project-relative
62
+ here = Path(__file__).parent
63
+ candidate = here / default_subdir
64
+ if candidate.is_dir():
65
+ return str(candidate.resolve())
66
+ return None
67
+
68
+
69
+ def _get_embeddings():
70
+ """
71
+ Prefer OpenAI embeddings if available; otherwise use HuggingFace with a local directory if present.
72
+ If offline and no local dir, raise a friendly error.
73
+ """
74
+ if _HAS_OPENAI and _has_openai_key():
75
+ return OpenAIEmbeddings(model="text-embedding-3-small")
76
+
77
+ if not _HAS_HF:
78
+ raise RuntimeError(
79
+ "No embeddings backend available. Install `langchain-openai` (and set OPENAI_API_KEY) "
80
+ "or install `langchain_community` + `sentence-transformers`."
81
+ )
82
+
83
+ # Try local first
84
+ local_dir = _resolve_local_dir(
85
+ env_var="EMB_LOCAL_DIR",
86
+ default_subdir="models/paraphrase-MiniLM-L6-v2",
87
+ )
88
+ if local_dir:
89
+ return HuggingFaceEmbeddings(model_name=local_dir)
90
+
91
+ # No local dir—fall back to hub name only if not offline
92
+ if _hf_offline():
93
+ raise RuntimeError(
94
+ "HF offline mode is enabled and no local embedding model was found.\n"
95
+ "Set EMB_LOCAL_DIR to a downloaded folder, or place the model at "
96
+ "<repo_root>/models/paraphrase-MiniLM-L6-v2.\n"
97
+ "Example download (online machine):\n"
98
+ " hf download sentence-transformers/paraphrase-MiniLM-L6-v2 "
99
+ "--local-dir models/paraphrase-MiniLM-L6-v2"
100
+ )
101
+ return HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-MiniLM-L6-v2")
102
+
103
+
104
+ def _get_llm(model_name: str = "gpt-4o-mini", temperature: float = 0):
105
+ if _HAS_OPENAI and _has_openai_key():
106
+ return ChatOpenAI(model=model_name, temperature=temperature)
107
+ return None
108
+
109
+
110
+ def build_rag_pipeline(
111
+ docs: str,
112
+ extra_docs: Optional[List[dict]] = None,
113
+ chunk_size: int = 1000,
114
+ chunk_overlap: int = 120,
115
+ ):
116
+ """
117
+ Build a vector index with metadata-aware Documents.
118
+ Falls back to DocArrayInMemorySearch when FAISS isn't available.
119
+ - docs: merged plain text from loader
120
+ - extra_docs: list of {"content": str, "metadata": dict} for figures, etc.
121
+ """
122
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
123
+ text_chunks = splitter.split_text(docs) if docs else []
124
+
125
+ documents: List[Document] = []
126
+ for i, ch in enumerate(text_chunks):
127
+ documents.append(
128
+ Document(page_content=ch, metadata={"type": "text", "chunk_id": i})
129
+ )
130
+
131
+ if extra_docs:
132
+ for ed in extra_docs:
133
+ documents.append(
134
+ Document(
135
+ page_content=(ed.get("content") or "").strip(),
136
+ metadata={**(ed.get("metadata") or {}), "type": (ed.get("metadata") or {}).get("type", "extra")}
137
+ )
138
+ )
139
+
140
+ if not documents:
141
+ raise ValueError("No content to index.")
142
+
143
+ embeddings = _get_embeddings()
144
+ return DocArrayInMemorySearch.from_documents(documents, embeddings)
145
+
146
+
147
+
148
+ def _domain_prompt(domain: str) -> str:
149
+ base = (
150
+ "You are an AI assistant specialized in {domain}. "
151
+ "Answer strictly using the provided context (including tables/figures). "
152
+ "Provide clear numbers and cite section/table/figure if possible. "
153
+ 'If the answer is not in the context, reply exactly: "I don\'t have enough information from the document."'
154
+ )
155
+ return base.format(domain=domain)
156
+
157
+
158
+ def query_rag_full(
159
+ db,
160
+ query: str,
161
+ top_k: int = 12,
162
+ rerank_keep: int = 5,
163
+ domain: str = "Finance",
164
+ ) -> Tuple[str, List[str], List[Document]]:
165
+ """
166
+ Returns (answer_text, retrieved_texts, retrieved_docs)
167
+ - Retrieves Documents with metadata
168
+ - Optional cross-encoder re-ranking (local-only if offline)
169
+ - LLM synthesis if available, else stitched fallback
170
+ """
171
+ retriever = db.as_retriever(search_kwargs={"k": top_k})
172
+ retrieved_docs: List[Document] = retriever.get_relevant_documents(query) or []
173
+
174
+ # Optional re-rank
175
+ top_docs = retrieved_docs
176
+ if _HAS_RERANK and retrieved_docs:
177
+ rerank_local = _resolve_local_dir("RERANK_LOCAL_DIR", "models/msmarco-MiniLM-L-6-v2")
178
+ try:
179
+ if rerank_local:
180
+ model = CrossEncoder(rerank_local)
181
+ elif not _hf_offline():
182
+ model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
183
+ else:
184
+ model = None # offline + no local reranker → skip
185
+ except Exception:
186
+ model = None
187
+
188
+ if model is not None:
189
+ pairs = [[query, d.page_content] for d in retrieved_docs]
190
+ scores = model.predict(pairs)
191
+ idx_sorted = sorted(range(len(scores)), key=lambda i: float(scores[i]), reverse=True)
192
+ keep = max(1, min(rerank_keep, len(idx_sorted)))
193
+ top_docs = [retrieved_docs[i] for i in idx_sorted[:keep]]
194
+
195
+ retrieved_texts = [d.page_content for d in top_docs]
196
+ if not retrieved_texts:
197
+ return "I couldn't find anything relevant in the document.", [], []
198
+
199
+ llm = _get_llm()
200
+ sys = _domain_prompt(domain)
201
+ if llm:
202
+ context = "\n\n".join([f"[{i+1}] {d.page_content[:4000]}" for i, d in enumerate(top_docs)])
203
+ cite_hints = []
204
+ for i, d in enumerate(top_docs):
205
+ m = d.metadata or {}
206
+ if m.get("type") == "figure" and m.get("page"):
207
+ cite_hints.append(f"[{i+1}] Figure p.{m.get('page')}")
208
+ elif m.get("type") == "text":
209
+ cite_hints.append(f"[{i+1}] Text chunk {m.get('chunk_id')}")
210
+ hints = "; ".join(cite_hints)
211
+
212
+ prompt = f"""{sys}
213
+
214
+ Context:
215
+ {context}
216
+
217
+ Hints for citations: {hints}
218
+
219
+ Question: {query}
220
+
221
+ Answer (include brief citations like [1] or 'Figure p.X' when appropriate):"""
222
+ answer = llm.invoke(prompt).content.strip()
223
+ return answer, retrieved_texts, top_docs
224
+
225
+ # Offline fallback
226
+ stitched = " ".join(retrieved_texts)[:1500]
227
+ answer = f"Answer (from retrieved context): {stitched}"
228
+ return answer, retrieved_texts, top_docs
229
+
230
+
231
+ def query_rag(db, query: str, top_k: int = 4) -> Tuple[str, List[str]]:
232
+ ans, texts, _docs = query_rag_full(db, query, top_k=top_k)
233
+ return ans, texts
234
+
235
+
236
+ def evaluate_rag(answer: str, retrieved_docs: List[str], query: str):
237
+ llm = _get_llm(model_name="gpt-4o-mini", temperature=0)
238
+ if not llm:
239
+ return {
240
+ "faithfulness": None,
241
+ "relevancy": None,
242
+ "explanation": (
243
+ "Evaluation requires an LLM (OpenAI). Set OPENAI_API_KEY and install `langchain-openai`."
244
+ ),
245
+ "mode": "offline-fallback",
246
+ }
247
+ docs_text = "\n".join(retrieved_docs)
248
+ eval_prompt = f"""
249
+ You are an impartial evaluator. Given a question, an assistant's answer, and the retrieved context,
250
+ score the response on:
251
+
252
+ 1) Faithfulness (0-5): Is every claim supported by the retrieved context?
253
+ 2) Relevancy (0-5): Do the retrieved docs directly address the question?
254
+
255
+ Return STRICT JSON ONLY:
256
+ {{
257
+ "faithfulness": <0-5 integer>,
258
+ "relevancy": <0-5 integer>,
259
+ "explanation": "<one-sentence reason>"
260
+ }}
261
+
262
+ ---
263
+ Question: {query}
264
+
265
+ Retrieved Context:
266
+ {docs_text}
267
+
268
+ Answer:
269
+ {answer}
270
+ """
271
+ raw = llm.invoke(eval_prompt).content.strip()
272
+ return raw
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ langchain
3
+ langchain-openai
4
+ langchain-community
5
+ langchain-text-splitters
6
+ faiss-cpu
7
+ pymupdf
8
+ pdfplumber
9
+ camelot-py
10
+ opencv-python
11
+ pillow
12
+ pytesseract
13
+ sentence-transformers
14
+ huggingface_hub
15
+ ragas
16
+ datasets
17
+ evaluate