1
File size: 3,878 Bytes
982555a
c83a777
982555a
 
 
 
 
 
 
 
 
c83a777
 
982555a
c83a777
982555a
c83a777
 
982555a
c83a777
982555a
 
c83a777
982555a
c83a777
982555a
b3abd21
c876f7b
982555a
c83a777
982555a
c83a777
c4110d1
c83a777
 
 
 
982555a
 
 
6adb177
c4110d1
c83a777
 
982555a
 
 
 
 
 
 
fd1d947
982555a
 
 
 
e508bdf
c83a777
982555a
 
 
 
 
 
c83a777
982555a
 
c83a777
982555a
 
c83a777
982555a
 
c83a777
982555a
c83a777
c4110d1
982555a
 
 
 
 
 
c83a777
 
 
982555a
 
 
c83a777
982555a
 
 
 
 
 
 
 
 
 
 
c4110d1
982555a
b3abd21
c83a777
982555a
 
e508bdf
982555a
 
c83a777
982555a
 
 
 
c83a777
982555a
2aae3c9
c83a777
982555a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Must be FIRST import and FIRST Streamlit command
import streamlit as st
st.set_page_config(
    page_title="Magic Story Generator",
    layout="centered",
    page_icon="📖"
)

# Other imports AFTER Streamlit config
import time
import tempfile
from PIL import Image
from gtts import gTTS
from transformers import pipeline

# --- Constants & Setup ---
st.title("📖✨ Turn Images into Children's Stories")

# --- Model Loading (Cached) ---
@st.cache_resource(show_spinner=False)
def load_models():
    # Image captioning model
    captioner = pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device=-1  # Use -1 for CPU, 0 for GPU
    )
    
    # Story generation model (Qwen3-1.7B)
    storyteller = pipeline(
        "text-generation",
        model="Qwen/Qwen3-1.7B",
        device_map="auto",
        trust_remote_code=True,
        torch_dtype="auto",
        max_new_tokens=150,
        temperature=0.7,
        top_p=0.85,
        repetition_penalty=1.15,
        eos_token_id=151645  # Qwen3's specific EOS token
    )
    
    return captioner, storyteller

caption_pipe, story_pipe = load_models()

# --- Main Application Flow ---
uploaded_image = st.file_uploader(
    "Upload a children's book style image:",
    type=["jpg", "jpeg", "png"]
)

if uploaded_image:
    # Process image
    image = Image.open(uploaded_image).convert("RGB")
    st.image(image, use_container_width=True)

    # Generate caption
    with st.spinner("🔍 Analyzing image..."):
        caption_result = caption_pipe(image)
        image_caption = caption_result[0].get("generated_text", "").strip()
    
    if not image_caption:
        st.error("❌ Couldn't understand this image. Please try another!")
        st.stop()
    
    st.success(f"**Image Understanding:** {image_caption}")

    # Create story prompt
    story_prompt = (
        f"<|im_start|>system\n"
        f"You are a children's book author. Create a 50-100 word story based on this image description: {image_caption}\n"
        "Use simple language, friendly characters, and a positive lesson.<|im_end|>\n"
        f"<|im_start|>user\n"
        f"Write a short, child-friendly story with a clear beginning, middle, and end.<|im_end|>\n"
        f"<|im_start|>assistant\n"
    )

    # Generate story
    with st.spinner("📝 Crafting magical story..."):
        start_time = time.time()
        story_result = story_pipe(
            story_prompt,
            do_sample=True,
            num_return_sequences=1
        )
        generation_time = time.time() - start_time
        st.text(f"⏱ Generation time: {generation_time:.1f}s")

    # Process output
    raw_story = story_result[0]['generated_text']
    clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
    clean_story = clean_story.replace("<|im_end|>", "").strip()

    # Ensure proper story formatting
    final_story = []
    for sentence in clean_story.split(". "):
        if not sentence: continue
        if not sentence.endswith('.'):
            sentence += '.'
        final_story.append(sentence[0].upper() + sentence[1:])
    
    final_story = " ".join(final_story).replace("..", ".")[:600]  # Character limit safeguard

    # Display story
    st.subheader("✨ Your Magical Story")
    st.write(final_story)

    # Audio conversion
    with st.spinner("🔊 Creating audio version..."):
        try:
            audio = gTTS(text=final_story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
                audio.save(tmp_file.name)
                st.audio(tmp_file.name, format="audio/mp3")
        except Exception as e:
            st.error(f"❌ Audio conversion failed: {str(e)}")

# Footer
st.markdown("---")
st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")