1
File size: 3,423 Bytes
c4110d1
 
c83a777
 
 
 
 
8367fb2
c83a777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3abd21
c876f7b
c83a777
 
 
 
c4110d1
c83a777
 
 
 
 
 
 
6adb177
c4110d1
c83a777
 
 
fd1d947
c83a777
 
 
 
 
 
e508bdf
c83a777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4110d1
 
c83a777
 
 
 
 
 
 
 
 
c4110d1
c83a777
 
 
c4110d1
c83a777
 
 
 
 
 
b3abd21
c83a777
 
 
e508bdf
c83a777
 
 
 
 
 
 
 
 
2aae3c9
c83a777
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import time
import streamlit as st
from PIL import Image
from transformers import pipeline
from gtts import gTTS
import tempfile

# --- Requirements ---
# Update requirements.txt to include:
"""
streamlit>=1.20
pillow>=9.0
torch>=2.0.0
transformers>=4.40
sentencepiece>=0.2.0
gTTS>=2.3.1
accelerate>=0.30
"""

# --- Page Setup ---
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")

# --- Load Pipelines (cached) ---
@st.cache_resource(show_spinner=False)
def load_pipelines():
    # 1) Image-captioning pipeline (BLIP)
    captioner = pipeline(
        task="image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device=-1
    )
    
    # 2) Modified story-generation pipeline using Qwen3-1.7B
    storyteller = pipeline(
        task="text-generation",
        model="Qwen/Qwen3-1.7B",
        device_map="auto",
        trust_remote_code=True,
        torch_dtype="auto",
        max_new_tokens=150,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.2,
        eos_token_id=151645  # Specific to Qwen3 tokenizer
    )
    
    return captioner, storyteller

captioner, storyteller = load_pipelines()

# --- Main App ---
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
    # Load and display the image
    img = Image.open(uploaded).convert("RGB")
    st.image(img, use_container_width=True)

    # Generate caption
    with st.spinner("🔍 Generating caption..."):
        cap = captioner(img)
        caption = cap[0].get("generated_text", "").strip() if isinstance(cap, list) else ""
    if not caption:
        st.error("😢 Couldn't understand this image. Try another one!")
        st.stop()
    st.success(f"**Caption:** {caption}")

    # Build prompt and generate story
    prompt = (
        f"<|im_start|>system\n"
        f"You are a children's story writer. Create a 50-100 word story based on this image description: {caption}\n"
        f"<|im_end|>\n"
        f"<|im_start|>user\n"
        f"Write a coherent, child-friendly story that flows naturally with simple vocabulary.<|im_end|>\n"
        f"<|im_start|>assistant\n"
    )
    
    with st.spinner("📝 Writing story..."):
        start = time.time()
        out = storyteller(
            prompt,
            do_sample=True,
            num_return_sequences=1
        )
        gen_time = time.time() - start
        st.text(f"⏱ Generated in {gen_time:.1f}s")
    
    # Process output
    story = out[0]['generated_text'].split("<|im_start|>assistant\n")[-1]
    story = story.replace("<|im_end|>", "").strip()
    
    # Enforce ≤100 words and proper ending
    words = story.split()
    if len(words) > 100:
        story = " ".join(words[:100])
    if not story.endswith(('.', '!', '?')):
        story += '.'

    # Display story
    st.subheader("📚 Your Magical Story")
    st.write(story)

    # Convert to audio
    with st.spinner("🔊 Converting to audio..."):
        try:
            tts = gTTS(text=story, lang="en", slow=False)
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
            tts.save(tmp.name)
            st.audio(tmp.name, format="audio/mp3")
        except Exception as e:
            st.warning(f"⚠️ TTS failed: {e}")

# Footer
st.markdown("---\nMade with ❤️ by your friendly story wizard")