|
import requests |
|
import os |
|
|
|
from typing import List |
|
from utils import encode_image |
|
from PIL import Image |
|
import torch |
|
import subprocess |
|
import psutil |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
import google.generativeai as genai |
|
|
|
|
|
|
|
class Rag: |
|
|
|
def get_answer_from_gemini(self, query, imagePaths): |
|
|
|
print(f"Querying Gemini for query={query}, imagePaths={imagePaths}") |
|
|
|
try: |
|
genai.configure(api_key="AIzaSyCwRr9054tCuh2S8yGpwKFvOAxYMT4WNIs") |
|
model = genai.GenerativeModel('gemini-1.5-flash') |
|
|
|
images = [Image.open(path) for path in imagePaths] |
|
|
|
chat = model.start_chat() |
|
|
|
response = chat.send_message([*images, query]) |
|
|
|
answer = response.text |
|
|
|
print(answer) |
|
|
|
return answer |
|
|
|
except Exception as e: |
|
print(f"An error occurred while querying Gemini: {e}") |
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
def get_answer_from_openai(self, query, imagesPaths): |
|
|
|
import dotenv |
|
|
|
|
|
dotenv_file = dotenv.find_dotenv() |
|
dotenv.load_dotenv(dotenv_file) |
|
""" #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready) |
|
print(f"Querying for query={query}, imagesPaths={imagesPaths}") |
|
|
|
model = AutoModel.from_pretrained( |
|
'openbmb/MiniCPM-o-2_6-int4', |
|
trust_remote_code=True, |
|
attn_implementation='flash_attention_2', # sdpa or flash_attention_2 |
|
torch_dtype=torch.bfloat16, |
|
init_vision=True, |
|
) |
|
|
|
|
|
model = model.eval().cuda() |
|
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6-int4', trust_remote_code=True) |
|
image = Image.open(imagesPaths[0]).convert('RGB') |
|
|
|
msgs = [{'role': 'user', 'content': [image, query]}] |
|
answer = model.chat( |
|
image=None, |
|
msgs=msgs, |
|
tokenizer=tokenizer |
|
) |
|
print(answer) |
|
return answer |
|
""" |
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] |
|
if os.environ['ollama'] == "minicpm-v": |
|
os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" |
|
|
|
|
|
|
|
print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}") |
|
|
|
try: |
|
|
|
response = chat( |
|
model=os.environ['ollama'], |
|
messages=[ |
|
{ |
|
'role': 'user', |
|
'content': query, |
|
'images': imagesPaths, |
|
"temperature":float(os.environ['temperature']), |
|
} |
|
], |
|
) |
|
|
|
answer = response.message.content |
|
|
|
print(answer) |
|
|
|
return answer |
|
|
|
except Exception as e: |
|
print(f"An error occurred while querying OpenAI: {e}") |
|
return None |
|
|
|
|
|
|
|
def __get_openai_api_payload(self, query:str, imagesPaths:List[str]): |
|
image_payload = [] |
|
|
|
for imagePath in imagesPaths: |
|
base64_image = encode_image(imagePath) |
|
image_payload.append({ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/jpeg;base64,{base64_image}" |
|
} |
|
}) |
|
|
|
payload = { |
|
"model": "Llama3.2-vision", |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": query |
|
}, |
|
*image_payload |
|
] |
|
} |
|
], |
|
"max_tokens": 1024 |
|
} |
|
|
|
return payload |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|