Spaces:
Running
on
T4
Running
on
T4
File size: 5,458 Bytes
43811b3 d281662 43811b3 cd98abf 43811b3 cd98abf 43811b3 cd98abf 43811b3 cd98abf 43811b3 d281662 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT
from prompt_examples import AYA_VISION_PROMPT_EXAMPLES
import base64
from io import BytesIO
from PIL import Image
import logging
import cohere
import os
import traceback
import random
import gradio as gr
from google.cloud.sql.connector import Connector, IPTypes
import pg8000
from datetime import datetime
import sqlalchemy
# from dotenv import load_dotenv
# load_dotenv()
MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY")
logger = logging.getLogger(__name__)
aya_vision_client = cohere.ClientV2(
api_key=MULTIMODAL_API_KEY,
client_name="c4ai-aya-vision-hf-space"
)
def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME):
response = aya_vision_client.chat(
messages=chat_history,
model=model,
)
return response.message.content[0].text
def get_aya_vision_prompt_example(language):
example = AYA_VISION_PROMPT_EXAMPLES[language]
return example[0], example[1]
def get_base64_from_local_file(file_path):
try:
print("loading image")
with open(file_path, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
return base64_image
except Exception as e:
logger.debug(f"Error converting local image to base64 string: {e}")
return None
def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
max_size_bytes = max_size_mb * 1024 * 1024
image_ext = image_filepath.lower()
if image_ext.endswith(".jpg") or image_ext.endswith('.jpeg'):
image_type="image/jpeg"
elif image_ext.endswith(".png"):
image_type = "image/png"
elif image_ext.endswith(".webp"):
image_type="image/webp"
elif image_ext.endswith(".gif"):
image_type="image/gif"
response=""
chat_history = []
print("converting image to base 64")
base64_image = get_base64_from_local_file(image_filepath)
image = f"data:{image_type};base64,{base64_image}"
# to prevent Cohere API from throwing error for empty message
if incoming_message=="" or incoming_message is None:
incoming_message="."
chat_history.append(
{
"role": "user",
"content": [{"type": "text", "text": incoming_message},
{"type": "image_url","image_url": { "url": image}}],
}
)
image_size_bytes = get_base64_image_size(image)
if image_size_bytes >= max_size_bytes:
gr.Error("Please upload image with size under 5MB")
# response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME)
# return response
res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME)
output = ""
for event in res:
if event:
if event.type == "content-delta":
output += event.delta.message.content.text
yield output
def get_base64_image_size(base64_string):
if ',' in base64_string:
base64_data = base64_string.split(',', 1)[1]
else:
base64_data = base64_string
base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '')
padding = base64_data.count('=')
size_bytes = (len(base64_data) * 3) // 4 - padding
return size_bytes
def insert_aya_audio(connection, user_prompt, text_response, audio_response_file_path):
with connection.begin():
connection.execute(
sqlalchemy.text("""
INSERT INTO aya_audio (user_prompt, text_response, audio_response_file_path, timestamp)
VALUES (:user_prompt, :text_response, :audio_response_file_path, :timestamp)
"""),
{"user_prompt": user_prompt, "text_response": text_response, "audio_response_file_path": audio_response_file_path, "timestamp": datetime.now()}
)
def insert_aya_image(connection, user_prompt, generated_img_desc, image_response_file_path):
with connection.begin():
connection.execute(
sqlalchemy.text("""
INSERT INTO aya_image (user_prompt, generated_img_desc, image_response_file_path, timestamp)
VALUES (:user_prompt, :generated_img_desc, :image_response_file_path, :timestamp)
"""),
{"user_prompt": user_prompt, "generated_img_desc": generated_img_desc, "image_response_file_path": image_response_file_path, "timestamp": datetime.now()}
)
def connect_with_connector() -> sqlalchemy.engine.base.Engine:
instance_connection_name = os.environ[
"INSTANCE_CONNECTION_NAME"
]
db_user = os.environ["DB_USER"]
db_pass = os.environ["DB_PASS"]
db_name = os.environ["DB_NAME"]
ip_type = IPTypes.PRIVATE if os.environ.get("PRIVATE_IP") else IPTypes.PUBLIC
connector = Connector(refresh_strategy="LAZY")
def getconn() -> pg8000.dbapi.Connection:
conn: pg8000.dbapi.Connection = connector.connect(
instance_connection_name,
"pg8000",
user=db_user,
password=db_pass,
db=db_name,
ip_type=ip_type,
)
return conn
pool = sqlalchemy.create_engine(
"postgresql+pg8000://",
creator=getconn,
)
connection = pool.connect()
return connection |