Spaces:
Sleeping
Sleeping
File size: 4,366 Bytes
ae38eb4 959d70e ae38eb4 b6eba06 ae38eb4 959d70e ae38eb4 959d70e ae38eb4 959d70e ae38eb4 b6eba06 4def369 ae38eb4 b6eba06 4def369 b6eba06 4def369 b6eba06 4def369 b6eba06 4def369 ae38eb4 b6eba06 ae38eb4 959d70e ae38eb4 b6eba06 ae38eb4 b6eba06 ae38eb4 b6eba06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import gradio as gr
from langchain_core.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
from langchain_google_genai import ChatGoogleGenerativeAI
import google.generativeai as genai
from langchain.chains.question_answering import load_qa_chain
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
# Load Mistral model
model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
mistral_tokenizer = AutoTokenizer.from_pretrained(model_path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16
mistral_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, device_map=device)
# Load BLIP model for image processing
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
def process_image(image):
# Convert PIL Image to tensor
inputs = blip_processor(images=image, return_tensors="pt").to(device)
# Generate caption from image
caption_ids = blip_model.generate(**inputs)
caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
return caption
def initialize(file_path, image, question):
try:
model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
prompt_template = """Answer the question as precise as possible using the provided context. If the answer is
not contained in the context, say "answer not available in context" \n\n
Context: \n {context}?\n
Question: \n {question} \n
Answer:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
context = ""
if file_path and os.path.exists(file_path):
pdf_loader = PyPDFLoader(file_path)
pages = pdf_loader.load_and_split()
context += "\n".join(str(page.page_content) for page in pages[:30])
if image:
image_context = process_image(image)
context += f"\nImage Context: {image_context}"
if context:
stuff_chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
stuff_answer = stuff_chain({"input_documents": [], "question": question, "context": context}, return_only_outputs=True)
gemini_answer = stuff_answer['output_text']
# Use Mistral model for additional text generation
mistral_prompt = f"Based on this answer: {gemini_answer}\nGenerate a follow-up question:"
mistral_inputs = mistral_tokenizer.encode(mistral_prompt, return_tensors='pt').to(device)
with torch.no_grad():
mistral_outputs = mistral_model.generate(mistral_inputs, max_length=50)
mistral_output = mistral_tokenizer.decode(mistral_outputs[0], skip_special_tokens=True)
combined_output = f"Gemini Answer: {gemini_answer}\n\nMistral Follow-up: {mistral_output}"
return combined_output
else:
return "Error: No valid context provided. Please upload a valid PDF or image."
except Exception as e:
return f"An error occurred: {str(e)}"
# Define Gradio Interface
input_file = gr.File(label="Upload PDF File")
input_image = gr.Image(type="pil", label="Upload Image")
input_question = gr.Textbox(label="Ask about the document")
output_text = gr.Textbox(label="Answer - Combined Gemini and Mistral")
def multimodal_qa(file, image, question):
if file is None and image is None:
return "Please upload a PDF file or an image first."
file_path = file.name if file else None
return initialize(file_path, image, question)
# Create Gradio Interface
gr.Interface(
fn=multimodal_qa,
inputs=[input_file, input_image, input_question],
outputs=output_text,
title="Multi-modal RAG with Gemini API and Mistral Model",
description="Upload a PDF or an image and ask questions about the content."
).launch()
|