Spaces:
Runtime error
Runtime error
# Import required libraries | |
import requests # For making HTTP requests | |
import os # For accessing environment variables | |
import google.generativeai as genai # For interacting with Google's Generative AI APIs | |
from typing import List # For type annotations | |
from utils import encode_image # Utility function to encode images as base64 | |
from PIL import Image # For image processing | |
class Rag: | |
""" | |
A class for interacting with Generative AI models (Gemini and OpenAI) to retrieve answers | |
based on user queries and associated images. | |
""" | |
# def get_answer_from_gemini(self, query: str, imagePaths: List[str]) -> str: | |
# """ | |
# Query the Gemini model with a text query and associated images. | |
# Args: | |
# query (str): The user's query. | |
# imagePaths (List[str]): List of file paths to images. | |
# Returns: | |
# str: The response text from the Gemini model. | |
# """ | |
# print(f"Querying Gemini for query={query}, imagePaths={imagePaths}") | |
# try: | |
# # Configure the Gemini API client using the API key from environment variables | |
# genai.configure(api_key=os.environ['GEMINI_API_KEY']) | |
# # Initialize the Gemini generative model | |
# model = genai.GenerativeModel('gemini-1.5-flash') | |
# # Load images from the given paths | |
# images = [Image.open(path) for path in imagePaths] | |
# # Start a new chat session | |
# chat = model.start_chat() | |
# # Send the query and images to the model | |
# response = chat.send_message([*images, query]) | |
# # Extract the response text | |
# answer = response.text | |
# print(answer) # Log the answer | |
# return answer | |
# except Exception as e: | |
# # Handle and log any errors that occur | |
# print(f"An error occurred while querying Gemini: {e}") | |
# return f"Error: {str(e)}" | |
def get_answer_from_openai(self, query: str, imagesPaths: List[str]) -> str: | |
""" | |
Query OpenAI's GPT model with a text query and associated images. | |
Args: | |
query (str): The user's query. | |
imagesPaths (List[str]): List of file paths to images. | |
Returns: | |
str: The response text from OpenAI. | |
""" | |
print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}") | |
try: | |
# Prepare the API payload with the query and images | |
payload = self.__get_openai_api_payload(query, imagesPaths) | |
# Define the HTTP headers for the OpenAI API request | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" # API key from environment variables | |
} | |
# Send a POST request to the OpenAI API | |
response = requests.post( | |
url="https://api.openai.com/v1/chat/completions", | |
headers=headers, | |
json=payload | |
) | |
response.raise_for_status() # Raise an error for unsuccessful requests | |
# Extract the content of the response | |
answer = response.json()["choices"][0]["message"]["content"] | |
print(answer) # Log the answer | |
return answer | |
except Exception as e: | |
# Handle and log any errors that occur | |
print(f"An error occurred while querying OpenAI: {e}") | |
return None | |
def get_answer_from_gemini(self, query: str, imagePaths: List[str]) -> str: | |
""" | |
Query the Gemini model with a text query and associated images. | |
Args: | |
query (str): The user's query. | |
imagePaths (List[str]): List of file paths to images. | |
Returns: | |
str: The response text from the Gemini model. | |
""" | |
print(f"Querying Gemini for query={query}, imagePaths={imagePaths}") | |
try: | |
# Configure the Gemini API client using the API key from environment variables | |
genai.configure(api_key=os.environ['GEMINI_API_KEY']) | |
# Initialize the Gemini generative model | |
model = genai.GenerativeModel('gemini-1.5-flash') | |
# Load images from the given paths (skip missing files) | |
images = [] | |
for path in imagePaths: | |
if os.path.exists(path): | |
images.append(Image.open(path)) | |
else: | |
print(f"Warning: Image not found {path}, skipping.") | |
# Start a new chat session | |
chat = model.start_chat() | |
# Construct the input for the model (handle cases with and without images) | |
input_data = [query] if not images else [*images, query] | |
# Send the query (and images, if any) to the model | |
response = chat.send_message(input_data) | |
# Extract the response text | |
answer = response.text | |
print(answer) # Log the answer | |
return answer | |
except Exception as e: | |
# Handle and log any errors that occur | |
print(f"An error occurred while querying Gemini: {e}") | |
return f"Error: {str(e)}" | |
def __get_openai_api_payload(self, query: str, imagesPaths: List[str]) -> dict: | |
""" | |
Prepare the payload for the OpenAI API request. | |
Args: | |
query (str): The user's query. | |
imagesPaths (List[str]): List of file paths to images. | |
Returns: | |
dict: The payload for the OpenAI API request. | |
""" | |
image_payload = [] # List to store encoded image data | |
# Encode each image as base64 and prepare the payload | |
for imagePath in imagesPaths: | |
base64_image = encode_image(imagePath) # Encode image in base64 | |
image_payload.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_image}" # Embed image data as a URL | |
} | |
}) | |
# Create the complete payload for the API request | |
payload = { | |
"model": "gpt-4o", # Specify the OpenAI model | |
"messages": [ | |
{ | |
"role": "user", # Role of the message sender | |
"content": [ | |
{ | |
"type": "text", | |
"text": query # Include the user's query | |
}, | |
*image_payload # Include the image data | |
] | |
} | |
], | |
"max_tokens": 1024 # Limit the response length | |
} | |
return payload | |