frankai98's picture
Update app.py
29d3b0a verified
raw
history blame
8.01 kB
import nest_asyncio
nest_asyncio.apply()
import streamlit as st
from transformers import pipeline
import torch
from gtts import gTTS
import io
import time
import asyncio
import datetime
if not asyncio.get_event_loop().is_running():
asyncio.set_event_loop(asyncio.new_event_loop())
# Initialize session state
if 'processed_data' not in st.session_state:
st.session_state.processed_data = {
'scenario': None,
'story': None,
'audio': None
}
if 'image_data' not in st.session_state:
st.session_state.image_data = None
if 'timer_start_time' not in st.session_state:
st.session_state.timer_start_time = None
if 'timer_frozen' not in st.session_state:
st.session_state.timer_frozen = False
if 'last_update_time' not in st.session_state:
st.session_state.last_update_time = None
# Page setup
st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
st.header("Turn Your Image to a Short Audio Story for Children")
# Model loading
@st.cache_resource
def load_models():
return {
"img_model": pipeline("image-to-text", "cnmoro/tiny-image-captioning"),
"story_model": pipeline("text-generation", "Qwen/Qwen2.5-0.5B-Instruct")
}
models = load_models()
# Processing functions
def img2text(url):
return models["img_model"](url)[0]["generated_text"]
def text2story(text):
prompt = f"Generate a 100-word story about: {text}"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
response = models["story_model"](
messages,
max_new_tokens=100,
do_sample=True,
temperature=0.7
)[0]["generated_text"]
return response[2]["content"]
def text2audio(story_text):
audio_io = io.BytesIO()
tts = gTTS(text=story_text, lang='en', slow=False)
tts.write_to_fp(audio_io)
audio_io.seek(0)
return {'audio': audio_io, 'sampling_rate': 16000}
# Create fixed containers for UI elements
image_container = st.empty()
timer_container = st.empty()
status_container = st.empty()
progress_container = st.empty()
results_container = st.container()
# Get current timer value
def get_formatted_timer():
if st.session_state.timer_start_time is None:
return "00:00"
current_time = time.time()
if st.session_state.timer_frozen:
# Use the last update time if timer is frozen
elapsed_seconds = int(st.session_state.last_update_time - st.session_state.timer_start_time)
else:
elapsed_seconds = int(current_time - st.session_state.timer_start_time)
# Update the last update time
st.session_state.last_update_time = current_time
minutes = elapsed_seconds // 60
seconds = elapsed_seconds % 60
return f"{minutes:02d}:{seconds:02d}"
# UI components
uploaded_file = st.file_uploader("Select an Image After the Models are Loaded...")
# Always display the image if we have image data
if st.session_state.image_data is not None:
image_container.image(st.session_state.image_data, caption="Uploaded Image", use_container_width=True)
# Display timer - update the display based on current state
current_time_str = get_formatted_timer()
if st.session_state.timer_frozen:
timer_container.markdown(f"<div style='font-size:16px;color:#00cc00;font-weight:bold;margin-bottom:10px;'>⏱️ Elapsed: {current_time_str} βœ“</div>", unsafe_allow_html=True)
else:
timer_container.markdown(f"<div style='font-size:16px;color:#666;margin-bottom:10px;'>⏱️ Elapsed: {current_time_str}</div>", unsafe_allow_html=True)
# Process new uploaded file
if uploaded_file is not None:
# Save the image data to session state
bytes_data = uploaded_file.getvalue()
st.session_state.image_data = bytes_data
# Display the image
image_container.image(bytes_data, caption="Uploaded Image", use_container_width=True)
if st.session_state.get('current_file') != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
# Reset and start timer
st.session_state.timer_start_time = time.time()
st.session_state.last_update_time = time.time()
st.session_state.timer_frozen = False
# Progress indicators
status_text = status_container.empty()
progress_bar = progress_container.progress(0)
try:
# Save uploaded file
with open(uploaded_file.name, "wb") as file:
file.write(bytes_data)
# Stage 1: Image to Text
status_text.markdown("**πŸ–ΌοΈ Generating caption...**")
progress_bar.progress(0)
st.session_state.processed_data['scenario'] = img2text(uploaded_file.name)
progress_bar.progress(33)
# Stage 2: Text to Story
status_text.markdown("**πŸ“– Generating story...**")
progress_bar.progress(33)
st.session_state.processed_data['story'] = text2story(
st.session_state.processed_data['scenario']
)
progress_bar.progress(66)
# Stage 3: Story to Audio
status_text.markdown("**πŸ”Š Synthesizing audio...**")
progress_bar.progress(66)
st.session_state.processed_data['audio'] = text2audio(
st.session_state.processed_data['story']
)
progress_bar.progress(100)
# Final status
status_text.success("**βœ… Generation complete!**")
# Show results
with results_container:
st.write("**Caption:**", st.session_state.processed_data['scenario'])
st.write("**Story:**", st.session_state.processed_data['story'])
except Exception as e:
status_text.error(f"**❌ Error:** {str(e)}")
progress_bar.empty()
raise e
# Display results if available
if st.session_state.processed_data.get('scenario'):
with results_container:
st.write("**Caption:**", st.session_state.processed_data['scenario'])
if st.session_state.processed_data.get('story'):
with results_container:
st.write("**Story:**", st.session_state.processed_data['story'])
# Audio playback - this will freeze the timer
if st.button("Play Audio of the Story Generated"):
if st.session_state.processed_data.get('audio'):
# Make sure the image is still displayed
if st.session_state.image_data is not None:
image_container.image(st.session_state.image_data, caption="Uploaded Image", use_container_width=True)
# Freeze the timer
st.session_state.timer_frozen = True
# Update the timer display with frozen styling
final_time = get_formatted_timer()
timer_container.markdown(f"<div style='font-size:16px;color:#00cc00;font-weight:bold;margin-bottom:10px;'>⏱️ Elapsed: {final_time} βœ“</div>", unsafe_allow_html=True)
# Play the audio
audio_data = st.session_state.processed_data['audio']
st.audio(
audio_data['audio'].getvalue(),
format="audio/mp3"
)
else:
st.warning("Please generate a story first!")
# Force a rerun every second while the timer is active (not frozen)
if st.session_state.timer_start_time is not None and not st.session_state.timer_frozen:
# Create a placeholder for our hidden component that triggers the rerun
rerun_trigger = st.empty()
# Add a hidden element that will automatically trigger a rerun after 0.5 seconds
rerun_trigger.markdown(
f"""
<div style="display:none;">
<script>
setTimeout(function() {{
window.parent.postMessage({{type: "streamlit:rerun"}}, "*");
}}, 500);
</script>
</div>
""",
unsafe_allow_html=True
)