File size: 4,366 Bytes
ae38eb4
 
 
 
959d70e
 
 
ae38eb4
 
b6eba06
 
ae38eb4
959d70e
 
 
ae38eb4
 
959d70e
ae38eb4
 
959d70e
ae38eb4
b6eba06
 
 
 
 
 
 
 
 
 
 
 
 
4def369
 
 
 
 
 
 
 
 
ae38eb4
b6eba06
 
 
4def369
 
b6eba06
 
 
 
 
 
 
4def369
b6eba06
4def369
 
 
 
 
 
 
 
 
 
 
 
b6eba06
4def369
 
ae38eb4
 
 
b6eba06
ae38eb4
959d70e
ae38eb4
b6eba06
 
 
 
 
ae38eb4
 
 
b6eba06
 
ae38eb4
b6eba06
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import gradio as gr
from langchain_core.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
from langchain_google_genai import ChatGoogleGenerativeAI
import google.generativeai as genai
from langchain.chains.question_answering import load_qa_chain
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image

# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

# Load Mistral model
model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
mistral_tokenizer = AutoTokenizer.from_pretrained(model_path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16
mistral_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, device_map=device)

# Load BLIP model for image processing
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

def process_image(image):
    # Convert PIL Image to tensor
    inputs = blip_processor(images=image, return_tensors="pt").to(device)
    # Generate caption from image
    caption_ids = blip_model.generate(**inputs)
    caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
    return caption

def initialize(file_path, image, question):
    try:
        model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
        prompt_template = """Answer the question as precise as possible using the provided context. If the answer is
                              not contained in the context, say "answer not available in context" \n\n
                              Context: \n {context}?\n
                              Question: \n {question} \n
                              Answer:
                            """
        prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
        
        context = ""
        
        if file_path and os.path.exists(file_path):
            pdf_loader = PyPDFLoader(file_path)
            pages = pdf_loader.load_and_split()
            context += "\n".join(str(page.page_content) for page in pages[:30])
        
        if image:
            image_context = process_image(image)
            context += f"\nImage Context: {image_context}"
        
        if context:
            stuff_chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
            stuff_answer = stuff_chain({"input_documents": [], "question": question, "context": context}, return_only_outputs=True)
            gemini_answer = stuff_answer['output_text']
            
            # Use Mistral model for additional text generation
            mistral_prompt = f"Based on this answer: {gemini_answer}\nGenerate a follow-up question:"
            mistral_inputs = mistral_tokenizer.encode(mistral_prompt, return_tensors='pt').to(device)
            with torch.no_grad():
                mistral_outputs = mistral_model.generate(mistral_inputs, max_length=50)
            mistral_output = mistral_tokenizer.decode(mistral_outputs[0], skip_special_tokens=True)
            
            combined_output = f"Gemini Answer: {gemini_answer}\n\nMistral Follow-up: {mistral_output}"
            return combined_output
        else:
            return "Error: No valid context provided. Please upload a valid PDF or image."
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Define Gradio Interface
input_file = gr.File(label="Upload PDF File")
input_image = gr.Image(type="pil", label="Upload Image")
input_question = gr.Textbox(label="Ask about the document")
output_text = gr.Textbox(label="Answer - Combined Gemini and Mistral")

def multimodal_qa(file, image, question):
    if file is None and image is None:
        return "Please upload a PDF file or an image first."
    file_path = file.name if file else None
    return initialize(file_path, image, question)

# Create Gradio Interface
gr.Interface(
    fn=multimodal_qa,
    inputs=[input_file, input_image, input_question],
    outputs=output_text,
    title="Multi-modal RAG with Gemini API and Mistral Model",
    description="Upload a PDF or an image and ask questions about the content."
).launch()