demo / rag.py
Kazel
change
8889abc
import requests
import os
from typing import List
from utils import encode_image
from PIL import Image
import torch
import subprocess
import psutil
import torch
from transformers import AutoModel, AutoTokenizer
import google.generativeai as genai
class Rag:
def get_answer_from_gemini(self, query, imagePaths):
print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
try:
genai.configure(api_key="AIzaSyCwRr9054tCuh2S8yGpwKFvOAxYMT4WNIs")
model = genai.GenerativeModel('gemini-1.5-flash')
images = [Image.open(path) for path in imagePaths]
chat = model.start_chat()
response = chat.send_message([*images, query])
answer = response.text
print(answer)
return answer
except Exception as e:
print(f"An error occurred while querying Gemini: {e}")
return f"Error: {str(e)}"
#os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
def get_answer_from_openai(self, query, imagesPaths):
#import environ variables from .env
import dotenv
# Load the .env file
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)
""" #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready)
print(f"Querying for query={query}, imagesPaths={imagesPaths}")
model = AutoModel.from_pretrained(
'openbmb/MiniCPM-o-2_6-int4',
trust_remote_code=True,
attn_implementation='flash_attention_2', # sdpa or flash_attention_2
torch_dtype=torch.bfloat16,
init_vision=True,
)
model = model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6-int4', trust_remote_code=True)
image = Image.open(imagesPaths[0]).convert('RGB')
msgs = [{'role': 'user', 'content': [image, query]}]
answer = model.chat(
image=None,
msgs=msgs,
tokenizer=tokenizer
)
print(answer)
return answer
"""
#ollama method below
torch.cuda.empty_cache() #release cuda so that ollama can use gpu!
os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] #int "1"
if os.environ['ollama'] == "minicpm-v":
os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" #set to quantized version
# Close model thread (colpali)
print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
try:
response = chat(
model=os.environ['ollama'],
messages=[
{
'role': 'user',
'content': query,
'images': imagesPaths,
"temperature":float(os.environ['temperature']), #test if temp makes a diff
}
],
)
answer = response.message.content
print(answer)
return answer
except Exception as e:
print(f"An error occurred while querying OpenAI: {e}")
return None
def __get_openai_api_payload(self, query:str, imagesPaths:List[str]):
image_payload = []
for imagePath in imagesPaths:
base64_image = encode_image(imagePath)
image_payload.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
payload = {
"model": "Llama3.2-vision", #change model here as needed
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": query
},
*image_payload
]
}
],
"max_tokens": 1024 #reduce token size to reduce processing time
}
return payload
# if __name__ == "__main__":
# rag = Rag()
# query = "Based on attached images, how many new cases were reported during second wave peak"
# imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
# rag.get_answer_from_gemini(query, imagesPaths)