1
File size: 3,713 Bytes
8367fb2
 
7d2ac1c
 
422a749
6b1de29
8367fb2
 
c5def56
8e5f097
 
8367fb2
c4ae250
8367fb2
7d2ac1c
 
422a749
e1594b2
422a749
 
 
 
 
 
 
 
 
e1594b2
422a749
8367fb2
c5def56
c4ae250
422a749
 
301b896
422a749
 
301b896
422a749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb9cd5a
c5def56
301b896
 
422a749
 
 
c4ae250
 
422a749
 
 
 
301b896
8e5f097
 
 
 
 
 
 
 
 
dfb3989
cc355a8
8e5f097
422a749
8e5f097
 
258bc7e
8e5f097
 
c2c4e19
8e5f097
 
 
 
c2c4e19
422a749
7d2ac1c
c5def56
8e5f097
96d517c
422a749
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi
from transformers import pipeline
from gtts import gTTS
import tempfile

# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")

# —––––––– Clients (cached) —–––––––
@st.cache_resource
def load_clients():
    hf_token = st.secrets["HF_TOKEN"]
    # image captioning client as before
    caption_client = InferenceApi("Salesforce/blip-image-captioning-base", token=hf_token)
    # text-generation pipeline for story
    story_generator = pipeline(
        "text-generation",
        model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        use_auth_token=hf_token,
        device=0  # or -1 for CPU
    )
    return caption_client, story_generator

caption_client, story_generator = load_clients()

# —––––––– Helper: Generate Caption —–––––––
def generate_caption(img):
    buf = BytesIO()
    img.save(buf, format="JPEG")
    try:
        out = caption_client(data=buf.getvalue())
        return out[0].get("generated_text", "").strip()
    except Exception as e:
        st.error(f"Caption error: {e}")
        return ""

# —––––––– Helper: Generate Story via pipeline —–––––––
def generate_story(prompt: str) -> str:
    # generate up to ~200 tokens to cover 100 words margin
    outputs = story_generator(
        prompt,
        max_new_tokens=200,
        temperature=0.8,
        top_p=0.95,
        repetition_penalty=1.15,
        no_repeat_ngram_size=2,
        do_sample=True
    )
    text = outputs[0]["generated_text"].strip()
    # everything after "Story:" (if you kept that in your prompt)
    if "Story:" in text:
        text = text.split("Story:", 1)[1].strip()
    # truncate to 100 words
    words = text.split()
    if len(words) > 100:
        text = " ".join(words[:100])
        if not text.endswith("."):
            text += "."
    return text

# —––––––– Main App Flow —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
    img = Image.open(uploaded).convert("RGB")
    if max(img.size) > 2048:
        img.thumbnail((2048, 2048))
    st.image(img, use_container_width=True)

    caption = generate_caption(img)
    if not caption:
        st.error("😢 Couldn't understand this image. Try another one!")
        st.stop()
    st.success(f"**Caption:** {caption}")

    story_prompt = (
        f"Image description: {caption}\n\n"
        "Write a 50-100 word children's story that:\n"
        "1. Features the main subject as a friendly character\n"
        "2. Includes a simple adventure or discovery\n"
        "3. Ends with a happy or funny conclusion\n"
        "4. Uses simple language for ages 3-8\n\n"
        "Story:\n"
    )

    with st.spinner("📝 Writing magical story..."):
        story = generate_story(story_prompt)
    st.subheader("📚 Your Magical Story")
    st.write(story)

    # Audio Conversion
    with st.spinner("🔊 Adding story voice..."):
        try:
            tts = gTTS(text=story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
                tts.save(fp.name)
                st.audio(fp.name, format="audio/mp3")
        except Exception as e:
            st.warning(f"⚠️ Couldn't make audio version: {e}")

# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")