multimodal_rag / rag.py
ej68okap
new code added
241c492
raw
history blame
6.81 kB
# Import required libraries
import requests # For making HTTP requests
import os # For accessing environment variables
import google.generativeai as genai # For interacting with Google's Generative AI APIs
from typing import List # For type annotations
from utils import encode_image # Utility function to encode images as base64
from PIL import Image # For image processing
class Rag:
"""
A class for interacting with Generative AI models (Gemini and OpenAI) to retrieve answers
based on user queries and associated images.
"""
# def get_answer_from_gemini(self, query: str, imagePaths: List[str]) -> str:
# """
# Query the Gemini model with a text query and associated images.
# Args:
# query (str): The user's query.
# imagePaths (List[str]): List of file paths to images.
# Returns:
# str: The response text from the Gemini model.
# """
# print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
# try:
# # Configure the Gemini API client using the API key from environment variables
# genai.configure(api_key=os.environ['GEMINI_API_KEY'])
# # Initialize the Gemini generative model
# model = genai.GenerativeModel('gemini-1.5-flash')
# # Load images from the given paths
# images = [Image.open(path) for path in imagePaths]
# # Start a new chat session
# chat = model.start_chat()
# # Send the query and images to the model
# response = chat.send_message([*images, query])
# # Extract the response text
# answer = response.text
# print(answer) # Log the answer
# return answer
# except Exception as e:
# # Handle and log any errors that occur
# print(f"An error occurred while querying Gemini: {e}")
# return f"Error: {str(e)}"
def get_answer_from_openai(self, query: str, imagesPaths: List[str]) -> str:
"""
Query OpenAI's GPT model with a text query and associated images.
Args:
query (str): The user's query.
imagesPaths (List[str]): List of file paths to images.
Returns:
str: The response text from OpenAI.
"""
print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
try:
# Prepare the API payload with the query and images
payload = self.__get_openai_api_payload(query, imagesPaths)
# Define the HTTP headers for the OpenAI API request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" # API key from environment variables
}
# Send a POST request to the OpenAI API
response = requests.post(
url="https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload
)
response.raise_for_status() # Raise an error for unsuccessful requests
# Extract the content of the response
answer = response.json()["choices"][0]["message"]["content"]
print(answer) # Log the answer
return answer
except Exception as e:
# Handle and log any errors that occur
print(f"An error occurred while querying OpenAI: {e}")
return None
def get_answer_from_gemini(self, query: str, imagePaths: List[str]) -> str:
"""
Query the Gemini model with a text query and associated images.
Args:
query (str): The user's query.
imagePaths (List[str]): List of file paths to images.
Returns:
str: The response text from the Gemini model.
"""
print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
try:
# Configure the Gemini API client using the API key from environment variables
genai.configure(api_key=os.environ['GEMINI_API_KEY'])
# Initialize the Gemini generative model
model = genai.GenerativeModel('gemini-1.5-flash')
# Load images from the given paths (skip missing files)
images = []
for path in imagePaths:
if os.path.exists(path):
images.append(Image.open(path))
else:
print(f"Warning: Image not found {path}, skipping.")
# Start a new chat session
chat = model.start_chat()
# Construct the input for the model (handle cases with and without images)
input_data = [query] if not images else [*images, query]
# Send the query (and images, if any) to the model
response = chat.send_message(input_data)
# Extract the response text
answer = response.text
print(answer) # Log the answer
return answer
except Exception as e:
# Handle and log any errors that occur
print(f"An error occurred while querying Gemini: {e}")
return f"Error: {str(e)}"
def __get_openai_api_payload(self, query: str, imagesPaths: List[str]) -> dict:
"""
Prepare the payload for the OpenAI API request.
Args:
query (str): The user's query.
imagesPaths (List[str]): List of file paths to images.
Returns:
dict: The payload for the OpenAI API request.
"""
image_payload = [] # List to store encoded image data
# Encode each image as base64 and prepare the payload
for imagePath in imagesPaths:
base64_image = encode_image(imagePath) # Encode image in base64
image_payload.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}" # Embed image data as a URL
}
})
# Create the complete payload for the API request
payload = {
"model": "gpt-4o", # Specify the OpenAI model
"messages": [
{
"role": "user", # Role of the message sender
"content": [
{
"type": "text",
"text": query # Include the user's query
},
*image_payload # Include the image data
]
}
],
"max_tokens": 1024 # Limit the response length
}
return payload