1
File size: 4,622 Bytes
4d1f328
e537b6d
982555a
 
4d1f328
c83a777
 
982555a
c83a777
4d1f328
 
 
 
 
 
c83a777
982555a
c83a777
982555a
 
c83a777
982555a
c83a777
4d1f328
b3abd21
c876f7b
4d1f328
c83a777
982555a
c83a777
c4110d1
c83a777
 
4d1f328
 
 
 
 
 
 
6adb177
c4110d1
c83a777
 
4d1f328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982555a
 
 
 
 
fd1d947
982555a
4d1f328
982555a
4d1f328
982555a
4d1f328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982555a
 
 
c83a777
4d1f328
 
 
 
 
 
 
c83a777
4d1f328
c83a777
c4110d1
e537b6d
4d1f328
 
 
 
 
 
 
 
 
 
 
 
 
e537b6d
4d1f328
 
c4110d1
4d1f328
982555a
4d1f328
 
982555a
c83a777
4d1f328
 
 
 
c83a777
4d1f328
2aae3c9
c83a777
982555a
4d1f328
982555a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# story_generator.py
import re
import time
import tempfile
import streamlit as st
from PIL import Image
from gtts import gTTS
from transformers import pipeline

# --- Initialize Streamlit Config ---
st.set_page_config(
    page_title="Magic Story Generator",
    layout="centered",
    page_icon="📖"
)

# --- Model Loading (Cached) ---
@st.cache_resource(show_spinner=False)
def load_models():
    # Image captioning model
    captioner = pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device=-1
    )
    
    # Story generation model with optimized settings
    storyteller = pipeline(
        "text-generation",
        model="Qwen/Qwen3-1.7B",
        device_map="auto",
        trust_remote_code=True,
        torch_dtype="auto",
        model_kwargs={
            "revision": "main",
            "temperature": 0.7,
            "top_p": 0.9,
            "repetition_penalty": 1.1,
            "pad_token_id": 151645
        }
    )
    
    return captioner, storyteller

# --- Text Processing Utilities ---
def clean_generated_text(raw_text):
    # Split at first assistant marker
    clean_text = raw_text.split("<|im_start|>assistant\n", 1)[-1]
    
    # Remove any subsequent chat turns
    clean_text = clean_text.split("<|im_start|>")[0]
    
    # Remove special tokens and whitespace
    clean_text = clean_text.replace("<|im_end|>", "").strip()
    
    # Regex cleanup for remaining markers
    clean_text = re.sub(
        r'^(assistant[\s\-\:>]*)+',
        '',
        clean_text,
        flags=re.IGNORECASE
    ).strip()
    
    # Format punctuation and capitalization
    sentences = []
    for sent in re.split(r'(?<=[.!?]) +', clean_text):
        sent = sent.strip()
        if not sent:
            continue
        if sent[-1] not in {'.', '!', '?'}:
            sent += '.'
        sentences.append(sent[0].upper() + sent[1:])
    
    return ' '.join(sentences)

# --- Main Application UI ---
st.title("📖✨ Magic Story Generator")

uploaded_image = st.file_uploader(
    "Upload a children's book style image:",
    type=["jpg", "jpeg", "png"]
)

if uploaded_image:
    # Display uploaded image
    image = Image.open(uploaded_image).convert("RGB")
    st.image(image, use_column_width=True)
    
    # Load models only when needed
    caption_pipe, story_pipe = load_models()
    
    # Generate image caption
    with st.spinner("🔍 Analyzing image..."):
        try:
            caption_result = caption_pipe(image)
            image_caption = caption_result[0].get("generated_text", "").strip()
            
            if not image_caption:
                raise ValueError("Couldn't generate caption")
                
            st.success(f"**Image Understanding:** {image_caption}")
        except Exception as e:
            st.error("❌ Failed to analyze image. Please try another.")
            st.stop()
    
    # Create story prompt
    story_prompt = (
        f"<|im_start|>system\n"
        f"You are a children's book author. Create a 150-word story based on: {image_caption}\n"
        "Include these elements:\n"
        "- Friendly characters\n"
        "- Simple vocabulary\n"
        "- Positive lesson\n"
        "- Clear story structure\n"
        "<|im_end|>\n"
        f"<|im_start|>user\n"
        f"Write an engaging story suitable for ages 6-8.<|im_end|>\n"
        f"<|im_start|>assistant\n"
    )
    
    # Generate story text
    with st.spinner("📝 Crafting magical story..."):
        try:
            story_result = story_pipe(
                story_prompt,
                max_new_tokens=300,
                do_sample=True,
                num_return_sequences=1
            )
            raw_story = story_result[0]['generated_text']
        except Exception as e:
            st.error("❌ Story generation failed. Please try again.")
            st.stop()
    
    # Process and display story
    final_story = clean_generated_text(raw_story)
    
    st.subheader("✨ Your Story")
    st.write(final_story)
    
    # Generate audio version
    with st.spinner("🔊 Creating audio version..."):
        try:
            tts = gTTS(text=final_story, lang='en', slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
                tts.save(fp.name)
                st.audio(fp.read(), format="audio/mp3")
        except Exception as e:
            st.warning("⚠️ Audio conversion failed. Text version still available.")

# Footer
st.markdown("---")
st.caption("Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")