from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT from prompt_examples import AYA_VISION_PROMPT_EXAMPLES import base64 from io import BytesIO from PIL import Image import logging import cohere import os import traceback import random import gradio as gr from google.cloud.sql.connector import Connector, IPTypes import pg8000 from datetime import datetime import sqlalchemy # from dotenv import load_dotenv # load_dotenv() MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY") logger = logging.getLogger(__name__) aya_vision_client = cohere.ClientV2( api_key=MULTIMODAL_API_KEY, client_name="c4ai-aya-vision-hf-space" ) def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME): response = aya_vision_client.chat( messages=chat_history, model=model, ) return response.message.content[0].text def get_aya_vision_prompt_example(language): example = AYA_VISION_PROMPT_EXAMPLES[language] return example[0], example[1] def get_base64_from_local_file(file_path): try: print("loading image") with open(file_path, "rb") as image_file: base64_image = base64.b64encode(image_file.read()).decode('utf-8') return base64_image except Exception as e: logger.debug(f"Error converting local image to base64 string: {e}") return None def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5): max_size_bytes = max_size_mb * 1024 * 1024 image_ext = image_filepath.lower() if image_ext.endswith(".jpg") or image_ext.endswith('.jpeg'): image_type="image/jpeg" elif image_ext.endswith(".png"): image_type = "image/png" elif image_ext.endswith(".webp"): image_type="image/webp" elif image_ext.endswith(".gif"): image_type="image/gif" response="" chat_history = [] print("converting image to base 64") base64_image = get_base64_from_local_file(image_filepath) image = f"data:{image_type};base64,{base64_image}" # to prevent Cohere API from throwing error for empty message if incoming_message=="" or incoming_message is None: incoming_message="." chat_history.append( { "role": "user", "content": [{"type": "text", "text": incoming_message}, {"type": "image_url","image_url": { "url": image}}], } ) image_size_bytes = get_base64_image_size(image) if image_size_bytes >= max_size_bytes: gr.Error("Please upload image with size under 5MB") # response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME) # return response res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME) output = "" for event in res: if event: if event.type == "content-delta": output += event.delta.message.content.text yield output def get_base64_image_size(base64_string): if ',' in base64_string: base64_data = base64_string.split(',', 1)[1] else: base64_data = base64_string base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '') padding = base64_data.count('=') size_bytes = (len(base64_data) * 3) // 4 - padding return size_bytes def insert_aya_audio(connection, user_prompt, text_response, audio_response_file_path): with connection.begin(): connection.execute( sqlalchemy.text(""" INSERT INTO aya_audio (user_prompt, text_response, audio_response_file_path, timestamp) VALUES (:user_prompt, :text_response, :audio_response_file_path, :timestamp) """), {"user_prompt": user_prompt, "text_response": text_response, "audio_response_file_path": audio_response_file_path, "timestamp": datetime.now()} ) def insert_aya_image(connection, user_prompt, generated_img_desc, image_response_file_path): with connection.begin(): connection.execute( sqlalchemy.text(""" INSERT INTO aya_image (user_prompt, generated_img_desc, image_response_file_path, timestamp) VALUES (:user_prompt, :generated_img_desc, :image_response_file_path, :timestamp) """), {"user_prompt": user_prompt, "generated_img_desc": generated_img_desc, "image_response_file_path": image_response_file_path, "timestamp": datetime.now()} ) def connect_with_connector() -> sqlalchemy.engine.base.Engine: instance_connection_name = os.environ[ "INSTANCE_CONNECTION_NAME" ] db_user = os.environ["DB_USER"] db_pass = os.environ["DB_PASS"] db_name = os.environ["DB_NAME"] ip_type = IPTypes.PRIVATE if os.environ.get("PRIVATE_IP") else IPTypes.PUBLIC connector = Connector(refresh_strategy="LAZY") def getconn() -> pg8000.dbapi.Connection: conn: pg8000.dbapi.Connection = connector.connect( instance_connection_name, "pg8000", user=db_user, password=db_pass, db=db_name, ip_type=ip_type, ) return conn pool = sqlalchemy.create_engine( "postgresql+pg8000://", creator=getconn, ) connection = pool.connect() return connection