1
File size: 3,916 Bytes
f913ab4
4d1f328
8151df4
 
 
 
 
 
9862828
8151df4
 
9862828
8151df4
c83a777
 
6523fb1
c83a777
6523fb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac11067
 
 
 
 
 
6523fb1
 
 
 
 
 
 
 
 
 
 
 
8151df4
c83a777
f913ab4
6523fb1
 
 
 
 
8151df4
 
 
 
fd1d947
982555a
8151df4
6523fb1
 
8151df4
 
9862828
 
6523fb1
 
9862828
 
 
6523fb1
 
 
 
 
 
 
 
 
9862828
 
 
 
6523fb1
 
 
9862828
 
 
6523fb1
eec20c9
6523fb1
 
eec20c9
6523fb1
 
 
 
 
 
eec20c9
6523fb1
 
eec20c9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Import Streamlit first
import streamlit as st
st.set_page_config(
    page_title="Magic Story Generator",
    layout="centered",
    page_icon="📖"
)

# Other imports
import re
import time
import torch
import tempfile
from PIL import Image
from gtts import gTTS
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

# --- Initialize Models First ---
@st.cache_resource(show_spinner=False)
def load_models():
    """Load and return both models at startup"""
    try:
        # 1. Image Captioning Model
        caption_pipe = pipeline(
            "image-to-text",
            model="Salesforce/blip-image-captioning-base",
            device=0 if torch.cuda.is_available() else -1
        )

        # 2. Story Generation Model
        story_tokenizer = AutoTokenizer.from_pretrained(
            "Qwen/Qwen3-0.6B",
            trust_remote_code=True
        )
        
        story_model = AutoModelForCausalLM.from_pretrained(
            "Qwen/Qwen3-0.6B",
            device_map="auto",
            torch_dtype=torch.float16
        )

        story_pipe = pipeline(
            "text-generation",
            model=story_model,
            tokenizer=story_tokenizer,
            max_new_tokens=230,
            temperature=0.9,
            top_k=50,
            top_p=0.9,
            repetition_penalty=1.1,
            eos_token_id=151645
        )

        return caption_pipe, story_pipe

    except Exception as e:
        st.error(f"🚨 Model loading failed: {str(e)}")
        st.stop()

# Initialize models immediately when app starts
caption_pipe, story_pipe = load_models()

# --- Rest of Application ---
st.title("📖✨ Turn Images into Children's Stories")

def clean_story_text(raw_text):
    """Improved cleaning function"""
    clean = re.sub(r'<\|.*?\|>', '', raw_text)  # Remove special tokens
    clean = re.sub(r'Okay, I need.*?(?=\n|$)', '', clean, flags=re.DOTALL)  # Remove thinking chains
    return clean.strip()

uploaded_image = st.file_uploader(
    "Upload a children's book style image:",
    type=["jpg", "jpeg", "png"]
)

if uploaded_image:
    image = Image.open(uploaded_image).convert("RGB")
    # Updated parameter here ↓
    st.image(image, use_container_width=True)  # Changed use_column_width to use_container_width

    with st.spinner("🔍 Analyzing image..."):
        try:
            caption_result = caption_pipe(image)
            image_caption = caption_result[0].get("generated_text", "")
            st.success(f"**Image Understanding:** {image_caption}")
        except Exception as e:
            st.error(f"❌ Image analysis failed: {str(e)}")
            st.stop()

    # Story generation prompt
    story_prompt = f"""Write a children's story about: {image_caption}
Rules:
- Use simple words (Grade 2 level)
- Exclude thinking processes
- 3 paragraphs maximum
Story:"""

    try:
        with st.spinner("📝 Crafting magical story..."):
            story_result = story_pipe(
                story_prompt,
                do_sample=True,
                top_p=0.9,
                repetition_penalty=1.2
            )
            
            raw_story = story_result[0]['generated_text']
            final_story = clean_story_text(raw_story.split("Story:")[-1])

            st.subheader("✨ Your Magical Story")
            st.write(final_story)

            # Audio conversion
            with st.spinner("🔊 Creating audio version..."):
                audio = gTTS(text=final_story, lang="en", slow=False)
                with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
                    audio.save(tmp_file.name)
                    st.audio(tmp_file.name, format="audio/mp3")

    except Exception as e:
        st.error(f"❌ Story generation failed: {str(e)}")

# Footer
st.markdown("---")
st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")