Spaces:
Running
on
T4
Running
on
T4
Update code (#9)
Browse files- update code (d281662a0810be41652a2f14d20401324325c29f)
- add fix (f14e66e1c82b6e7d51eea0b0e4254485ee737490)
Co-authored-by: Shivalika Singh <[email protected]>
- app.py +52 -15
- aya_vision_utils.py +57 -9
- requirements.txt +5 -1
app.py
CHANGED
@@ -25,11 +25,14 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
|
25 |
from prompt_examples import TEXT_CHAT_EXAMPLES, IMG_GEN_PROMPT_EXAMPLES, AUDIO_EXAMPLES, TEXT_CHAT_EXAMPLES_LABELS, IMG_GEN_PROMPT_EXAMPLES_LABELS, AUDIO_EXAMPLES_LABELS, AYA_VISION_PROMPT_EXAMPLES
|
26 |
from preambles import CHAT_PREAMBLE, AUDIO_RESPONSE_PREAMBLE, IMG_DESCRIPTION_PREAMBLE
|
27 |
from constants import LID_LANGUAGES, NEETS_AI_LANGID_MAP, AYA_MODEL_NAME, BATCH_SIZE, USE_ELVENLABS, USE_REPLICATE
|
28 |
-
from aya_vision_utils import get_aya_vision_response, get_aya_vision_prompt_example
|
|
|
|
|
29 |
# from dotenv import load_dotenv
|
30 |
|
31 |
# load_dotenv()
|
32 |
|
|
|
33 |
HF_API_TOKEN = os.getenv("HF_API_KEY")
|
34 |
ELEVEN_LABS_KEY = os.getenv("ELEVEN_LABS_KEY")
|
35 |
NEETS_AI_API_KEY = os.getenv("NEETS_AI_API_KEY")
|
@@ -62,6 +65,17 @@ eleven_labs_client = ElevenLabs(
|
|
62 |
api_key=ELEVEN_LABS_KEY,
|
63 |
)
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
# Language identification
|
66 |
lid_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
|
67 |
LID_model = fasttext.load_model(lid_model_path)
|
@@ -102,20 +116,34 @@ def replicate_api_inference(input_prompt):
|
|
102 |
image = Image.open(image[0])
|
103 |
return image
|
104 |
|
105 |
-
def generate_image(input_prompt, model_id="black-forest-labs/FLUX.1-schnell"):
|
106 |
-
if input_prompt:
|
107 |
if USE_REPLICATE:
|
108 |
print("using replicate for image generation")
|
109 |
-
image = replicate_api_inference(
|
110 |
else:
|
111 |
try:
|
112 |
print("using HF inference API for image generation")
|
113 |
-
image_bytes = get_hf_inference_api_response({ "inputs":
|
114 |
image = np.array(Image.open(io.BytesIO(image_bytes)))
|
115 |
except Exception as e:
|
116 |
print("HF API error:", e)
|
117 |
# generate image with help replicate in case of error
|
118 |
-
image = replicate_api_inference(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
return image
|
120 |
else:
|
121 |
return None
|
@@ -246,7 +274,7 @@ def clean_text(text, remove_bullets=False, remove_newline=False):
|
|
246 |
|
247 |
return cleaned_text
|
248 |
|
249 |
-
def convert_text_to_speech(text, language="english"):
|
250 |
|
251 |
# do language detection to determine voice of speech response
|
252 |
if text:
|
@@ -268,19 +296,28 @@ def convert_text_to_speech(text, language="english"):
|
|
268 |
else:
|
269 |
# use elevenlabs for TTS
|
270 |
audio_path = elevenlabs_generate_audio(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
return audio_path
|
273 |
else:
|
274 |
return None
|
275 |
|
276 |
def elevenlabs_generate_audio(text):
|
277 |
-
audio = eleven_labs_client.
|
278 |
text=text,
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
audio_path = "
|
284 |
save(audio, audio_path)
|
285 |
return audio_path
|
286 |
|
@@ -534,7 +571,7 @@ with demo:
|
|
534 |
|
535 |
generated_img_desc.change(
|
536 |
generate_image, #run_flux,
|
537 |
-
inputs=[generated_img_desc],
|
538 |
outputs=[generated_img],
|
539 |
show_progress="full",
|
540 |
)
|
@@ -558,7 +595,7 @@ with demo:
|
|
558 |
show_progress="full",
|
559 |
).then(
|
560 |
convert_text_to_speech,
|
561 |
-
inputs=[e2e_audio_file_aya_response],
|
562 |
outputs=[e2e_aya_audio_response],
|
563 |
show_progress="full",
|
564 |
)
|
|
|
25 |
from prompt_examples import TEXT_CHAT_EXAMPLES, IMG_GEN_PROMPT_EXAMPLES, AUDIO_EXAMPLES, TEXT_CHAT_EXAMPLES_LABELS, IMG_GEN_PROMPT_EXAMPLES_LABELS, AUDIO_EXAMPLES_LABELS, AYA_VISION_PROMPT_EXAMPLES
|
26 |
from preambles import CHAT_PREAMBLE, AUDIO_RESPONSE_PREAMBLE, IMG_DESCRIPTION_PREAMBLE
|
27 |
from constants import LID_LANGUAGES, NEETS_AI_LANGID_MAP, AYA_MODEL_NAME, BATCH_SIZE, USE_ELVENLABS, USE_REPLICATE
|
28 |
+
from aya_vision_utils import get_aya_vision_response, get_aya_vision_prompt_example, insert_aya_audio, insert_aya_image, connect_with_connector
|
29 |
+
from google.cloud import storage
|
30 |
+
|
31 |
# from dotenv import load_dotenv
|
32 |
|
33 |
# load_dotenv()
|
34 |
|
35 |
+
|
36 |
HF_API_TOKEN = os.getenv("HF_API_KEY")
|
37 |
ELEVEN_LABS_KEY = os.getenv("ELEVEN_LABS_KEY")
|
38 |
NEETS_AI_API_KEY = os.getenv("NEETS_AI_API_KEY")
|
|
|
65 |
api_key=ELEVEN_LABS_KEY,
|
66 |
)
|
67 |
|
68 |
+
BUCKET_NAME = os.getenv("BUCKET_NAME")
|
69 |
+
AUDIO_BUCKET = os.getenv("AUDIO_BUCKET")
|
70 |
+
IMAGE_STORAGE_PATH = os.getenv("IMAGE_STORAGE_PATH")
|
71 |
+
AUDIO_STORAGE_PATH = os.getenv("AUDIO_STORAGE_PATH")
|
72 |
+
SAVING_ENABLED = True
|
73 |
+
|
74 |
+
storage_client = storage.Client()
|
75 |
+
bucket = storage_client.bucket(BUCKET_NAME)
|
76 |
+
audio_bucket = storage_client.bucket(AUDIO_BUCKET)
|
77 |
+
connection = connect_with_connector()
|
78 |
+
|
79 |
# Language identification
|
80 |
lid_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
|
81 |
LID_model = fasttext.load_model(lid_model_path)
|
|
|
116 |
image = Image.open(image[0])
|
117 |
return image
|
118 |
|
119 |
+
def generate_image(input_prompt, generated_img_desc, model_id="black-forest-labs/FLUX.1-schnell"):
|
120 |
+
if input_prompt and generated_img_desc:
|
121 |
if USE_REPLICATE:
|
122 |
print("using replicate for image generation")
|
123 |
+
image = replicate_api_inference(generated_img_desc)
|
124 |
else:
|
125 |
try:
|
126 |
print("using HF inference API for image generation")
|
127 |
+
image_bytes = get_hf_inference_api_response({ "inputs": generated_img_desc}, model_id)
|
128 |
image = np.array(Image.open(io.BytesIO(image_bytes)))
|
129 |
except Exception as e:
|
130 |
print("HF API error:", e)
|
131 |
# generate image with help replicate in case of error
|
132 |
+
image = replicate_api_inference(generated_img_desc)
|
133 |
+
|
134 |
+
# save image to local file
|
135 |
+
image_path = "generated_image.png"
|
136 |
+
image.save(image_path)
|
137 |
+
|
138 |
+
if SAVING_ENABLED:
|
139 |
+
unique_id = str(uuid.uuid4())
|
140 |
+
|
141 |
+
blob = bucket.blob(IMAGE_STORAGE_PATH + unique_id + "_" + image_path)
|
142 |
+
blob.upload_from_filename(image_path)
|
143 |
+
gcp_image_path = f"gs://{BUCKET_NAME}/{IMAGE_STORAGE_PATH}{unique_id}_{image_path}"
|
144 |
+
|
145 |
+
insert_aya_image(connection, input_prompt, generated_img_desc, gcp_image_path)
|
146 |
+
|
147 |
return image
|
148 |
else:
|
149 |
return None
|
|
|
274 |
|
275 |
return cleaned_text
|
276 |
|
277 |
+
def convert_text_to_speech(transcript, text, language="english"):
|
278 |
|
279 |
# do language detection to determine voice of speech response
|
280 |
if text:
|
|
|
296 |
else:
|
297 |
# use elevenlabs for TTS
|
298 |
audio_path = elevenlabs_generate_audio(text)
|
299 |
+
|
300 |
+
if SAVING_ENABLED:
|
301 |
+
unique_id = str(uuid.uuid4())
|
302 |
+
|
303 |
+
blob = audio_bucket.blob(AUDIO_STORAGE_PATH + unique_id + "_" + audio_path)
|
304 |
+
blob.upload_from_filename(audio_path)
|
305 |
+
gcp_audio_path = f"gs://{BUCKET_NAME}/{AUDIO_STORAGE_PATH}{unique_id}_{audio_path}"
|
306 |
+
|
307 |
+
insert_aya_audio(connection, transcript, text, gcp_audio_path)
|
308 |
|
309 |
return audio_path
|
310 |
else:
|
311 |
return None
|
312 |
|
313 |
def elevenlabs_generate_audio(text):
|
314 |
+
audio = eleven_labs_client.text_to_speech.convert(
|
315 |
text=text,
|
316 |
+
voice_id="21m00Tcm4TlvDq8ikWAM", #Rachel
|
317 |
+
model_id="eleven_multilingual_v2",
|
318 |
+
output_format="mp3_44100_128",
|
319 |
+
)
|
320 |
+
audio_path = "audio.mp3"
|
321 |
save(audio, audio_path)
|
322 |
return audio_path
|
323 |
|
|
|
571 |
|
572 |
generated_img_desc.change(
|
573 |
generate_image, #run_flux,
|
574 |
+
inputs=[input_img_prompt, generated_img_desc],
|
575 |
outputs=[generated_img],
|
576 |
show_progress="full",
|
577 |
)
|
|
|
595 |
show_progress="full",
|
596 |
).then(
|
597 |
convert_text_to_speech,
|
598 |
+
inputs=[e2e_audio_file_trans, e2e_audio_file_aya_response],
|
599 |
outputs=[e2e_aya_audio_response],
|
600 |
show_progress="full",
|
601 |
)
|
aya_vision_utils.py
CHANGED
@@ -9,7 +9,10 @@ import os
|
|
9 |
import traceback
|
10 |
import random
|
11 |
import gradio as gr
|
12 |
-
|
|
|
|
|
|
|
13 |
# from dotenv import load_dotenv
|
14 |
# load_dotenv()
|
15 |
|
@@ -32,9 +35,6 @@ def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME):
|
|
32 |
|
33 |
def get_aya_vision_prompt_example(language):
|
34 |
example = AYA_VISION_PROMPT_EXAMPLES[language]
|
35 |
-
print("example:", example)
|
36 |
-
print("example prompt:", example[0])
|
37 |
-
print("example image:", example[1])
|
38 |
return example[0], example[1]
|
39 |
|
40 |
def get_base64_from_local_file(file_path):
|
@@ -42,7 +42,6 @@ def get_base64_from_local_file(file_path):
|
|
42 |
print("loading image")
|
43 |
with open(file_path, "rb") as image_file:
|
44 |
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
45 |
-
print("converted image")
|
46 |
return base64_image
|
47 |
except Exception as e:
|
48 |
logger.debug(f"Error converting local image to base64 string: {e}")
|
@@ -50,8 +49,6 @@ def get_base64_from_local_file(file_path):
|
|
50 |
|
51 |
|
52 |
def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
|
53 |
-
print("incoming message:", incoming_message)
|
54 |
-
print("image_filepath:", image_filepath)
|
55 |
max_size_bytes = max_size_mb * 1024 * 1024
|
56 |
|
57 |
image_ext = image_filepath.lower()
|
@@ -70,7 +67,6 @@ def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
|
|
70 |
print("converting image to base 64")
|
71 |
base64_image = get_base64_from_local_file(image_filepath)
|
72 |
image = f"data:{image_type};base64,{base64_image}"
|
73 |
-
print("Image base64:", image[:30])
|
74 |
|
75 |
# to prevent Cohere API from throwing error for empty message
|
76 |
if incoming_message=="" or incoming_message is None:
|
@@ -108,4 +104,56 @@ def get_base64_image_size(base64_string):
|
|
108 |
base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '')
|
109 |
padding = base64_data.count('=')
|
110 |
size_bytes = (len(base64_data) * 3) // 4 - padding
|
111 |
-
return size_bytes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import traceback
|
10 |
import random
|
11 |
import gradio as gr
|
12 |
+
from google.cloud.sql.connector import Connector, IPTypes
|
13 |
+
import pg8000
|
14 |
+
from datetime import datetime
|
15 |
+
import sqlalchemy
|
16 |
# from dotenv import load_dotenv
|
17 |
# load_dotenv()
|
18 |
|
|
|
35 |
|
36 |
def get_aya_vision_prompt_example(language):
|
37 |
example = AYA_VISION_PROMPT_EXAMPLES[language]
|
|
|
|
|
|
|
38 |
return example[0], example[1]
|
39 |
|
40 |
def get_base64_from_local_file(file_path):
|
|
|
42 |
print("loading image")
|
43 |
with open(file_path, "rb") as image_file:
|
44 |
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
|
|
45 |
return base64_image
|
46 |
except Exception as e:
|
47 |
logger.debug(f"Error converting local image to base64 string: {e}")
|
|
|
49 |
|
50 |
|
51 |
def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
|
|
|
|
|
52 |
max_size_bytes = max_size_mb * 1024 * 1024
|
53 |
|
54 |
image_ext = image_filepath.lower()
|
|
|
67 |
print("converting image to base 64")
|
68 |
base64_image = get_base64_from_local_file(image_filepath)
|
69 |
image = f"data:{image_type};base64,{base64_image}"
|
|
|
70 |
|
71 |
# to prevent Cohere API from throwing error for empty message
|
72 |
if incoming_message=="" or incoming_message is None:
|
|
|
104 |
base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '')
|
105 |
padding = base64_data.count('=')
|
106 |
size_bytes = (len(base64_data) * 3) // 4 - padding
|
107 |
+
return size_bytes
|
108 |
+
|
109 |
+
|
110 |
+
def insert_aya_audio(connection, user_prompt, text_response, audio_response_file_path):
|
111 |
+
with connection.begin():
|
112 |
+
connection.execute(
|
113 |
+
sqlalchemy.text("""
|
114 |
+
INSERT INTO aya_audio (user_prompt, text_response, audio_response_file_path, timestamp)
|
115 |
+
VALUES (:user_prompt, :text_response, :audio_response_file_path, :timestamp)
|
116 |
+
"""),
|
117 |
+
{"user_prompt": user_prompt, "text_response": text_response, "audio_response_file_path": audio_response_file_path, "timestamp": datetime.now()}
|
118 |
+
)
|
119 |
+
|
120 |
+
def insert_aya_image(connection, user_prompt, generated_img_desc, image_response_file_path):
|
121 |
+
with connection.begin():
|
122 |
+
connection.execute(
|
123 |
+
sqlalchemy.text("""
|
124 |
+
INSERT INTO aya_image (user_prompt, generated_img_desc, image_response_file_path, timestamp)
|
125 |
+
VALUES (:user_prompt, :generated_img_desc, :image_response_file_path, :timestamp)
|
126 |
+
"""),
|
127 |
+
{"user_prompt": user_prompt, "generated_img_desc": generated_img_desc, "image_response_file_path": image_response_file_path, "timestamp": datetime.now()}
|
128 |
+
)
|
129 |
+
|
130 |
+
def connect_with_connector() -> sqlalchemy.engine.base.Engine:
|
131 |
+
instance_connection_name = os.environ[
|
132 |
+
"INSTANCE_CONNECTION_NAME"
|
133 |
+
]
|
134 |
+
db_user = os.environ["DB_USER"]
|
135 |
+
db_pass = os.environ["DB_PASS"]
|
136 |
+
db_name = os.environ["DB_NAME"]
|
137 |
+
|
138 |
+
ip_type = IPTypes.PRIVATE if os.environ.get("PRIVATE_IP") else IPTypes.PUBLIC
|
139 |
+
|
140 |
+
connector = Connector(refresh_strategy="LAZY")
|
141 |
+
|
142 |
+
def getconn() -> pg8000.dbapi.Connection:
|
143 |
+
conn: pg8000.dbapi.Connection = connector.connect(
|
144 |
+
instance_connection_name,
|
145 |
+
"pg8000",
|
146 |
+
user=db_user,
|
147 |
+
password=db_pass,
|
148 |
+
db=db_name,
|
149 |
+
ip_type=ip_type,
|
150 |
+
)
|
151 |
+
return conn
|
152 |
+
|
153 |
+
pool = sqlalchemy.create_engine(
|
154 |
+
"postgresql+pg8000://",
|
155 |
+
creator=getconn,
|
156 |
+
)
|
157 |
+
|
158 |
+
connection = pool.connect()
|
159 |
+
return connection
|
requirements.txt
CHANGED
@@ -10,4 +10,8 @@ groq
|
|
10 |
replicate
|
11 |
fasttext
|
12 |
cutlet
|
13 |
-
fugashi[unidic-lite]
|
|
|
|
|
|
|
|
|
|
10 |
replicate
|
11 |
fasttext
|
12 |
cutlet
|
13 |
+
fugashi[unidic-lite]
|
14 |
+
python-dotenv
|
15 |
+
SQLAlchemy
|
16 |
+
google-cloud-storage
|
17 |
+
cloud-sql-python-connector[pg8000]
|