Spaces:
Running
Running
File size: 7,023 Bytes
67a56f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import shutil
import PyPDF2
import gradio as gr
from PIL import Image
# Unstructured for rich PDF parsing
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.utils.constants import PartitionStrategy
# Vision-language captioning (BLIP)
from transformers import BlipProcessor, BlipForConditionalGeneration
# LangChain vectorstore and embeddings
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
# HF Inference client for chat completions
from huggingface_hub import InferenceClient
# ββ Globals βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
retriever = None # FAISS retriever for multimodal content
current_pdf_name = None # Name of the currently loaded PDF
combined_texts = None # Combined text + image captions corpus
# ββ Setup: directories βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
FIGURES_DIR = "figures"
if os.path.exists(FIGURES_DIR):
shutil.rmtree(FIGURES_DIR)
os.makedirs(FIGURES_DIR, exist_ok=True)
# ββ Models & Clients βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Chat model (Mistral-7B-Instruct)
chat_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3")
# Text embeddings (BAAI BGE)
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
# Image captioning (BLIP)
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
def generate_caption(image_path: str) -> str:
"""
Generates a natural-language caption for an image using BLIP.
"""
image = Image.open(image_path).convert('RGB')
inputs = blip_processor(image, return_tensors="pt")
out = blip_model.generate(**inputs)
caption = blip_processor.decode(out[0], skip_special_tokens=True)
return caption
def process_pdf(pdf_file) -> str:
"""
Parses the uploaded PDF into text chunks and image captions,
builds a FAISS index, and prepares the retriever.
Returns status message.
"""
global current_pdf_name, retriever, combined_texts
if pdf_file is None:
return "β Please upload a PDF file."
# Save PDF locally for unstructured
pdf_path = pdf_file.name
current_pdf_name = os.path.basename(pdf_path)
# Extract text, table, and image blocks
elements = partition_pdf(
filename=pdf_path,
strategy=PartitionStrategy.HI_RES,
extract_image_block_types=["Image", "Table"],
extract_image_block_output_dir=FIGURES_DIR
)
# Separate text and image elements
text_elements = [el.text for el in elements if el.category not in ["Image", "Table"] and el.text]
image_files = [os.path.join(FIGURES_DIR, f)
for f in os.listdir(FIGURES_DIR)
if f.lower().endswith((".png", ".jpg", ".jpeg"))]
# Generate captions for each image
captions = []
for img in image_files:
cap = generate_caption(img)
captions.append(cap)
# Combine all pieces for indexing
combined_texts = text_elements + captions
# Create FAISS index and retriever
index = FAISS.from_texts(combined_texts, embeddings)
retriever = index.as_retriever(search_kwargs={"k": 2})
status = f"β
Indexed '{current_pdf_name}' β {len(text_elements)} text blocks + {len(captions)} image captions"
return status
def ask_question(question: str) -> str:
"""
Retrieves relevant chunks from the FAISS index and generates an answer via chat model.
"""
global retriever
if retriever is None:
return "β Please upload and process a PDF first."
if not question.strip():
return "β Please enter a question."
docs = retriever.get_relevant_documents(question)
context = "\n\n".join(doc.page_content for doc in docs)
prompt = (
"Use the following document excerpts to answer the question.\n\n"
f"{context}\n\n"
f"Question: {question}\n"
"Answer:"
)
response = chat_client.chat_completion(
messages=[{"role": "user", "content": prompt}],
max_tokens=128,
temperature=0.5
)
answer = response["choices"][0]["message"]["content"].strip()
return answer
def clear_interface():
"""Resets global state and clears the figures directory."""
global retriever, current_pdf_name, combined_texts
retriever = None
current_pdf_name = None
combined_texts = None
shutil.rmtree(FIGURES_DIR)
os.makedirs(FIGURES_DIR, exist_ok=True)
return ""
# ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue")
with gr.Blocks(theme=theme, css="""
.container { border-radius: 10px; padding: 15px; }
.pdf-active { border-left: 3px solid #6366f1; padding-left: 10px; background-color: rgba(99,102,241,0.1); }
.footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #666; }
.main-title { text-align: center; font-size: 64px; font-weight: bold; margin-bottom: 20px; }
""") as demo:
gr.Markdown("<div class='main-title'>DocQueryAI (Multimodal)</div>")
with gr.Row():
with gr.Column():
gr.Markdown("## π Document Input")
pdf_display = gr.Textbox(label="Active Document", interactive=False, elem_classes="pdf-active")
pdf_file = gr.File(file_types=[".pdf"], type="file")
process_btn = gr.Button("π€ Process Document", variant="primary")
status_box = gr.Textbox(label="Status", interactive=False)
with gr.Column():
gr.Markdown("## β Ask Questions")
question_input = gr.Textbox(lines=3, placeholder="Enter your question hereβ¦")
ask_btn = gr.Button("π Ask Question", variant="primary")
answer_output = gr.Textbox(label="Answer", lines=8, interactive=False)
clear_btn = gr.Button("ποΈ Clear All", variant="secondary")
gr.Markdown("<div class='footer'>Powered by LangChain + Mistral 7B + FAISS + BLIP | Gradio</div>")
process_btn.click(fn=process_pdf, inputs=[pdf_file], outputs=[status_box])
ask_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output])
clear_btn.click(fn=clear_interface, outputs=[status_box, answer_output])
if __name__ == "__main__":
demo.launch(debug=True, share=True)
|