File size: 7,023 Bytes
67a56f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import shutil
import PyPDF2
import gradio as gr
from PIL import Image

# Unstructured for rich PDF parsing
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.utils.constants import PartitionStrategy

# Vision-language captioning (BLIP)
from transformers import BlipProcessor, BlipForConditionalGeneration

# LangChain vectorstore and embeddings
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings

# HF Inference client for chat completions
from huggingface_hub import InferenceClient

# ── Globals ───────────────────────────────────────────────────────────────────
retriever = None               # FAISS retriever for multimodal content
current_pdf_name = None        # Name of the currently loaded PDF
combined_texts = None          # Combined text + image captions corpus

# ── Setup: directories ─────────────────────────────────────────────────────────
FIGURES_DIR = "figures"
if os.path.exists(FIGURES_DIR):
    shutil.rmtree(FIGURES_DIR)
os.makedirs(FIGURES_DIR, exist_ok=True)

# ── Models & Clients ───────────────────────────────────────────────────────────
# Chat model (Mistral-7B-Instruct)
chat_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3")
# Text embeddings (BAAI BGE)
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
# Image captioning (BLIP)
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")


def generate_caption(image_path: str) -> str:
    """
    Generates a natural-language caption for an image using BLIP.
    """
    image = Image.open(image_path).convert('RGB')
    inputs = blip_processor(image, return_tensors="pt")
    out = blip_model.generate(**inputs)
    caption = blip_processor.decode(out[0], skip_special_tokens=True)
    return caption


def process_pdf(pdf_file) -> str:
    """
    Parses the uploaded PDF into text chunks and image captions,
    builds a FAISS index, and prepares the retriever.
    Returns status message.
    """
    global current_pdf_name, retriever, combined_texts
    if pdf_file is None:
        return "❌ Please upload a PDF file."

    # Save PDF locally for unstructured
    pdf_path = pdf_file.name
    current_pdf_name = os.path.basename(pdf_path)

    # Extract text, table, and image blocks
    elements = partition_pdf(
        filename=pdf_path,
        strategy=PartitionStrategy.HI_RES,
        extract_image_block_types=["Image", "Table"],
        extract_image_block_output_dir=FIGURES_DIR
    )

    # Separate text and image elements
    text_elements = [el.text for el in elements if el.category not in ["Image", "Table"] and el.text]
    image_files = [os.path.join(FIGURES_DIR, f)
                   for f in os.listdir(FIGURES_DIR)
                   if f.lower().endswith((".png", ".jpg", ".jpeg"))]

    # Generate captions for each image
    captions = []
    for img in image_files:
        cap = generate_caption(img)
        captions.append(cap)

    # Combine all pieces for indexing
    combined_texts = text_elements + captions

    # Create FAISS index and retriever
    index = FAISS.from_texts(combined_texts, embeddings)
    retriever = index.as_retriever(search_kwargs={"k": 2})

    status = f"βœ… Indexed '{current_pdf_name}' β€” {len(text_elements)} text blocks + {len(captions)} image captions"
    return status


def ask_question(question: str) -> str:
    """
    Retrieves relevant chunks from the FAISS index and generates an answer via chat model.
    """
    global retriever
    if retriever is None:
        return "❌ Please upload and process a PDF first."
    if not question.strip():
        return "❌ Please enter a question."

    docs = retriever.get_relevant_documents(question)
    context = "\n\n".join(doc.page_content for doc in docs)

    prompt = (
        "Use the following document excerpts to answer the question.\n\n"
        f"{context}\n\n"
        f"Question: {question}\n"
        "Answer:"
    )

    response = chat_client.chat_completion(
        messages=[{"role": "user", "content": prompt}],
        max_tokens=128,
        temperature=0.5
    )
    answer = response["choices"][0]["message"]["content"].strip()
    return answer


def clear_interface():
    """Resets global state and clears the figures directory."""
    global retriever, current_pdf_name, combined_texts
    retriever = None
    current_pdf_name = None
    combined_texts = None
    shutil.rmtree(FIGURES_DIR)
    os.makedirs(FIGURES_DIR, exist_ok=True)
    return ""

# ── Gradio UI ────────────────────────────────────────────────────────────────
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue")
with gr.Blocks(theme=theme, css="""
    .container { border-radius: 10px; padding: 15px; }
    .pdf-active { border-left: 3px solid #6366f1; padding-left: 10px; background-color: rgba(99,102,241,0.1); }
    .footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #666; }
    .main-title { text-align: center; font-size: 64px; font-weight: bold; margin-bottom: 20px; }
""") as demo:
    gr.Markdown("<div class='main-title'>DocQueryAI (Multimodal)</div>")

    with gr.Row():
        with gr.Column():
            gr.Markdown("## πŸ“„ Document Input")
            pdf_display = gr.Textbox(label="Active Document", interactive=False, elem_classes="pdf-active")
            pdf_file = gr.File(file_types=[".pdf"], type="file")
            process_btn = gr.Button("πŸ“€ Process Document", variant="primary")
            status_box = gr.Textbox(label="Status", interactive=False)

        with gr.Column():
            gr.Markdown("## ❓ Ask Questions")
            question_input = gr.Textbox(lines=3, placeholder="Enter your question here…")
            ask_btn = gr.Button("πŸ” Ask Question", variant="primary")
            answer_output = gr.Textbox(label="Answer", lines=8, interactive=False)

    clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
    gr.Markdown("<div class='footer'>Powered by LangChain + Mistral 7B + FAISS + BLIP | Gradio</div>")

    process_btn.click(fn=process_pdf, inputs=[pdf_file], outputs=[status_box])
    ask_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output])
    clear_btn.click(fn=clear_interface, outputs=[status_box, answer_output])

if __name__ == "__main__":
    demo.launch(debug=True, share=True)