Spaces:
Running
on
T4
Running
on
T4
from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT | |
from prompt_examples import AYA_VISION_PROMPT_EXAMPLES | |
import base64 | |
from io import BytesIO | |
from PIL import Image | |
import logging | |
import cohere | |
import os | |
import traceback | |
import random | |
import gradio as gr | |
from google.cloud.sql.connector import Connector, IPTypes | |
import pg8000 | |
from datetime import datetime | |
import sqlalchemy | |
# from dotenv import load_dotenv | |
# load_dotenv() | |
MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY") | |
logger = logging.getLogger(__name__) | |
aya_vision_client = cohere.ClientV2( | |
api_key=MULTIMODAL_API_KEY, | |
client_name="c4ai-aya-vision-hf-space" | |
) | |
def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME): | |
response = aya_vision_client.chat( | |
messages=chat_history, | |
model=model, | |
) | |
return response.message.content[0].text | |
def get_aya_vision_prompt_example(language): | |
example = AYA_VISION_PROMPT_EXAMPLES[language] | |
return example[0], example[1] | |
def get_base64_from_local_file(file_path): | |
try: | |
print("loading image") | |
with open(file_path, "rb") as image_file: | |
base64_image = base64.b64encode(image_file.read()).decode('utf-8') | |
return base64_image | |
except Exception as e: | |
logger.debug(f"Error converting local image to base64 string: {e}") | |
return None | |
def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5): | |
max_size_bytes = max_size_mb * 1024 * 1024 | |
image_ext = image_filepath.lower() | |
if image_ext.endswith(".jpg") or image_ext.endswith('.jpeg'): | |
image_type="image/jpeg" | |
elif image_ext.endswith(".png"): | |
image_type = "image/png" | |
elif image_ext.endswith(".webp"): | |
image_type="image/webp" | |
elif image_ext.endswith(".gif"): | |
image_type="image/gif" | |
response="" | |
chat_history = [] | |
print("converting image to base 64") | |
base64_image = get_base64_from_local_file(image_filepath) | |
image = f"data:{image_type};base64,{base64_image}" | |
# to prevent Cohere API from throwing error for empty message | |
if incoming_message=="" or incoming_message is None: | |
incoming_message="." | |
chat_history.append( | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": incoming_message}, | |
{"type": "image_url","image_url": { "url": image}}], | |
} | |
) | |
image_size_bytes = get_base64_image_size(image) | |
if image_size_bytes >= max_size_bytes: | |
gr.Error("Please upload image with size under 5MB") | |
# response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME) | |
# return response | |
res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME) | |
output = "" | |
for event in res: | |
if event: | |
if event.type == "content-delta": | |
output += event.delta.message.content.text | |
yield output | |
def get_base64_image_size(base64_string): | |
if ',' in base64_string: | |
base64_data = base64_string.split(',', 1)[1] | |
else: | |
base64_data = base64_string | |
base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '') | |
padding = base64_data.count('=') | |
size_bytes = (len(base64_data) * 3) // 4 - padding | |
return size_bytes | |
def insert_aya_audio(connection, user_prompt, text_response, audio_response_file_path): | |
with connection.begin(): | |
connection.execute( | |
sqlalchemy.text(""" | |
INSERT INTO aya_audio (user_prompt, text_response, audio_response_file_path, timestamp) | |
VALUES (:user_prompt, :text_response, :audio_response_file_path, :timestamp) | |
"""), | |
{"user_prompt": user_prompt, "text_response": text_response, "audio_response_file_path": audio_response_file_path, "timestamp": datetime.now()} | |
) | |
def insert_aya_image(connection, user_prompt, generated_img_desc, image_response_file_path): | |
with connection.begin(): | |
connection.execute( | |
sqlalchemy.text(""" | |
INSERT INTO aya_image (user_prompt, generated_img_desc, image_response_file_path, timestamp) | |
VALUES (:user_prompt, :generated_img_desc, :image_response_file_path, :timestamp) | |
"""), | |
{"user_prompt": user_prompt, "generated_img_desc": generated_img_desc, "image_response_file_path": image_response_file_path, "timestamp": datetime.now()} | |
) | |
def connect_with_connector() -> sqlalchemy.engine.base.Engine: | |
instance_connection_name = os.environ[ | |
"INSTANCE_CONNECTION_NAME" | |
] | |
db_user = os.environ["DB_USER"] | |
db_pass = os.environ["DB_PASS"] | |
db_name = os.environ["DB_NAME"] | |
ip_type = IPTypes.PRIVATE if os.environ.get("PRIVATE_IP") else IPTypes.PUBLIC | |
connector = Connector(refresh_strategy="LAZY") | |
def getconn() -> pg8000.dbapi.Connection: | |
conn: pg8000.dbapi.Connection = connector.connect( | |
instance_connection_name, | |
"pg8000", | |
user=db_user, | |
password=db_pass, | |
db=db_name, | |
ip_type=ip_type, | |
) | |
return conn | |
pool = sqlalchemy.create_engine( | |
"postgresql+pg8000://", | |
creator=getconn, | |
) | |
connection = pool.connect() | |
return connection |