|
import streamlit as st |
|
from PIL import Image |
|
from transformers import pipeline |
|
from gtts import gTTS |
|
import tempfile |
|
import os |
|
|
|
|
|
st.set_page_config( |
|
page_title="Storyteller for Kids", |
|
page_icon="📚", |
|
layout="centered", |
|
initial_sidebar_state="collapsed" |
|
) |
|
st.title("🖼️➡️📖 Interactive Storyteller") |
|
|
|
|
|
@st.cache_resource |
|
def load_pipelines(): |
|
|
|
captioner = pipeline( |
|
"image-to-text", |
|
model="Salesforce/blip-image-captioning-base", |
|
max_new_tokens=50 |
|
) |
|
|
|
|
|
storyteller = pipeline( |
|
"text2text-generation", |
|
model="google/flan-t5-xxl", |
|
device_map="auto", |
|
model_kwargs={"load_in_8bit": True} |
|
) |
|
|
|
return captioner, storyteller |
|
|
|
|
|
def main(): |
|
captioner, storyteller = load_pipelines() |
|
|
|
|
|
uploaded = st.file_uploader( |
|
"Upload an image:", |
|
type=["jpg", "jpeg", "png"], |
|
help="Max size: 5MB" |
|
) |
|
|
|
if uploaded: |
|
try: |
|
|
|
image = Image.open(uploaded).convert("RGB") |
|
st.image(image, caption="Your Image", use_column_width=True) |
|
|
|
|
|
with st.spinner("🔍 Analyzing image content..."): |
|
cap_outputs = captioner(image) |
|
cap = cap_outputs[0].get("generated_text", "").strip() |
|
|
|
st.subheader("Image Understanding") |
|
st.info(f"**Detected:** {cap}") |
|
|
|
|
|
st.subheader("Story Creation") |
|
prompt = f"""Create a children's story (3-10 years old) based on this description: |
|
|
|
{cap} |
|
|
|
Requirements: |
|
- 50-100 words |
|
- Playful and imaginative |
|
- Positive message |
|
- Simple vocabulary |
|
- Include animal characters |
|
|
|
Story:""" |
|
|
|
with st.spinner("✍️ Crafting a magical story..."): |
|
story_output = storyteller( |
|
prompt, |
|
max_length=300, |
|
do_sample=True, |
|
top_p=0.95, |
|
temperature=0.85, |
|
num_beams=4, |
|
repetition_penalty=1.2 |
|
) |
|
story = story_output[0]["generated_text"].strip() |
|
|
|
st.success("**Generated Story:**") |
|
st.write(story) |
|
|
|
|
|
st.subheader("Audio Version") |
|
with st.spinner("🔊 Generating audio..."): |
|
try: |
|
tts = gTTS(text=story, lang="en", slow=False) |
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp: |
|
tts.write_to_fp(tmp) |
|
tmp_path = tmp.name |
|
|
|
st.audio(tmp_path, format="audio/mp3") |
|
|
|
|
|
with open(tmp_path, "rb") as f: |
|
st.download_button( |
|
label="Download Audio Story", |
|
data=f, |
|
file_name="kids_story.mp3", |
|
mime="audio/mpeg" |
|
) |
|
|
|
finally: |
|
if os.path.exists(tmp_path): |
|
os.remove(tmp_path) |
|
|
|
except Exception as e: |
|
st.error(f"Error processing your request: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|