aya_expanse / aya_vision_utils.py
shivalika's picture
update code
d281662
raw
history blame
5.46 kB
from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT
from prompt_examples import AYA_VISION_PROMPT_EXAMPLES
import base64
from io import BytesIO
from PIL import Image
import logging
import cohere
import os
import traceback
import random
import gradio as gr
from google.cloud.sql.connector import Connector, IPTypes
import pg8000
from datetime import datetime
import sqlalchemy
# from dotenv import load_dotenv
# load_dotenv()
MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY")
logger = logging.getLogger(__name__)
aya_vision_client = cohere.ClientV2(
api_key=MULTIMODAL_API_KEY,
client_name="c4ai-aya-vision-hf-space"
)
def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME):
response = aya_vision_client.chat(
messages=chat_history,
model=model,
)
return response.message.content[0].text
def get_aya_vision_prompt_example(language):
example = AYA_VISION_PROMPT_EXAMPLES[language]
return example[0], example[1]
def get_base64_from_local_file(file_path):
try:
print("loading image")
with open(file_path, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
return base64_image
except Exception as e:
logger.debug(f"Error converting local image to base64 string: {e}")
return None
def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
max_size_bytes = max_size_mb * 1024 * 1024
image_ext = image_filepath.lower()
if image_ext.endswith(".jpg") or image_ext.endswith('.jpeg'):
image_type="image/jpeg"
elif image_ext.endswith(".png"):
image_type = "image/png"
elif image_ext.endswith(".webp"):
image_type="image/webp"
elif image_ext.endswith(".gif"):
image_type="image/gif"
response=""
chat_history = []
print("converting image to base 64")
base64_image = get_base64_from_local_file(image_filepath)
image = f"data:{image_type};base64,{base64_image}"
# to prevent Cohere API from throwing error for empty message
if incoming_message=="" or incoming_message is None:
incoming_message="."
chat_history.append(
{
"role": "user",
"content": [{"type": "text", "text": incoming_message},
{"type": "image_url","image_url": { "url": image}}],
}
)
image_size_bytes = get_base64_image_size(image)
if image_size_bytes >= max_size_bytes:
gr.Error("Please upload image with size under 5MB")
# response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME)
# return response
res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME)
output = ""
for event in res:
if event:
if event.type == "content-delta":
output += event.delta.message.content.text
yield output
def get_base64_image_size(base64_string):
if ',' in base64_string:
base64_data = base64_string.split(',', 1)[1]
else:
base64_data = base64_string
base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '')
padding = base64_data.count('=')
size_bytes = (len(base64_data) * 3) // 4 - padding
return size_bytes
def insert_aya_audio(connection, user_prompt, text_response, audio_response_file_path):
with connection.begin():
connection.execute(
sqlalchemy.text("""
INSERT INTO aya_audio (user_prompt, text_response, audio_response_file_path, timestamp)
VALUES (:user_prompt, :text_response, :audio_response_file_path, :timestamp)
"""),
{"user_prompt": user_prompt, "text_response": text_response, "audio_response_file_path": audio_response_file_path, "timestamp": datetime.now()}
)
def insert_aya_image(connection, user_prompt, generated_img_desc, image_response_file_path):
with connection.begin():
connection.execute(
sqlalchemy.text("""
INSERT INTO aya_image (user_prompt, generated_img_desc, image_response_file_path, timestamp)
VALUES (:user_prompt, :generated_img_desc, :image_response_file_path, :timestamp)
"""),
{"user_prompt": user_prompt, "generated_img_desc": generated_img_desc, "image_response_file_path": image_response_file_path, "timestamp": datetime.now()}
)
def connect_with_connector() -> sqlalchemy.engine.base.Engine:
instance_connection_name = os.environ[
"INSTANCE_CONNECTION_NAME"
]
db_user = os.environ["DB_USER"]
db_pass = os.environ["DB_PASS"]
db_name = os.environ["DB_NAME"]
ip_type = IPTypes.PRIVATE if os.environ.get("PRIVATE_IP") else IPTypes.PUBLIC
connector = Connector(refresh_strategy="LAZY")
def getconn() -> pg8000.dbapi.Connection:
conn: pg8000.dbapi.Connection = connector.connect(
instance_connection_name,
"pg8000",
user=db_user,
password=db_pass,
db=db_name,
ip_type=ip_type,
)
return conn
pool = sqlalchemy.create_engine(
"postgresql+pg8000://",
creator=getconn,
)
connection = pool.connect()
return connection