|
import os |
|
import time |
|
import streamlit as st |
|
from PIL import Image |
|
from transformers import pipeline |
|
from gtts import gTTS |
|
import tempfile |
|
|
|
|
|
|
|
""" |
|
streamlit>=1.20 |
|
pillow>=9.0 |
|
torch>=2.0.0 |
|
transformers>=4.40 |
|
sentencepiece>=0.2.0 |
|
gTTS>=2.3.1 |
|
accelerate>=0.30 |
|
""" |
|
|
|
|
|
st.set_page_config(page_title="Magic Story Generator", layout="centered") |
|
st.title("📖✨ Turn Images into Children's Stories") |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_pipelines(): |
|
|
|
captioner = pipeline( |
|
task="image-to-text", |
|
model="Salesforce/blip-image-captioning-base", |
|
device=-1 |
|
) |
|
|
|
|
|
storyteller = pipeline( |
|
task="text-generation", |
|
model="Qwen/Qwen3-1.7B", |
|
device_map="auto", |
|
trust_remote_code=True, |
|
torch_dtype="auto", |
|
max_new_tokens=150, |
|
temperature=0.7, |
|
top_p=0.9, |
|
repetition_penalty=1.2, |
|
eos_token_id=151645 |
|
) |
|
|
|
return captioner, storyteller |
|
|
|
captioner, storyteller = load_pipelines() |
|
|
|
|
|
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"]) |
|
if uploaded: |
|
|
|
img = Image.open(uploaded).convert("RGB") |
|
st.image(img, use_container_width=True) |
|
|
|
|
|
with st.spinner("🔍 Generating caption..."): |
|
cap = captioner(img) |
|
caption = cap[0].get("generated_text", "").strip() if isinstance(cap, list) else "" |
|
if not caption: |
|
st.error("😢 Couldn't understand this image. Try another one!") |
|
st.stop() |
|
st.success(f"**Caption:** {caption}") |
|
|
|
|
|
prompt = ( |
|
f"<|im_start|>system\n" |
|
f"You are a children's story writer. Create a 50-100 word story based on this image description: {caption}\n" |
|
f"<|im_end|>\n" |
|
f"<|im_start|>user\n" |
|
f"Write a coherent, child-friendly story that flows naturally with simple vocabulary.<|im_end|>\n" |
|
f"<|im_start|>assistant\n" |
|
) |
|
|
|
with st.spinner("📝 Writing story..."): |
|
start = time.time() |
|
out = storyteller( |
|
prompt, |
|
do_sample=True, |
|
num_return_sequences=1 |
|
) |
|
gen_time = time.time() - start |
|
st.text(f"⏱ Generated in {gen_time:.1f}s") |
|
|
|
|
|
story = out[0]['generated_text'].split("<|im_start|>assistant\n")[-1] |
|
story = story.replace("<|im_end|>", "").strip() |
|
|
|
|
|
words = story.split() |
|
if len(words) > 100: |
|
story = " ".join(words[:100]) |
|
if not story.endswith(('.', '!', '?')): |
|
story += '.' |
|
|
|
|
|
st.subheader("📚 Your Magical Story") |
|
st.write(story) |
|
|
|
|
|
with st.spinner("🔊 Converting to audio..."): |
|
try: |
|
tts = gTTS(text=story, lang="en", slow=False) |
|
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") |
|
tts.save(tmp.name) |
|
st.audio(tmp.name, format="audio/mp3") |
|
except Exception as e: |
|
st.warning(f"⚠️ TTS failed: {e}") |
|
|
|
|
|
st.markdown("---\nMade with ❤️ by your friendly story wizard") |