1
File size: 3,880 Bytes
8151df4
4d1f328
8151df4
 
 
 
 
 
 
 
 
 
c83a777
 
8151df4
c83a777
8151df4
 
c83a777
8151df4
 
982555a
8151df4
 
 
 
 
 
 
 
 
 
c5b69e3
c4110d1
8151df4
 
ecd3e31
 
 
8151df4
 
 
6adb177
8151df4
 
c83a777
8151df4
982555a
8151df4
 
 
 
 
fd1d947
982555a
8151df4
 
 
 
613c57d
8151df4
 
 
4d1f328
8151df4
 
 
e537b6d
8151df4
 
 
 
 
 
613c57d
8151df4
 
 
 
 
 
 
 
 
 
 
 
 
 
c4110d1
8151df4
 
 
 
4d1f328
8151df4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613c57d
8151df4
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Must be FIRST import and FIRST Streamlit command
import streamlit as st
st.set_page_config(
    page_title="Magic Story Generator",
    layout="centered",
    page_icon="📖"
)

# Other imports AFTER Streamlit config
import re
import time
import tempfile
from PIL import Image
from gtts import gTTS
from transformers import pipeline

# --- Constants & Setup ---
st.title("📖✨ Turn Images into Children's Stories")

# --- Model Loading (Cached) ---
@st.cache_resource(show_spinner=False)
def load_models():
    # Image captioning model
    captioner = pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device=-1  # Use -1 for CPU, 0 for GPU
    )
    
    # Story generation model (Qwen3-1.7B)
    storyteller = pipeline(
        "text-generation",
        model="Qwen/Qwen3-0.6B",
        device_map="auto",
        trust_remote_code=True,
        torch_dtype="auto",
        max_new_tokens=230,
        temperature=0.8,
        top_k=50,
        top_p=0.85,
        repetition_penalty=1.15,
        eos_token_id=151645
    )
    
    return captioner, storyteller

caption_pipe, story_pipe = load_models()

# --- Main Application Flow ---
uploaded_image = st.file_uploader(
    "Upload a children's book style image:",
    type=["jpg", "jpeg", "png"]
)

if uploaded_image:
    # Process image
    image = Image.open(uploaded_image).convert("RGB")
    st.image(image, use_container_width=True)

    # Generate caption
    with st.spinner("🔍 Analyzing image..."):
        caption_result = caption_pipe(image)
        image_caption = caption_result[0].get("generated_text", "").strip()
    
    if not image_caption:
        st.error("❌ Couldn't understand this image. Please try another!")
        st.stop()
    
    st.success(f"**Image Understanding:** {image_caption}")

    # Create story prompt
    story_prompt = (
        f"<|im_start|>system\n"
        f"You are a children's book author. Create a 100-150 word story based on: {image_caption}\n"
    )

    # Generate story
    with st.spinner("📝 Crafting magical story..."):
        start_time = time.time()
        story_result = story_pipe(
            story_prompt,
            do_sample=True,
            num_return_sequences=1,
            pad_token_id=151645
        )
        generation_time = time.time() - start_time

    # Process output
    raw_story = story_result[0]['generated_text']
    
    # Clean up story text
    clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
    clean_story = clean_story.split("<|im_start|>")[0]  # Remove any new turns
    clean_story = clean_story.replace("<|im_end|>", "").strip()
    
    # Remove assistant mentions using regex
    clean_story = re.sub(
        r'^(assistant[:>]?\s*)+', 
        '', 
        clean_story, 
        flags=re.IGNORECASE
    ).strip()

    # Format story punctuation
    final_story = []
    for sentence in clean_story.split(". "):
        sentence = sentence.strip()
        if not sentence:
            continue
        if not sentence.endswith('.'):
            sentence += '.'
        final_story.append(sentence[0].upper() + sentence[1:])
    
    final_story = " ".join(final_story).replace("..", ".")[:800]

    # Display story
    st.subheader("✨ Your Magical Story")
    st.write(final_story)

    # Audio conversion
    with st.spinner("🔊 Creating audio version..."):
        try:
            audio = gTTS(text=final_story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
                audio.save(tmp_file.name)
                st.audio(tmp_file.name, format="audio/mp3")
        except Exception as e:
            st.error(f"❌ Audio conversion failed: {str(e)}")

# Footer
st.markdown("---")
st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")