CR7CAD commited on
Commit
ab8ead3
·
verified ·
1 Parent(s): 1ebc71c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -80
app.py CHANGED
@@ -1,68 +1,72 @@
1
- # import part - only using the two requested imports
2
  import streamlit as st
3
  from transformers import pipeline
 
4
 
5
- # function part
6
- # img2text
7
- def img2text(image_path):
8
- image_to_text = pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")
9
- text = image_to_text(image_path)[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  return text
11
 
12
- # text2story - IMPROVED to end naturally
13
  def text2story(text):
14
- # Using a smaller text generation model
15
- generator = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
16
 
17
- # Create a prompt for the story generation
18
- prompt = f"Write a fun children's story based on this: {text}. The story should be short and end naturally with a conclusion. Once upon a time, "
19
 
20
- # Generate the story
21
  story_result = generator(
22
  prompt,
23
- max_length=250, # Increased to allow for a complete story
24
  num_return_sequences=1,
25
  temperature=0.7,
26
- top_k=50,
27
- top_p=0.95,
28
  do_sample=True
29
  )
30
 
31
- # Extract the generated text
32
  story_text = story_result[0]['generated_text']
33
  story_text = story_text.replace(prompt, "Once upon a time, ")
34
 
35
- # Find a natural ending point (end of sentence) before 100 words
36
- words = story_text.split()
37
- if len(words) > 100:
38
- # Join the first 100 words
39
- shortened_text = " ".join(words[:100])
40
-
41
- # Find the last complete sentence
42
- last_period = shortened_text.rfind('.')
43
- last_question = shortened_text.rfind('?')
44
- last_exclamation = shortened_text.rfind('!')
45
-
46
- # Find the last sentence ending punctuation
47
- last_end = max(last_period, last_question, last_exclamation)
48
-
49
- if last_end > 0:
50
- # Truncate at the end of the last complete sentence
51
- story_text = shortened_text[:last_end + 1]
52
- else:
53
- # If no sentence ending found, just use the shortened text
54
- story_text = shortened_text
55
 
56
  return story_text
57
 
58
- # text2audio - Using HelpingAI-TTS-v1 model
59
  def text2audio(story_text):
60
  try:
61
- # Use the HelpingAI TTS model as requested
62
- synthesizer = pipeline("text-to-speech", model="HelpingAI/HelpingAI-TTS-v1")
63
 
64
- # Limit text length to avoid timeouts
65
- max_chars = 500
66
  if len(story_text) > max_chars:
67
  last_period = story_text[:max_chars].rfind('.')
68
  if last_period > 0:
@@ -72,46 +76,57 @@ def text2audio(story_text):
72
 
73
  # Generate speech
74
  speech = synthesizer(story_text)
75
-
76
- # Get output information
77
- st.write(f"Speech output keys: {list(speech.keys())}")
78
-
79
  return speech
80
 
81
  except Exception as e:
82
  st.error(f"Error generating audio: {str(e)}")
83
  return None
84
 
85
- # main part
86
- st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
87
- st.header("Turn Your Image to Audio Story")
88
- uploaded_file = st.file_uploader("Select an Image...")
89
 
90
- if uploaded_file is not None:
91
- # Display the uploaded image
92
- st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
93
-
94
- # Create a temporary file in memory from the uploaded file
95
- image_bytes = uploaded_file.getvalue()
96
 
97
- # Stage 1: Image to Text
98
- st.text('Processing img2text...')
99
- caption = img2text(image_bytes) # Pass bytes directly to pipeline
100
- st.write(caption)
 
101
 
102
- # Stage 2: Text to Story
103
- st.text('Generating a story...')
104
- story = text2story(caption)
105
- st.write(story)
106
 
107
- # Stage 3: Story to Audio data
108
- st.text('Generating audio data...')
109
- speech_output = text2audio(story)
110
 
111
- # Play button
112
- if st.button("Play Audio"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  if speech_output is not None:
114
- # Try to play the audio directly
115
  try:
116
  if 'audio' in speech_output and 'sampling_rate' in speech_output:
117
  st.audio(speech_output['audio'], sample_rate=speech_output['sampling_rate'])
@@ -120,21 +135,15 @@ if uploaded_file is not None:
120
  elif 'waveform' in speech_output and 'sample_rate' in speech_output:
121
  st.audio(speech_output['waveform'], sample_rate=speech_output['sample_rate'])
122
  else:
123
- # Try the first array-like value as audio data
124
  for key, value in speech_output.items():
125
  if hasattr(value, '__len__') and len(value) > 1000:
126
- if 'rate' in speech_output:
127
- st.audio(value, sample_rate=speech_output['rate'])
128
- elif 'sample_rate' in speech_output:
129
- st.audio(value, sample_rate=speech_output['sample_rate'])
130
- elif 'sampling_rate' in speech_output:
131
- st.audio(value, sample_rate=speech_output['sampling_rate'])
132
- else:
133
- st.audio(value, sample_rate=24000) # Default sample rate
134
  break
135
  else:
136
- st.error(f"Could not find compatible audio format in: {list(speech_output.keys())}")
137
  except Exception as e:
138
  st.error(f"Error playing audio: {str(e)}")
139
  else:
140
- st.error("Audio generation failed. Please try again.")
 
1
+ # import part
2
  import streamlit as st
3
  from transformers import pipeline
4
+ from PIL import Image
5
 
6
+ # Set global caching options for Transformers
7
+ from transformers import set_caching_enabled
8
+ set_caching_enabled(True)
9
+
10
+ # function part with caching for better performance
11
+ @st.cache_resource
12
+ def load_image_captioning_model():
13
+ return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")
14
+
15
+ @st.cache_resource
16
+ def load_text_generator():
17
+ return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
18
+
19
+ @st.cache_resource
20
+ def load_tts_model():
21
+ return pipeline("text-to-speech", model="HelpingAI/HelpingAI-TTS-v1")
22
+
23
+ # img2text - Using the original model with more constraints
24
+ def img2text(image):
25
+ # Load the model (cached)
26
+ image_to_text = load_image_captioning_model()
27
+
28
+ # Strongly limit output length for speed
29
+ text = image_to_text(image, max_new_tokens=15)[0]["generated_text"]
30
  return text
31
 
32
+ # text2story - Much more constrained for speed
33
  def text2story(text):
34
+ # Load the model (cached)
35
+ generator = load_text_generator()
36
 
37
+ # Very brief prompt to minimize work
38
+ prompt = f"Short story about {text}: Once upon a time, "
39
 
40
+ # Very constrained parameters for maximum speed
41
  story_result = generator(
42
  prompt,
43
+ max_new_tokens=60, # Much shorter output
44
  num_return_sequences=1,
45
  temperature=0.7,
46
+ top_k=10, # Lower value = faster
47
+ top_p=0.9, # Lower value = faster
48
  do_sample=True
49
  )
50
 
51
+ # Extract and clean text
52
  story_text = story_result[0]['generated_text']
53
  story_text = story_text.replace(prompt, "Once upon a time, ")
54
 
55
+ # Find a natural ending point
56
+ last_period = story_text.rfind('.')
57
+ if last_period > 30: # Ensure we have at least some content
58
+ story_text = story_text[:last_period + 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  return story_text
61
 
62
+ # text2audio - Minimal text for faster processing
63
  def text2audio(story_text):
64
  try:
65
+ # Load the model (cached)
66
+ synthesizer = load_tts_model()
67
 
68
+ # Aggressively limit text length to speed up TTS
69
+ max_chars = 200 # Much shorter than before
70
  if len(story_text) > max_chars:
71
  last_period = story_text[:max_chars].rfind('.')
72
  if last_period > 0:
 
76
 
77
  # Generate speech
78
  speech = synthesizer(story_text)
 
 
 
 
79
  return speech
80
 
81
  except Exception as e:
82
  st.error(f"Error generating audio: {str(e)}")
83
  return None
84
 
85
+ # Streamlined main UI
86
+ st.set_page_config(page_title="Image to Story", page_icon="📚")
87
+ st.header("Image to Audio Story")
 
88
 
89
+ # Add info about processing time
90
+ st.info("Note: Processing may take some time as the models are loading. Please be patient.")
91
+
92
+ # Cache the file uploader state
93
+ if "uploaded_file" not in st.session_state:
94
+ st.session_state["uploaded_file"] = None
95
 
96
+ uploaded_file = st.file_uploader("Select an Image...", key="file_uploader")
97
+
98
+ # Process the image if uploaded
99
+ if uploaded_file is not None:
100
+ st.session_state["uploaded_file"] = uploaded_file
101
 
102
+ # Display the uploaded image
103
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
 
 
104
 
105
+ # Convert to PIL image
106
+ image = Image.open(uploaded_file)
 
107
 
108
+ # Optional processing toggle to let user decide
109
+ if st.button("Generate Story and Audio"):
110
+ col1, col2 = st.columns(2)
111
+
112
+ # Stage 1: Image to Text with minimal output
113
+ with col1:
114
+ with st.spinner('Captioning image...'):
115
+ caption = img2text(image)
116
+ st.write(f"**Caption:** {caption}")
117
+
118
+ # Stage 2: Text to Story with minimal length
119
+ with col2:
120
+ with st.spinner('Creating story...'):
121
+ story = text2story(caption)
122
+ st.write(f"**Story:** {story}")
123
+
124
+ # Stage 3: Audio with minimal text
125
+ with st.spinner('Generating audio...'):
126
+ speech_output = text2audio(story)
127
+
128
+ # Display audio immediately
129
  if speech_output is not None:
 
130
  try:
131
  if 'audio' in speech_output and 'sampling_rate' in speech_output:
132
  st.audio(speech_output['audio'], sample_rate=speech_output['sampling_rate'])
 
135
  elif 'waveform' in speech_output and 'sample_rate' in speech_output:
136
  st.audio(speech_output['waveform'], sample_rate=speech_output['sample_rate'])
137
  else:
138
+ # Try any array-like data
139
  for key, value in speech_output.items():
140
  if hasattr(value, '__len__') and len(value) > 1000:
141
+ sample_rate = speech_output.get('sampling_rate', speech_output.get('sample_rate', 24000))
142
+ st.audio(value, sample_rate=sample_rate)
 
 
 
 
 
 
143
  break
144
  else:
145
+ st.error("Could not find audio data in the output")
146
  except Exception as e:
147
  st.error(f"Error playing audio: {str(e)}")
148
  else:
149
+ st.error("Audio generation failed")