CR7CAD commited on
Commit
fc13d66
·
verified ·
1 Parent(s): 5518670

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -53
app.py CHANGED
@@ -5,14 +5,51 @@ from PIL import Image
5
  import torch
6
  import os
7
  import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # For TTS, try multiple options in order of preference
10
  try:
11
- # Try gTTS first
12
  from gtts import gTTS
13
-
14
- def text2audio(story_text):
15
- # Create a temporary file
 
 
 
 
 
 
 
 
 
16
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
17
  temp_filename = temp_file.name
18
  temp_file.close()
@@ -29,44 +66,34 @@ try:
29
  os.unlink(temp_filename)
30
 
31
  return audio_bytes, 'audio/mp3'
32
-
33
- except ImportError:
34
- st.warning("gTTS not available. Using alternative text-to-speech method.")
35
-
36
- # Define alternative TTS using built-in transformers pipeline
37
- def text2audio(story_text):
38
- # Use a different TTS method
39
- from transformers import pipeline
40
-
41
- # Try a simple TTS model that should work with base transformers
42
- synthesizer = pipeline("text-to-speech", model="facebook/mms-tts-eng")
43
-
44
- # Generate speech
45
- speech = synthesizer(story_text)
46
 
47
  # Return the audio data
48
  if 'audio' in speech:
49
  return speech['audio'], speech.get('sampling_rate', 16000)
50
  elif 'audio_array' in speech:
51
  return speech['audio_array'], speech.get('sampling_rate', 16000)
52
- else:
53
- # In case of failure, return an error message
54
- raise Exception("Failed to generate audio with any available method")
55
 
56
- # Simple image-to-text function
 
57
  def img2text(image):
58
- image_to_text = pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")
59
- text = image_to_text(image)[0]["generated_text"]
60
- return text
61
 
62
  # Helper function to count words
63
  def count_words(text):
64
  return len(text.split())
65
 
66
  # Improved text-to-story function without "Once upon a time" constraint
 
67
  def text2story(text):
68
- generator = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
69
-
70
  # Ask for a story without specifying how to start
71
  prompt = f"""Write a children's story based on this: {text}.
72
  The story should have a clear beginning, middle, and end.
@@ -74,7 +101,7 @@ def text2story(text):
74
  """
75
 
76
  # Generate a longer text to ensure we get a complete story
77
- story_result = generator(
78
  prompt,
79
  max_length=500,
80
  num_return_sequences=1,
@@ -160,8 +187,38 @@ def text2story(text):
160
 
161
  # Basic Streamlit interface
162
  st.title("Image to Audio Story")
163
- uploaded_file = st.file_uploader("Upload an image")
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  if uploaded_file is not None:
166
  # Display image
167
  st.image(uploaded_file, caption="Uploaded Image")
@@ -169,29 +226,45 @@ if uploaded_file is not None:
169
  # Convert to PIL Image
170
  image = Image.open(uploaded_file)
171
 
172
- # Image to Text
173
- with st.spinner("Generating caption..."):
174
- caption = img2text(image)
175
- st.write(f"Caption: {caption}")
176
-
177
- # Text to Story
178
- with st.spinner("Creating story..."):
179
- story = text2story(caption)
180
- # Display word count for transparency
181
- word_count = len(story.split())
182
- st.write(f"Story ({word_count} words):")
183
- st.write(story)
184
-
185
- # Text to Audio
186
- with st.spinner("Generating audio..."):
 
 
 
 
 
 
 
187
  try:
188
- audio_data, audio_format = text2audio(story)
189
-
190
- # Play audio
191
- if isinstance(audio_format, str) and audio_format.startswith('audio/'):
192
- st.audio(audio_data, format=audio_format)
193
- else:
194
- st.audio(audio_data, sample_rate=audio_format)
195
  except Exception as e:
196
- st.error(f"Error generating or playing audio: {e}")
197
- st.info("There was an issue with the text-to-speech conversion.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
  import os
7
  import tempfile
8
+ import time
9
+
10
+ # Use Streamlit's caching mechanisms to optimize model loading
11
+ @st.cache_resource
12
+ def load_image_to_text_pipeline():
13
+ """Load and cache the image-to-text model"""
14
+ return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")
15
+
16
+ @st.cache_resource
17
+ def load_text_generation_pipeline():
18
+ """Load and cache the text generation model"""
19
+ return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
20
+
21
+ @st.cache_resource
22
+ def load_tts_pipeline():
23
+ """Load and cache the text-to-speech pipeline as fallback"""
24
+ try:
25
+ return pipeline("text-to-speech", model="facebook/mms-tts-eng")
26
+ except:
27
+ # Return None if loading fails
28
+ return None
29
+
30
+ # Initialize all models at app startup
31
+ with st.spinner("Loading models (this may take a moment the first time)..."):
32
+ # Load all models at startup and cache them
33
+ img2text_model = load_image_to_text_pipeline()
34
+ story_generator_model = load_text_generation_pipeline()
35
+ tts_fallback_model = load_tts_pipeline()
36
 
37
  # For TTS, try multiple options in order of preference
38
  try:
39
+ # Try importing gTTS
40
  from gtts import gTTS
41
+ has_gtts = True
42
+ except ImportError:
43
+ has_gtts = False
44
+ if tts_fallback_model is None:
45
+ st.warning("No text-to-speech capability available. Audio generation will be disabled.")
46
+
47
+ # Cache the text-to-audio conversion
48
+ @st.cache_data
49
+ def text2audio(story_text):
50
+ """Convert text to audio with caching to avoid regenerating the same audio"""
51
+ if has_gtts:
52
+ # Use gTTS
53
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
54
  temp_filename = temp_file.name
55
  temp_file.close()
 
66
  os.unlink(temp_filename)
67
 
68
  return audio_bytes, 'audio/mp3'
69
+ elif tts_fallback_model is not None:
70
+ # Use transformers TTS
71
+ speech = tts_fallback_model(story_text)
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Return the audio data
74
  if 'audio' in speech:
75
  return speech['audio'], speech.get('sampling_rate', 16000)
76
  elif 'audio_array' in speech:
77
  return speech['audio_array'], speech.get('sampling_rate', 16000)
78
+
79
+ # If we got here, no TTS method worked
80
+ raise Exception("No text-to-speech capability available")
81
 
82
+ # Simple image-to-text function using cached model
83
+ @st.cache_data
84
  def img2text(image):
85
+ """Convert image to text with caching"""
86
+ result = img2text_model(image)
87
+ return result[0]["generated_text"]
88
 
89
  # Helper function to count words
90
  def count_words(text):
91
  return len(text.split())
92
 
93
  # Improved text-to-story function without "Once upon a time" constraint
94
+ @st.cache_data
95
  def text2story(text):
96
+ """Generate a story from text with caching"""
 
97
  # Ask for a story without specifying how to start
98
  prompt = f"""Write a children's story based on this: {text}.
99
  The story should have a clear beginning, middle, and end.
 
101
  """
102
 
103
  # Generate a longer text to ensure we get a complete story
104
+ story_result = story_generator_model(
105
  prompt,
106
  max_length=500,
107
  num_return_sequences=1,
 
187
 
188
  # Basic Streamlit interface
189
  st.title("Image to Audio Story")
 
190
 
191
+ # Add processing status indicator
192
+ status_container = st.empty()
193
+
194
+ # Initialize session state for tracking progress
195
+ if 'progress' not in st.session_state:
196
+ st.session_state.progress = {
197
+ 'caption_generated': False,
198
+ 'story_generated': False,
199
+ 'audio_generated': False,
200
+ 'caption': '',
201
+ 'story': '',
202
+ 'audio_data': None,
203
+ 'audio_format': None
204
+ }
205
+
206
+ # File uploader
207
+ uploaded_file = st.file_uploader("Upload an image", on_change=lambda: reset_progress())
208
+
209
+ # Function to reset progress when a new file is uploaded
210
+ def reset_progress():
211
+ st.session_state.progress = {
212
+ 'caption_generated': False,
213
+ 'story_generated': False,
214
+ 'audio_generated': False,
215
+ 'caption': '',
216
+ 'story': '',
217
+ 'audio_data': None,
218
+ 'audio_format': None
219
+ }
220
+
221
+ # Process the image if uploaded
222
  if uploaded_file is not None:
223
  # Display image
224
  st.image(uploaded_file, caption="Uploaded Image")
 
226
  # Convert to PIL Image
227
  image = Image.open(uploaded_file)
228
 
229
+ # Image to Text (if not already done)
230
+ if not st.session_state.progress['caption_generated']:
231
+ status_container.info("Generating caption...")
232
+ st.session_state.progress['caption'] = img2text(image)
233
+ st.session_state.progress['caption_generated'] = True
234
+
235
+ st.write(f"Caption: {st.session_state.progress['caption']}")
236
+
237
+ # Text to Story (if not already done)
238
+ if not st.session_state.progress['story_generated']:
239
+ status_container.info("Creating story...")
240
+ st.session_state.progress['story'] = text2story(st.session_state.progress['caption'])
241
+ st.session_state.progress['story_generated'] = True
242
+
243
+ # Display word count for transparency
244
+ word_count = count_words(st.session_state.progress['story'])
245
+ st.write(f"Story ({word_count} words):")
246
+ st.write(st.session_state.progress['story'])
247
+
248
+ # Pre-generate audio in background (if not already done)
249
+ if not st.session_state.progress['audio_generated'] and (has_gtts or tts_fallback_model is not None):
250
+ status_container.info("Pre-generating audio in background...")
251
  try:
252
+ st.session_state.progress['audio_data'], st.session_state.progress['audio_format'] = text2audio(st.session_state.progress['story'])
253
+ st.session_state.progress['audio_generated'] = True
254
+ status_container.success("Ready to play audio!")
 
 
 
 
255
  except Exception as e:
256
+ status_container.error(f"Error pre-generating audio: {e}")
257
+
258
+ # Button to play audio
259
+ if st.button("Play the audio"):
260
+ if st.session_state.progress['audio_generated']:
261
+ # Display the audio player
262
+ if isinstance(st.session_state.progress['audio_format'], str) and st.session_state.progress['audio_format'].startswith('audio/'):
263
+ st.audio(st.session_state.progress['audio_data'], format=st.session_state.progress['audio_format'])
264
+ else:
265
+ st.audio(st.session_state.progress['audio_data'], sample_rate=st.session_state.progress['audio_format'])
266
+ else:
267
+ # Handle case where audio generation failed or is not available
268
+ st.error("Unable to play audio. Audio generation was not successful.")
269
+ else:
270
+ status_container.info("Upload an image to begin")