mgbam commited on
Commit
247b2e3
Β·
verified Β·
1 Parent(s): 59e152e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -49
app.py CHANGED
@@ -1,8 +1,9 @@
1
- # Copyright 2025 Google LLC. Based on work by Yousif Ahmed.
 
2
  # Concept: ChronoWeave – Branching Narrative Generation
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
 
7
  import streamlit as st
8
  import google.generativeai as genai
@@ -35,11 +36,14 @@ import typing_extensions as typing
35
  import nest_asyncio
36
  nest_asyncio.apply()
37
 
38
- # Import Vertex AI SDK for image generation (Preview API)
39
  import vertexai
40
  from vertexai.preview.vision_models import ImageGenerationModel
41
  from google.oauth2 import service_account
42
 
 
 
 
43
  # --- Logging Setup ---
44
  logging.basicConfig(
45
  level=logging.INFO,
@@ -59,7 +63,7 @@ Generate multiple, branching story timelines from a single theme using AI, compl
59
  TEXT_MODEL_ID = "models/gemini-1.5-flash"
60
  AUDIO_MODEL_ID = "models/gemini-1.5-flash"
61
  AUDIO_SAMPLING_RATE = 24000
62
- # Pretrained Imagen model identifier for Vertex AI preview
63
  IMAGE_MODEL_ID = "imagen-3.0-generate-002"
64
  DEFAULT_ASPECT_RATIO = "1:1"
65
  VIDEO_FPS = 24
@@ -80,14 +84,14 @@ except KeyError:
80
  st.error("🚨 **Google API Key Not Found!** Please configure it.", icon="🚨")
81
  st.stop()
82
 
83
- # Vertex AI configuration: load PROJECT_ID and LOCATION from secrets or environment.
84
  PROJECT_ID = st.secrets.get("PROJECT_ID") or os.environ.get("PROJECT_ID")
85
  LOCATION = st.secrets.get("LOCATION") or os.environ.get("LOCATION", "us-central1")
86
  if not PROJECT_ID:
87
  st.error("🚨 **PROJECT_ID not set!** Please add PROJECT_ID to your secrets.", icon="🚨")
88
  st.stop()
89
 
90
- # Load service account JSON from environment (secret name: SERVICE_ACCOUNT_JSON)
91
  try:
92
  service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"])
93
  credentials = service_account.Credentials.from_service_account_info(service_account_info)
@@ -170,46 +174,26 @@ def wave_file_writer(filename: str, channels: int = 1, rate: int = AUDIO_SAMPLIN
170
  except Exception as e_close:
171
  logger.error(f"Error closing wave file {filename}: {e_close}")
172
 
 
 
173
  async def generate_audio_live_async(api_text: str, output_filename: str, voice: Optional[str] = None) -> Optional[str]:
174
- """Generates audio using Gemini Live API (async version) via the GenerativeModel."""
175
- collected_audio = bytearray()
 
 
176
  task_id = os.path.basename(output_filename).split('.')[0]
177
- logger.info(f"πŸŽ™οΈ [{task_id}] Requesting audio: '{api_text[:60]}...'")
178
  try:
179
- config = {
180
- "response_modalities": ["AUDIO"],
181
- "audio_encoding": "LINEAR16",
182
- "sample_rate_hertz": AUDIO_SAMPLING_RATE,
183
- }
184
- directive_prompt = f"Narrate directly: \"{api_text}\""
185
- async with live_model.connect(config=config) as session:
186
- await session.send_request([directive_prompt])
187
- async for response in session.stream_content():
188
- if response.audio_chunk and response.audio_chunk.data:
189
- collected_audio.extend(response.audio_chunk.data)
190
- if hasattr(response, 'error') and response.error:
191
- logger.error(f"❌ [{task_id}] Audio stream error: {response.error}")
192
- st.error(f"Audio stream error {task_id}: {response.error}", icon="πŸ”Š")
193
- return None
194
- if not collected_audio:
195
- logger.warning(f"⚠️ [{task_id}] No audio data received.")
196
- st.warning(f"No audio data for {task_id}.", icon="πŸ”Š")
197
- return None
198
- with wave_file_writer(output_filename, rate=AUDIO_SAMPLING_RATE) as wf:
199
- wf.writeframes(bytes(collected_audio))
200
- logger.info(f"βœ… [{task_id}] Audio saved: {os.path.basename(output_filename)} ({len(collected_audio)} bytes)")
201
- return output_filename
202
- except genai.types.generation_types.BlockedPromptException as bpe:
203
- logger.error(f"❌ [{task_id}] Audio blocked: {bpe}")
204
- st.error(f"Audio blocked {task_id}.", icon="πŸ”‡")
205
- return None
206
- except TypeError as te:
207
- logger.exception(f"❌ [{task_id}] Audio config TypeError: {te}")
208
- st.error(f"Audio config error {task_id} (TypeError): {te}. Check library/config.", icon="βš™οΈ")
209
- return None
210
  except Exception as e:
211
- logger.exception(f"❌ [{task_id}] Audio failed: {e}")
212
- st.error(f"Audio failed {task_id}: {e}", icon="πŸ”Š")
213
  return None
214
 
215
  def generate_story_sequence_chrono(theme: str, num_scenes: int, num_timelines: int, divergence_prompt: str = "") -> Optional[ChronoWeaveResponse]:
@@ -274,7 +258,7 @@ def generate_image_imagen(prompt: str, aspect_ratio: str = "1:1", task_id: str =
274
  """
275
  Generates an image using Vertex AI's Imagen model via the Vertex AI preview API.
276
 
277
- This function loads the pretrained Imagen model "imagen-3.0-generate-002" and attempts to generate an image.
278
  If authentication fails, it provides guidance on how to resolve the issue.
279
  """
280
  logger.info(f"πŸ–ΌοΈ [{task_id}] Requesting image: '{prompt[:70]}...' (Aspect: {aspect_ratio})")
@@ -297,11 +281,9 @@ def generate_image_imagen(prompt: str, aspect_ratio: str = "1:1", task_id: str =
297
  if "Unable to authenticate" in error_str:
298
  error_msg = (
299
  "Authentication error: Unable to authenticate your request. "
300
- "If running locally, please run `!gcloud auth login`. "
301
- "If running in Colab, try:\n"
302
- " from google.colab import auth\n"
303
- " auth.authenticate_user()\n"
304
- "If using a service account or other environment, please refer to https://cloud.google.com/docs/authentication for guidance."
305
  )
306
  else:
307
  error_msg = f"Image generation for {task_id} failed: {e}"
@@ -400,6 +382,7 @@ if generate_button:
400
  generated_audio_path: Optional[str] = None
401
  if not scene_has_error:
402
  with st.spinner(f"[{task_id}] Generating audio... πŸ”Š"):
 
403
  audio_path_temp = os.path.join(temp_dir, f"{task_id}_audio.wav")
404
  try:
405
  generated_audio_path = asyncio.run(generate_audio_live_async(segment.audio_text, audio_path_temp, audio_voice))
@@ -417,7 +400,7 @@ if generate_button:
417
  temp_audio_files[scene_id] = generated_audio_path
418
  try:
419
  with open(generated_audio_path, 'rb') as ap:
420
- st.audio(ap.read(), format='audio/wav')
421
  except Exception as e:
422
  logger.warning(f"⚠️ [{task_id}] Audio preview error: {e}")
423
  else:
 
1
+ # Copyright 2025 Google LLC.
2
+ # Based on work by Yousif Ahmed.
3
  # Concept: ChronoWeave – Branching Narrative Generation
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at: https://www.apache.org/licenses/LICENSE-2.0
7
 
8
  import streamlit as st
9
  import google.generativeai as genai
 
36
  import nest_asyncio
37
  nest_asyncio.apply()
38
 
39
+ # Import Vertex AI SDK and service account credentials support
40
  import vertexai
41
  from vertexai.preview.vision_models import ImageGenerationModel
42
  from google.oauth2 import service_account
43
 
44
+ # Import gTTS for audio generation
45
+ from gtts import gTTS
46
+
47
  # --- Logging Setup ---
48
  logging.basicConfig(
49
  level=logging.INFO,
 
63
  TEXT_MODEL_ID = "models/gemini-1.5-flash"
64
  AUDIO_MODEL_ID = "models/gemini-1.5-flash"
65
  AUDIO_SAMPLING_RATE = 24000
66
+ # Pretrained Imagen model identifier
67
  IMAGE_MODEL_ID = "imagen-3.0-generate-002"
68
  DEFAULT_ASPECT_RATIO = "1:1"
69
  VIDEO_FPS = 24
 
84
  st.error("🚨 **Google API Key Not Found!** Please configure it.", icon="🚨")
85
  st.stop()
86
 
87
+ # Vertex AI configuration: PROJECT_ID and LOCATION
88
  PROJECT_ID = st.secrets.get("PROJECT_ID") or os.environ.get("PROJECT_ID")
89
  LOCATION = st.secrets.get("LOCATION") or os.environ.get("LOCATION", "us-central1")
90
  if not PROJECT_ID:
91
  st.error("🚨 **PROJECT_ID not set!** Please add PROJECT_ID to your secrets.", icon="🚨")
92
  st.stop()
93
 
94
+ # Load service account JSON from the secret
95
  try:
96
  service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"])
97
  credentials = service_account.Credentials.from_service_account_info(service_account_info)
 
174
  except Exception as e_close:
175
  logger.error(f"Error closing wave file {filename}: {e_close}")
176
 
177
+ # --- Audio Generation using gTTS ---
178
+ # We replace the previous failing method with gTTS.
179
  async def generate_audio_live_async(api_text: str, output_filename: str, voice: Optional[str] = None) -> Optional[str]:
180
+ """
181
+ Generates audio using gTTS (Google Text-to-Speech).
182
+ Saves an MP3 file; MoviePy supports MP3 playback.
183
+ """
184
  task_id = os.path.basename(output_filename).split('.')[0]
185
+ logger.info(f"πŸŽ™οΈ [{task_id}] Generating audio via gTTS for text: '{api_text[:60]}...'")
186
  try:
187
+ # Generate audio using gTTS
188
+ tts = gTTS(text=api_text, lang="en")
189
+ # Replace .wav with .mp3
190
+ mp3_filename = output_filename.replace(".wav", ".mp3")
191
+ tts.save(mp3_filename)
192
+ logger.info(f"βœ… [{task_id}] Audio saved: {os.path.basename(mp3_filename)}")
193
+ return mp3_filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  except Exception as e:
195
+ logger.exception(f"❌ [{task_id}] Audio generation error: {e}")
196
+ st.error(f"Audio generation failed for {task_id}: {e}", icon="πŸ”Š")
197
  return None
198
 
199
  def generate_story_sequence_chrono(theme: str, num_scenes: int, num_timelines: int, divergence_prompt: str = "") -> Optional[ChronoWeaveResponse]:
 
258
  """
259
  Generates an image using Vertex AI's Imagen model via the Vertex AI preview API.
260
 
261
+ This function loads the pretrained Imagen model "imagen-3.0-generate-002" and generates an image.
262
  If authentication fails, it provides guidance on how to resolve the issue.
263
  """
264
  logger.info(f"πŸ–ΌοΈ [{task_id}] Requesting image: '{prompt[:70]}...' (Aspect: {aspect_ratio})")
 
281
  if "Unable to authenticate" in error_str:
282
  error_msg = (
283
  "Authentication error: Unable to authenticate your request. "
284
+ "Ensure your service account JSON is loaded correctly. "
285
+ "For example, on Hugging Face Spaces, set SERVICE_ACCOUNT_JSON in your repository secrets. "
286
+ "If running locally, run `!gcloud auth login`."
 
 
287
  )
288
  else:
289
  error_msg = f"Image generation for {task_id} failed: {e}"
 
382
  generated_audio_path: Optional[str] = None
383
  if not scene_has_error:
384
  with st.spinner(f"[{task_id}] Generating audio... πŸ”Š"):
385
+ # Change output extension to .wav for consistency, but gTTS returns MP3
386
  audio_path_temp = os.path.join(temp_dir, f"{task_id}_audio.wav")
387
  try:
388
  generated_audio_path = asyncio.run(generate_audio_live_async(segment.audio_text, audio_path_temp, audio_voice))
 
400
  temp_audio_files[scene_id] = generated_audio_path
401
  try:
402
  with open(generated_audio_path, 'rb') as ap:
403
+ st.audio(ap.read(), format='audio/mp3')
404
  except Exception as e:
405
  logger.warning(f"⚠️ [{task_id}] Audio preview error: {e}")
406
  else: