File size: 5,458 Bytes
43811b3
 
 
 
 
 
 
 
 
 
 
d281662
 
 
 
43811b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd98abf
 
 
43811b3
cd98abf
43811b3
cd98abf
43811b3
cd98abf
43811b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d281662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT
from prompt_examples import AYA_VISION_PROMPT_EXAMPLES
import base64
from io import BytesIO
from PIL import Image
import logging
import cohere
import os
import traceback
import random
import gradio as gr
from google.cloud.sql.connector import Connector, IPTypes
import pg8000
from datetime import datetime
import sqlalchemy
# from dotenv import load_dotenv
# load_dotenv()

MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY")

logger = logging.getLogger(__name__)

aya_vision_client = cohere.ClientV2(
    api_key=MULTIMODAL_API_KEY,
    client_name="c4ai-aya-vision-hf-space"
)                                    

def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME):
    response = aya_vision_client.chat(
        messages=chat_history,
        model=model,
        )
    return response.message.content[0].text  


def get_aya_vision_prompt_example(language):
    example = AYA_VISION_PROMPT_EXAMPLES[language]
    return example[0], example[1]                                 

def get_base64_from_local_file(file_path):
    try:
        print("loading image")
        with open(file_path, "rb") as image_file:
            base64_image = base64.b64encode(image_file.read()).decode('utf-8')
        return base64_image
    except Exception as e:
        logger.debug(f"Error converting local image to base64 string: {e}")
        return None


def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
    max_size_bytes = max_size_mb * 1024 * 1024

    image_ext = image_filepath.lower()

    if image_ext.endswith(".jpg") or image_ext.endswith('.jpeg'):
        image_type="image/jpeg"
    elif image_ext.endswith(".png"):
        image_type =  "image/png"
    elif image_ext.endswith(".webp"):
        image_type="image/webp"
    elif image_ext.endswith(".gif"):
        image_type="image/gif"

    response=""
    chat_history = []
    print("converting image to base 64")
    base64_image = get_base64_from_local_file(image_filepath)
    image = f"data:{image_type};base64,{base64_image}"

    # to prevent Cohere API from throwing error for empty message
    if incoming_message=="" or incoming_message is None:
        incoming_message="."

    chat_history.append(
        {
        "role": "user",
        "content": [{"type": "text", "text": incoming_message},
            {"type": "image_url","image_url": { "url": image}}],
        }
    )

    image_size_bytes = get_base64_image_size(image)
    if image_size_bytes >= max_size_bytes:
        gr.Error("Please upload image with size under 5MB")
    
    # response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME)
    # return response
    res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME)
    output = ""

    for event in res:
        if event:
            if event.type == "content-delta":
                output += event.delta.message.content.text
                yield output

def get_base64_image_size(base64_string):
    if ',' in base64_string:
        base64_data = base64_string.split(',', 1)[1]
    else:
        base64_data = base64_string

    base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '')
    padding = base64_data.count('=')
    size_bytes = (len(base64_data) * 3) // 4 - padding
    return size_bytes


def insert_aya_audio(connection, user_prompt, text_response, audio_response_file_path):
    with connection.begin():
        connection.execute(
            sqlalchemy.text("""
                INSERT INTO aya_audio (user_prompt, text_response, audio_response_file_path, timestamp)
                VALUES (:user_prompt, :text_response, :audio_response_file_path, :timestamp)
            """),
            {"user_prompt": user_prompt, "text_response": text_response, "audio_response_file_path": audio_response_file_path, "timestamp": datetime.now()}
        )   

def insert_aya_image(connection, user_prompt, generated_img_desc, image_response_file_path):
    with connection.begin():
        connection.execute(
            sqlalchemy.text("""
                INSERT INTO aya_image (user_prompt, generated_img_desc, image_response_file_path, timestamp)
                VALUES (:user_prompt, :generated_img_desc, :image_response_file_path, :timestamp)
            """),
            {"user_prompt": user_prompt, "generated_img_desc": generated_img_desc, "image_response_file_path": image_response_file_path, "timestamp": datetime.now()}
        )

def connect_with_connector() -> sqlalchemy.engine.base.Engine:
    instance_connection_name = os.environ[
        "INSTANCE_CONNECTION_NAME"
    ]
    db_user = os.environ["DB_USER"] 
    db_pass = os.environ["DB_PASS"]
    db_name = os.environ["DB_NAME"]

    ip_type = IPTypes.PRIVATE if os.environ.get("PRIVATE_IP") else IPTypes.PUBLIC

    connector = Connector(refresh_strategy="LAZY")

    def getconn() -> pg8000.dbapi.Connection:
        conn: pg8000.dbapi.Connection = connector.connect(
            instance_connection_name,
            "pg8000",
            user=db_user,
            password=db_pass,
            db=db_name,
            ip_type=ip_type,
        )
        return conn

    pool = sqlalchemy.create_engine(
        "postgresql+pg8000://",
        creator=getconn,
    )

    connection = pool.connect()
    return connection