shivalikasingh shivalika commited on
Commit
dbb8f6f
·
verified ·
1 Parent(s): befde28

Update code (#9)

Browse files

- update code (d281662a0810be41652a2f14d20401324325c29f)
- add fix (f14e66e1c82b6e7d51eea0b0e4254485ee737490)


Co-authored-by: Shivalika Singh <[email protected]>

Files changed (3) hide show
  1. app.py +52 -15
  2. aya_vision_utils.py +57 -9
  3. requirements.txt +5 -1
app.py CHANGED
@@ -25,11 +25,14 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
25
  from prompt_examples import TEXT_CHAT_EXAMPLES, IMG_GEN_PROMPT_EXAMPLES, AUDIO_EXAMPLES, TEXT_CHAT_EXAMPLES_LABELS, IMG_GEN_PROMPT_EXAMPLES_LABELS, AUDIO_EXAMPLES_LABELS, AYA_VISION_PROMPT_EXAMPLES
26
  from preambles import CHAT_PREAMBLE, AUDIO_RESPONSE_PREAMBLE, IMG_DESCRIPTION_PREAMBLE
27
  from constants import LID_LANGUAGES, NEETS_AI_LANGID_MAP, AYA_MODEL_NAME, BATCH_SIZE, USE_ELVENLABS, USE_REPLICATE
28
- from aya_vision_utils import get_aya_vision_response, get_aya_vision_prompt_example
 
 
29
  # from dotenv import load_dotenv
30
 
31
  # load_dotenv()
32
 
 
33
  HF_API_TOKEN = os.getenv("HF_API_KEY")
34
  ELEVEN_LABS_KEY = os.getenv("ELEVEN_LABS_KEY")
35
  NEETS_AI_API_KEY = os.getenv("NEETS_AI_API_KEY")
@@ -62,6 +65,17 @@ eleven_labs_client = ElevenLabs(
62
  api_key=ELEVEN_LABS_KEY,
63
  )
64
 
 
 
 
 
 
 
 
 
 
 
 
65
  # Language identification
66
  lid_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
67
  LID_model = fasttext.load_model(lid_model_path)
@@ -102,20 +116,34 @@ def replicate_api_inference(input_prompt):
102
  image = Image.open(image[0])
103
  return image
104
 
105
- def generate_image(input_prompt, model_id="black-forest-labs/FLUX.1-schnell"):
106
- if input_prompt:
107
  if USE_REPLICATE:
108
  print("using replicate for image generation")
109
- image = replicate_api_inference(input_prompt)
110
  else:
111
  try:
112
  print("using HF inference API for image generation")
113
- image_bytes = get_hf_inference_api_response({ "inputs": input_prompt}, model_id)
114
  image = np.array(Image.open(io.BytesIO(image_bytes)))
115
  except Exception as e:
116
  print("HF API error:", e)
117
  # generate image with help replicate in case of error
118
- image = replicate_api_inference(input_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  return image
120
  else:
121
  return None
@@ -246,7 +274,7 @@ def clean_text(text, remove_bullets=False, remove_newline=False):
246
 
247
  return cleaned_text
248
 
249
- def convert_text_to_speech(text, language="english"):
250
 
251
  # do language detection to determine voice of speech response
252
  if text:
@@ -268,19 +296,28 @@ def convert_text_to_speech(text, language="english"):
268
  else:
269
  # use elevenlabs for TTS
270
  audio_path = elevenlabs_generate_audio(text)
 
 
 
 
 
 
 
 
 
271
 
272
  return audio_path
273
  else:
274
  return None
275
 
276
  def elevenlabs_generate_audio(text):
277
- audio = eleven_labs_client.generate(
278
  text=text,
279
- voice="River",
280
- model="eleven_turbo_v2_5", #"eleven_multilingual_v2"
281
- )
282
- # save audio
283
- audio_path = "./audio.mp3"
284
  save(audio, audio_path)
285
  return audio_path
286
 
@@ -534,7 +571,7 @@ with demo:
534
 
535
  generated_img_desc.change(
536
  generate_image, #run_flux,
537
- inputs=[generated_img_desc],
538
  outputs=[generated_img],
539
  show_progress="full",
540
  )
@@ -558,7 +595,7 @@ with demo:
558
  show_progress="full",
559
  ).then(
560
  convert_text_to_speech,
561
- inputs=[e2e_audio_file_aya_response],
562
  outputs=[e2e_aya_audio_response],
563
  show_progress="full",
564
  )
 
25
  from prompt_examples import TEXT_CHAT_EXAMPLES, IMG_GEN_PROMPT_EXAMPLES, AUDIO_EXAMPLES, TEXT_CHAT_EXAMPLES_LABELS, IMG_GEN_PROMPT_EXAMPLES_LABELS, AUDIO_EXAMPLES_LABELS, AYA_VISION_PROMPT_EXAMPLES
26
  from preambles import CHAT_PREAMBLE, AUDIO_RESPONSE_PREAMBLE, IMG_DESCRIPTION_PREAMBLE
27
  from constants import LID_LANGUAGES, NEETS_AI_LANGID_MAP, AYA_MODEL_NAME, BATCH_SIZE, USE_ELVENLABS, USE_REPLICATE
28
+ from aya_vision_utils import get_aya_vision_response, get_aya_vision_prompt_example, insert_aya_audio, insert_aya_image, connect_with_connector
29
+ from google.cloud import storage
30
+
31
  # from dotenv import load_dotenv
32
 
33
  # load_dotenv()
34
 
35
+
36
  HF_API_TOKEN = os.getenv("HF_API_KEY")
37
  ELEVEN_LABS_KEY = os.getenv("ELEVEN_LABS_KEY")
38
  NEETS_AI_API_KEY = os.getenv("NEETS_AI_API_KEY")
 
65
  api_key=ELEVEN_LABS_KEY,
66
  )
67
 
68
+ BUCKET_NAME = os.getenv("BUCKET_NAME")
69
+ AUDIO_BUCKET = os.getenv("AUDIO_BUCKET")
70
+ IMAGE_STORAGE_PATH = os.getenv("IMAGE_STORAGE_PATH")
71
+ AUDIO_STORAGE_PATH = os.getenv("AUDIO_STORAGE_PATH")
72
+ SAVING_ENABLED = True
73
+
74
+ storage_client = storage.Client()
75
+ bucket = storage_client.bucket(BUCKET_NAME)
76
+ audio_bucket = storage_client.bucket(AUDIO_BUCKET)
77
+ connection = connect_with_connector()
78
+
79
  # Language identification
80
  lid_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
81
  LID_model = fasttext.load_model(lid_model_path)
 
116
  image = Image.open(image[0])
117
  return image
118
 
119
+ def generate_image(input_prompt, generated_img_desc, model_id="black-forest-labs/FLUX.1-schnell"):
120
+ if input_prompt and generated_img_desc:
121
  if USE_REPLICATE:
122
  print("using replicate for image generation")
123
+ image = replicate_api_inference(generated_img_desc)
124
  else:
125
  try:
126
  print("using HF inference API for image generation")
127
+ image_bytes = get_hf_inference_api_response({ "inputs": generated_img_desc}, model_id)
128
  image = np.array(Image.open(io.BytesIO(image_bytes)))
129
  except Exception as e:
130
  print("HF API error:", e)
131
  # generate image with help replicate in case of error
132
+ image = replicate_api_inference(generated_img_desc)
133
+
134
+ # save image to local file
135
+ image_path = "generated_image.png"
136
+ image.save(image_path)
137
+
138
+ if SAVING_ENABLED:
139
+ unique_id = str(uuid.uuid4())
140
+
141
+ blob = bucket.blob(IMAGE_STORAGE_PATH + unique_id + "_" + image_path)
142
+ blob.upload_from_filename(image_path)
143
+ gcp_image_path = f"gs://{BUCKET_NAME}/{IMAGE_STORAGE_PATH}{unique_id}_{image_path}"
144
+
145
+ insert_aya_image(connection, input_prompt, generated_img_desc, gcp_image_path)
146
+
147
  return image
148
  else:
149
  return None
 
274
 
275
  return cleaned_text
276
 
277
+ def convert_text_to_speech(transcript, text, language="english"):
278
 
279
  # do language detection to determine voice of speech response
280
  if text:
 
296
  else:
297
  # use elevenlabs for TTS
298
  audio_path = elevenlabs_generate_audio(text)
299
+
300
+ if SAVING_ENABLED:
301
+ unique_id = str(uuid.uuid4())
302
+
303
+ blob = audio_bucket.blob(AUDIO_STORAGE_PATH + unique_id + "_" + audio_path)
304
+ blob.upload_from_filename(audio_path)
305
+ gcp_audio_path = f"gs://{BUCKET_NAME}/{AUDIO_STORAGE_PATH}{unique_id}_{audio_path}"
306
+
307
+ insert_aya_audio(connection, transcript, text, gcp_audio_path)
308
 
309
  return audio_path
310
  else:
311
  return None
312
 
313
  def elevenlabs_generate_audio(text):
314
+ audio = eleven_labs_client.text_to_speech.convert(
315
  text=text,
316
+ voice_id="21m00Tcm4TlvDq8ikWAM", #Rachel
317
+ model_id="eleven_multilingual_v2",
318
+ output_format="mp3_44100_128",
319
+ )
320
+ audio_path = "audio.mp3"
321
  save(audio, audio_path)
322
  return audio_path
323
 
 
571
 
572
  generated_img_desc.change(
573
  generate_image, #run_flux,
574
+ inputs=[input_img_prompt, generated_img_desc],
575
  outputs=[generated_img],
576
  show_progress="full",
577
  )
 
595
  show_progress="full",
596
  ).then(
597
  convert_text_to_speech,
598
+ inputs=[e2e_audio_file_trans, e2e_audio_file_aya_response],
599
  outputs=[e2e_aya_audio_response],
600
  show_progress="full",
601
  )
aya_vision_utils.py CHANGED
@@ -9,7 +9,10 @@ import os
9
  import traceback
10
  import random
11
  import gradio as gr
12
-
 
 
 
13
  # from dotenv import load_dotenv
14
  # load_dotenv()
15
 
@@ -32,9 +35,6 @@ def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME):
32
 
33
  def get_aya_vision_prompt_example(language):
34
  example = AYA_VISION_PROMPT_EXAMPLES[language]
35
- print("example:", example)
36
- print("example prompt:", example[0])
37
- print("example image:", example[1])
38
  return example[0], example[1]
39
 
40
  def get_base64_from_local_file(file_path):
@@ -42,7 +42,6 @@ def get_base64_from_local_file(file_path):
42
  print("loading image")
43
  with open(file_path, "rb") as image_file:
44
  base64_image = base64.b64encode(image_file.read()).decode('utf-8')
45
- print("converted image")
46
  return base64_image
47
  except Exception as e:
48
  logger.debug(f"Error converting local image to base64 string: {e}")
@@ -50,8 +49,6 @@ def get_base64_from_local_file(file_path):
50
 
51
 
52
  def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
53
- print("incoming message:", incoming_message)
54
- print("image_filepath:", image_filepath)
55
  max_size_bytes = max_size_mb * 1024 * 1024
56
 
57
  image_ext = image_filepath.lower()
@@ -70,7 +67,6 @@ def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
70
  print("converting image to base 64")
71
  base64_image = get_base64_from_local_file(image_filepath)
72
  image = f"data:{image_type};base64,{base64_image}"
73
- print("Image base64:", image[:30])
74
 
75
  # to prevent Cohere API from throwing error for empty message
76
  if incoming_message=="" or incoming_message is None:
@@ -108,4 +104,56 @@ def get_base64_image_size(base64_string):
108
  base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '')
109
  padding = base64_data.count('=')
110
  size_bytes = (len(base64_data) * 3) // 4 - padding
111
- return size_bytes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import traceback
10
  import random
11
  import gradio as gr
12
+ from google.cloud.sql.connector import Connector, IPTypes
13
+ import pg8000
14
+ from datetime import datetime
15
+ import sqlalchemy
16
  # from dotenv import load_dotenv
17
  # load_dotenv()
18
 
 
35
 
36
  def get_aya_vision_prompt_example(language):
37
  example = AYA_VISION_PROMPT_EXAMPLES[language]
 
 
 
38
  return example[0], example[1]
39
 
40
  def get_base64_from_local_file(file_path):
 
42
  print("loading image")
43
  with open(file_path, "rb") as image_file:
44
  base64_image = base64.b64encode(image_file.read()).decode('utf-8')
 
45
  return base64_image
46
  except Exception as e:
47
  logger.debug(f"Error converting local image to base64 string: {e}")
 
49
 
50
 
51
  def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5):
 
 
52
  max_size_bytes = max_size_mb * 1024 * 1024
53
 
54
  image_ext = image_filepath.lower()
 
67
  print("converting image to base 64")
68
  base64_image = get_base64_from_local_file(image_filepath)
69
  image = f"data:{image_type};base64,{base64_image}"
 
70
 
71
  # to prevent Cohere API from throwing error for empty message
72
  if incoming_message=="" or incoming_message is None:
 
104
  base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '')
105
  padding = base64_data.count('=')
106
  size_bytes = (len(base64_data) * 3) // 4 - padding
107
+ return size_bytes
108
+
109
+
110
+ def insert_aya_audio(connection, user_prompt, text_response, audio_response_file_path):
111
+ with connection.begin():
112
+ connection.execute(
113
+ sqlalchemy.text("""
114
+ INSERT INTO aya_audio (user_prompt, text_response, audio_response_file_path, timestamp)
115
+ VALUES (:user_prompt, :text_response, :audio_response_file_path, :timestamp)
116
+ """),
117
+ {"user_prompt": user_prompt, "text_response": text_response, "audio_response_file_path": audio_response_file_path, "timestamp": datetime.now()}
118
+ )
119
+
120
+ def insert_aya_image(connection, user_prompt, generated_img_desc, image_response_file_path):
121
+ with connection.begin():
122
+ connection.execute(
123
+ sqlalchemy.text("""
124
+ INSERT INTO aya_image (user_prompt, generated_img_desc, image_response_file_path, timestamp)
125
+ VALUES (:user_prompt, :generated_img_desc, :image_response_file_path, :timestamp)
126
+ """),
127
+ {"user_prompt": user_prompt, "generated_img_desc": generated_img_desc, "image_response_file_path": image_response_file_path, "timestamp": datetime.now()}
128
+ )
129
+
130
+ def connect_with_connector() -> sqlalchemy.engine.base.Engine:
131
+ instance_connection_name = os.environ[
132
+ "INSTANCE_CONNECTION_NAME"
133
+ ]
134
+ db_user = os.environ["DB_USER"]
135
+ db_pass = os.environ["DB_PASS"]
136
+ db_name = os.environ["DB_NAME"]
137
+
138
+ ip_type = IPTypes.PRIVATE if os.environ.get("PRIVATE_IP") else IPTypes.PUBLIC
139
+
140
+ connector = Connector(refresh_strategy="LAZY")
141
+
142
+ def getconn() -> pg8000.dbapi.Connection:
143
+ conn: pg8000.dbapi.Connection = connector.connect(
144
+ instance_connection_name,
145
+ "pg8000",
146
+ user=db_user,
147
+ password=db_pass,
148
+ db=db_name,
149
+ ip_type=ip_type,
150
+ )
151
+ return conn
152
+
153
+ pool = sqlalchemy.create_engine(
154
+ "postgresql+pg8000://",
155
+ creator=getconn,
156
+ )
157
+
158
+ connection = pool.connect()
159
+ return connection
requirements.txt CHANGED
@@ -10,4 +10,8 @@ groq
10
  replicate
11
  fasttext
12
  cutlet
13
- fugashi[unidic-lite]
 
 
 
 
 
10
  replicate
11
  fasttext
12
  cutlet
13
+ fugashi[unidic-lite]
14
+ python-dotenv
15
+ SQLAlchemy
16
+ google-cloud-storage
17
+ cloud-sql-python-connector[pg8000]