Spaces:
Running
Running
import os | |
import tempfile | |
import gradio as gr | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.document_loaders import UnstructuredPDFLoader | |
from langchain.chains import RetrievalQA | |
from langchain.llms import HuggingFaceHub | |
from PIL import Image | |
from transformers import pipeline | |
# Directories for temporary storage | |
FIGURES_DIR = tempfile.mkdtemp(prefix="figures_") | |
# Configure Hugging Face | |
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
# Initialize embeddings and vector store | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vector_store = None | |
# Initialize image captioning pipeline | |
captioner = pipeline("image-to-text", model="Salesforce/blip2-flan-t5-xl", use_auth_token=HUGGINGFACEHUB_API_TOKEN) | |
# Initialize LLM for QA | |
llm = HuggingFaceHub( | |
repo_id="google/flan-t5-xxl", | |
model_kwargs={"temperature":0.0, "max_length":256}, | |
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN | |
) | |
# Helper functions | |
def process_pdf(pdf_file): | |
# Load text content | |
loader = UnstructuredPDFLoader(pdf_file.name) | |
docs = loader.load() | |
# Basic text from PDF | |
raw_text = "\n".join([d.page_content for d in docs]) | |
# Optionally extract images and caption them | |
# Here, we simply caption any embedded images | |
captions = [] | |
# (In a real pipeline, extract and save images separately) | |
# For demo, we skip actual image files extraction | |
# Combine text and captions | |
combined = raw_text + "\n\n" + "\n".join(captions) | |
return combined | |
def build_index(text): | |
global vector_store | |
# Split into chunks | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
chunks = splitter.split_text(text) | |
# Create or update FAISS index | |
vector_store = FAISS.from_texts(chunks, embeddings) | |
def answer_query(query): | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=vector_store.as_retriever() | |
) | |
return qa.run(query) | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# Multimodal RAG QA App") | |
with gr.Row(): | |
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"] ) | |
question_input = gr.Textbox(label="Ask a question", placeholder="Enter your question here...") | |
output = gr.Textbox(label="Answer", interactive=False) | |
def on_submit(pdf, question): | |
if pdf is not None: | |
text = process_pdf(pdf) | |
build_index(text) | |
if not question: | |
return "Please enter a question." | |
return answer_query(question) | |
submit_btn = gr.Button("Get Answer") | |
submit_btn.click(on_submit, inputs=[pdf_input, question_input], outputs=output) | |
if __name__ == "__main__": | |
demo.launch() | |