1
File size: 4,236 Bytes
0dcd353
ed4df47
8367fb2
 
7d2ac1c
0dcd353
422a749
6b1de29
8367fb2
 
c5def56
2aae3c9
 
8367fb2
c4ae250
ed4df47
7d2ac1c
 
0dcd353
2aae3c9
0dcd353
 
 
2aae3c9
ed4df47
 
 
 
2aae3c9
0dcd353
 
 
 
 
2aae3c9
ed4df47
422a749
0dcd353
422a749
 
2aae3c9
422a749
ed4df47
0dcd353
422a749
e1594b2
422a749
8367fb2
0dcd353
c5def56
0dcd353
422a749
 
301b896
ed4df47
 
 
301b896
0dcd353
 
 
422a749
 
ed4df47
 
 
2aae3c9
ed4df47
 
2aae3c9
 
ed4df47
2aae3c9
ed4df47
2aae3c9
ed4df47
 
 
422a749
 
ed4df47
 
 
 
 
422a749
 
2aae3c9
0dcd353
ed4df47
2aae3c9
0dcd353
 
2aae3c9
422a749
 
 
 
 
 
bb9cd5a
0dcd353
c5def56
2aae3c9
301b896
422a749
 
 
c4ae250
 
ed4df47
0dcd353
422a749
 
 
301b896
8e5f097
 
ed4df47
0dcd353
8e5f097
 
258bc7e
ed4df47
c2c4e19
8e5f097
 
 
 
c2c4e19
ed4df47
7d2ac1c
c5def56
8e5f097
2aae3c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import time
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi, login
from transformers import pipeline
from gtts import gTTS
import tempfile

# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")

# —––––––– Clients (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_clients():
    hf_token = st.secrets["HF_TOKEN"]

    # Authenticate for both HF Hub and transformers
    os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
    login(hf_token)

    # Pin transformers cache locally via env var
    cache_dir = "./hf_cache"
    os.makedirs(cache_dir, exist_ok=True)
    os.environ["TRANSFORMERS_CACHE"] = cache_dir

    # 1) BLIP image-captioning client
    caption_client = InferenceApi(
        repo_id="Salesforce/blip-image-captioning-base",
        token=hf_token
    )

    # 2) Text-generation pipeline on CPU (no cache_dir arg here!)
    t0 = time.time()
    story_generator = pipeline(
        task="text-generation",
        model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        device=-1  # force CPU
    )
    st.text(f"✅ Story model loaded in {time.time() - t0:.1f}s (cached thereafter)")

    return caption_client, story_generator

caption_client, story_generator = load_clients()


# —––––––– Helper: Generate Caption —–––––––
def generate_caption(img: Image.Image) -> str:
    buf = BytesIO()
    img.save(buf, format="JPEG")
    try:
        resp = caption_client(data=buf.getvalue())
        if isinstance(resp, list) and resp:
            return resp[0].get("generated_text", "").strip()
    except Exception as e:
        st.error(f"Caption generation error: {e}")
    return ""


# —––––––– Helper: Generate Story via pipeline —–––––––
def generate_story(caption: str) -> str:
    prompt = f"""
You are a creative children’s-story author.
Below is an image description:
{caption}

Write a coherent 50–100 word story that:
1. Introduces the main character.
2. Shows a simple problem or discovery.
3. Has a happy resolution.
4. Uses clear language for ages 3–8.
5. Keeps sentences under 20 words.
Story:
"""
    t0 = time.time()
    outputs = story_generator(
        prompt,
        max_new_tokens=120,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1,
        no_repeat_ngram_size=3,
        do_sample=True
    )
    st.text(f"⏱ Generated in {time.time() - t0:.1f}s on CPU")

    text = outputs[0]["generated_text"].strip()
    # strip the prompt echo
    if text.startswith(prompt):
        text = text[len(prompt):].strip()
    # enforce ≤100 words
    words = text.split()
    if len(words) > 100:
        text = " ".join(words[:100])
        if not text.endswith("."):
            text += "."
    return text


# —––––––– Main App Flow —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
if uploaded:
    img = Image.open(uploaded).convert("RGB")
    if max(img.size) > 2048:
        img.thumbnail((2048, 2048))
    st.image(img, use_container_width=True)

    with st.spinner("🔍 Generating caption..."):
        caption = generate_caption(img)
    if not caption:
        st.error("😢 Couldn't understand this image. Try another one!")
        st.stop()
    st.success(f"**Caption:** {caption}")

    with st.spinner("📝 Writing magical story..."):
        story = generate_story(caption)

    st.subheader("📚 Your Magical Story")
    st.write(story)

    with st.spinner("🔊 Converting to audio..."):
        try:
            tts = gTTS(text=story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
                tts.save(fp.name)
                st.audio(fp.name, format="audio/mp3")
        except Exception as e:
            st.warning(f"⚠️ TTS failed: {e}")

# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")