1
File size: 4,112 Bytes
8367fb2
 
 
 
 
db1550f
8367fb2
db1550f
 
 
 
 
 
 
 
8367fb2
db1550f
8367fb2
 
db1550f
8367fb2
504dc12
db1550f
 
8367fb2
db1550f
 
8367fb2
 
db1550f
 
 
8367fb2
db1550f
8367fb2
 
db1550f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8367fb2
db1550f
 
 
 
 
 
 
8367fb2
db1550f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8367fb2
db1550f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8367fb2
db1550f
 
1c165f8
db1550f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
from PIL import Image
from transformers import pipeline
from gtts import gTTS
import tempfile
import os

# —––––––– Page config —–––––––
st.set_page_config(
    page_title="Storyteller for Kids",
    page_icon="📚",
    layout="centered",
    initial_sidebar_state="collapsed"
)
st.title("🖼️➡️📖 Interactive Storyteller")

# —––––––– Cache model loading —–––––––
@st.cache_resource
def load_pipelines():
    # Image-to-text pipeline
    captioner = pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base",
        max_new_tokens=50
    )
    
    # Story generation pipeline with better parameters
    storyteller = pipeline(
        "text2text-generation",
        model="google/flan-t5-xxl",
        device_map="auto",
        model_kwargs={"load_in_8bit": True}
    )
    
    return captioner, storyteller

# —––––––– Main workflow —–––––––
def main():
    captioner, storyteller = load_pipelines()
    
    # —––––––– Image upload —–––––––
    uploaded = st.file_uploader(
        "Upload an image:",
        type=["jpg", "jpeg", "png"],
        help="Max size: 5MB"
    )
    
    if uploaded:
        try:
            # —––––––– Display image —–––––––
            image = Image.open(uploaded).convert("RGB")
            st.image(image, caption="Your Image", use_column_width=True)

            # —––––––– Generate caption —–––––––
            with st.spinner("🔍 Analyzing image content..."):
                cap_outputs = captioner(image)
                cap = cap_outputs[0].get("generated_text", "").strip()
                
            st.subheader("Image Understanding")
            st.info(f"**Detected:** {cap}")

            # —––––––– Generate story —–––––––
            st.subheader("Story Creation")
            prompt = f"""Create a children's story (3-10 years old) based on this description:
            
            {cap}
            
            Requirements:
            - 50-100 words
            - Playful and imaginative
            - Positive message
            - Simple vocabulary
            - Include animal characters
            
            Story:"""
            
            with st.spinner("✍️ Crafting a magical story..."):
                story_output = storyteller(
                    prompt,
                    max_length=300,
                    do_sample=True,
                    top_p=0.95,
                    temperature=0.85,
                    num_beams=4,
                    repetition_penalty=1.2
                )
                story = story_output[0]["generated_text"].strip()
                
            st.success("**Generated Story:**")
            st.write(story)

            # —––––––– Text-to-Speech —–––––––
            st.subheader("Audio Version")
            with st.spinner("🔊 Generating audio..."):
                try:
                    tts = gTTS(text=story, lang="en", slow=False)
                    with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
                        tts.write_to_fp(tmp)
                        tmp_path = tmp.name
                        
                    st.audio(tmp_path, format="audio/mp3")
                    
                    # Add download button
                    with open(tmp_path, "rb") as f:
                        st.download_button(
                            label="Download Audio Story",
                            data=f,
                            file_name="kids_story.mp3",
                            mime="audio/mpeg"
                        )
                    
                finally:
                    if os.path.exists(tmp_path):
                        os.remove(tmp_path)

        except Exception as e:
            st.error(f"Error processing your request: {str(e)}")

if __name__ == "__main__":
    main()