1
File size: 3,616 Bytes
8367fb2
 
7d2ac1c
 
6b1de29
8367fb2
 
8e5f097
 
8367fb2
c4ae250
8367fb2
7d2ac1c
 
8e5f097
 
 
8367fb2
121e41f
7d2ac1c
8367fb2
c4ae250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301b896
 
 
c4ae250
301b896
 
 
c4ae250
301b896
c4ae250
301b896
 
 
c4ae250
 
301b896
 
 
 
 
 
8e5f097
 
 
 
 
 
 
 
 
 
dfb3989
cc355a8
8e5f097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bc44b9
8e5f097
 
 
258bc7e
8e5f097
 
c2c4e19
8e5f097
 
 
 
c2c4e19
8e5f097
7d2ac1c
8e5f097
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi
from gtts import gTTS
import tempfile

st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")

# —––––––– Clients (cached) —–––––––
@st.cache_resource
def load_clients():
    hf_token = st.secrets["HF_TOKEN"]
    return (
        InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token),
        InferenceApi("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", token=hf_token)
    )

caption_client, story_client = load_clients()

def generate_caption(img):
    """
    Runs the BLIP caption model on a PIL.Image and returns the generated text.
    """
    img_bytes = BytesIO()
    img.save(img_bytes, format="JPEG")
    try:
        result = caption_client(data=img_bytes.getvalue())
        if isinstance(result, list) and result:
            return result[0].get("generated_text", "").strip()
        return ""
    except Exception as e:
        st.error(f"Caption generation error: {e}")
        return ""

def process_image(uploaded_file):
    try:
        img = Image.open(uploaded_file).convert("RGB")
        if max(img.size) > 2048:
            img.thumbnail((2048, 2048))
        return img
    except Exception as e:
        st.error(f"Image processing error: {e}")
        st.stop()

uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
    img = process_image(uploaded)
    st.image(img, use_container_width=True)

    with st.spinner("Generating caption..."):
        caption = generate_caption(img)
        if not caption:
            st.error("😢 Couldn't understand this image. Try another one!")
            st.stop()
    st.success(f"**Caption:** {caption}")

    # Story Generation Prompt
    story_prompt = (
        f"Image description: {caption}\n\n"
        "Write a 50-100 word children's story that:\n"
        "1. Features the main subject as a friendly character\n"
        "2. Includes a simple adventure or discovery\n"
        "3. Ends with a happy or funny conclusion\n"
        "4. Uses simple language for ages 3-8\n\n"
        "Story:\n"
    )

    # Generate Story
    with st.spinner("📝 Writing magical story..."):
        try:
            story_response = story_client(
                story_prompt,
                max_new_tokens=200,
                temperature=0.8,
                top_p=0.95,
                repetition_penalty=1.15,
                do_sample=True,
                no_repeat_ngram_size=2
            )
            
            # Process response
            full_text = story_response[0]['generated_text']
            story = full_text.split("Story:")[-1].strip()
            
            # Ensure clean ending
            if "." in story:
                story = story.rsplit(".", 1)[0] + "."
            
        except Exception as e:
            st.error(f"🚨 Story magic failed: {str(e)}")
            st.stop()

    # Display Story
    st.subheader("📚 Your Magical Story")
    st.write(story)

    # Audio Conversion
    with st.spinner("🔊 Adding story voice..."):
        try:
            tts = gTTS(text=story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
                tts.save(fp.name)
                st.audio(fp.name, format="audio/mp3")
        except Exception as e:
            st.warning("⚠️ Couldn't make audio version: " + str(e))

st.markdown("---\n*Made with ❤️ by your friendly story wizard*")