frankai98's picture
Update app.py
dd1055c verified
raw
history blame
5.16 kB
import streamlit as st
from transformers import pipeline
import torch
from gtts import gTTS
import io
import time
from streamlit.components.v1 import html
# Initialize session state
if 'processed_data' not in st.session_state:
st.session_state.processed_data = {
'scenario': None,
'story': None,
'audio': None
}
# JavaScript timer component
def timer():
return """
<div id="timer" style="font-size:16px;color:#666;margin-bottom:10px;">⏱️ Elapsed: 00:00</div>
<script>
function updateTimer() {
var start = Date.now();
var timerElement = document.getElementById('timer');
var interval = setInterval(function() {
var elapsed = Date.now() - start;
var minutes = Math.floor(elapsed / 60000);
var seconds = Math.floor((elapsed % 60000) / 1000);
timerElement.innerHTML = '⏱️ Elapsed: ' +
(minutes < 10 ? '0' : '') + minutes + ':' +
(seconds < 10 ? '0' : '') + seconds;
}, 1000);
// Cleanup when component is removed
return function() {
clearInterval(interval);
}
}
var cleanup = updateTimer();
// Handle Streamlit's component cleanup
window.addEventListener('beforeunload', function() {
cleanup();
});
</script>
"""
# Page setup
st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
st.header("Turn Your Image to a Short Audio Story for Children")
# Model loading
@st.cache_resource
def load_models():
return {
"img_model": pipeline("image-to-text", "cnmoro/tiny-image-captioning"),
"story_model": pipeline("text-generation", "Qwen/Qwen2.5-0.5B-Instruct")
}
models = load_models()
# Processing functions
def img2text(url):
return models["img_model"](url)[0]["generated_text"]
def text2story(text):
prompt = f"Generate a brief 100-word story about: {text}"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
response = models["story_model"](
messages,
max_new_tokens=128,
do_sample=True,
temperature=0.7
)[0]["generated_text"]
return response[2]["content"]
def text2audio(story_text):
audio_io = io.BytesIO()
tts = gTTS(text=story_text, lang='en', slow=False)
tts.write_to_fp(audio_io)
audio_io.seek(0)
return {'audio': audio_io, 'sampling_rate': 16000}
# UI components
uploaded_file = st.file_uploader("Select an Image After the Models are Loaded...")
if uploaded_file is not None:
# Initialize progress containers
status_text = st.empty()
progress_bar = st.progress(0)
# Start JavaScript timer
html(timer(), height=50)
try:
# Save uploaded file
bytes_data = uploaded_file.getvalue()
with open(uploaded_file.name, "wb") as file:
file.write(bytes_data)
if st.session_state.get('current_file') != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
# Display image
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
# Stage 1: Image to Text
status_text.markdown("**πŸ–ΌοΈ Analyzing image...**")
progress_bar.progress(0)
st.session_state.processed_data['scenario'] = img2text(uploaded_file.name)
progress_bar.progress(33)
# Stage 2: Text to Story
status_text.markdown("**πŸ“– Generating story...**")
progress_bar.progress(33)
st.session_state.processed_data['story'] = text2story(
st.session_state.processed_data['scenario']
)
progress_bar.progress(66)
# Stage 3: Story to Audio
status_text.markdown("**πŸ”Š Synthesizing audio...**")
progress_bar.progress(66)
st.session_state.processed_data['audio'] = text2audio(
st.session_state.processed_data['story']
)
progress_bar.progress(100)
# Final status
status_text.success("**βœ… Generation complete!**")
html("<script>document.getElementById('timer').style.color = '#00cc00';</script>")
# Show results
st.subheader("Results")
st.write("**Caption:**", st.session_state.processed_data['scenario'])
st.write("**Story:**", st.session_state.processed_data['story'])
except Exception as e:
html("<script>document.getElementById('timer').remove();</script>")
status_text.error(f"**❌ Error:** {str(e)}")
progress_bar.empty()
raise e
finally:
pass # Timer cleanup handled by JavaScript
# Audio playback
if st.button("Play Audio of the Story Generated"):
if st.session_state.processed_data.get('audio'):
audio_data = st.session_state.processed_data['audio']
st.audio(
audio_data['audio'].getvalue(),
format="audio/mp3"
)
else:
st.warning("Please generate a story first!")