1
File size: 4,360 Bytes
0dcd353
ed4df47
8367fb2
 
7d2ac1c
0dcd353
422a749
6b1de29
8367fb2
 
c5def56
ed4df47
 
8367fb2
c4ae250
ed4df47
7d2ac1c
 
0dcd353
ed4df47
0dcd353
 
 
ed4df47
 
 
 
 
 
0dcd353
 
 
 
 
ed4df47
 
422a749
0dcd353
422a749
 
ed4df47
 
422a749
ed4df47
0dcd353
422a749
e1594b2
422a749
8367fb2
0dcd353
c5def56
0dcd353
422a749
 
301b896
ed4df47
 
 
301b896
0dcd353
 
 
422a749
 
ed4df47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422a749
 
ed4df47
 
 
 
 
422a749
 
ed4df47
 
0dcd353
ed4df47
 
0dcd353
 
ed4df47
422a749
 
 
 
 
 
bb9cd5a
0dcd353
c5def56
301b896
 
422a749
 
 
c4ae250
 
ed4df47
0dcd353
422a749
 
 
301b896
8e5f097
 
ed4df47
0dcd353
8e5f097
 
258bc7e
ed4df47
c2c4e19
8e5f097
 
 
 
c2c4e19
ed4df47
7d2ac1c
c5def56
8e5f097
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import time
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi, login
from transformers import pipeline
from gtts import gTTS
import tempfile

# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator (CPU)", layout="centered")
st.title("📖✨ Turn Images into Children's Stories (CPU)")

# —––––––– Clients (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_clients():
    hf_token = st.secrets["HF_TOKEN"]

    # Authenticate once so pipelines use your token automatically
    os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
    login(hf_token)

    # Pin cache locally to avoid re-downloads
    cache_dir = "./hf_cache"
    os.makedirs(cache_dir, exist_ok=True)
    os.environ["TRANSFORMERS_CACHE"] = cache_dir

    # 1) BLIP-based image captioning client
    caption_client = InferenceApi(
        repo_id="Salesforce/blip-image-captioning-base",
        token=hf_token
    )

    # 2) Text-generation pipeline forced onto CPU
    t0 = time.time()
    story_generator = pipeline(
        task="text-generation",
        model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        device=-1,             # CPU only
        cache_dir=cache_dir
    )
    st.text(f"✅ Story model loaded in {time.time() - t0:.1f}s (cached thereafter)")

    return caption_client, story_generator

caption_client, story_generator = load_clients()


# —––––––– Helper: Generate Caption —–––––––
def generate_caption(img: Image.Image) -> str:
    buf = BytesIO()
    img.save(buf, format="JPEG")
    try:
        resp = caption_client(data=buf.getvalue())
        if isinstance(resp, list) and resp:
            return resp[0].get("generated_text", "").strip()
    except Exception as e:
        st.error(f"Caption generation error: {e}")
    return ""


# —––––––– Helper: Generate Story via pipeline —–––––––
def generate_story(caption: str) -> str:
    prompt = f"""
You are a creative children’s-story author.
Below is the description of an image:
{caption}

Write a coherent, 50 to 100-word story that:
1. Introduces the main character from the image.
2. Shows a simple problem or discovery.
3. Resolves it in a happy ending.
4. Uses clear language for ages 3–8.
5. Keeps each sentence under 20 words.
Story:
"""
    t0 = time.time()
    outputs = story_generator(
        prompt,
        max_new_tokens=120,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1,
        no_repeat_ngram_size=3,
        do_sample=True
    )
    gen_time = time.time() - t0
    st.text(f"⏱ Generated in {gen_time:.1f}s on CPU")

    text = outputs[0]["generated_text"].strip()
    # Remove the echoed prompt portion
    if text.startswith(prompt):
        text = text[len(prompt):].strip()
    # Enforce max 100 words
    words = text.split()
    if len(words) > 100:
        text = " ".join(words[:100])
        if not text.endswith("."):
            text += "."
    return text


# —––––––– Main App Flow —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
    img = Image.open(uploaded).convert("RGB")
    if max(img.size) > 2048:
        img.thumbnail((2048, 2048))
    st.image(img, use_container_width=True)

    with st.spinner("🔍 Generating caption..."):
        caption = generate_caption(img)
    if not caption:
        st.error("😢 Couldn't understand this image. Try another one!")
        st.stop()
    st.success(f"**Caption:** {caption}")

    with st.spinner("📝 Writing magical story..."):
        story = generate_story(caption)

    st.subheader("📚 Your Magical Story")
    st.write(story)

    with st.spinner("🔊 Converting to audio..."):
        try:
            tts = gTTS(text=story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
                tts.save(fp.name)
                st.audio(fp.name, format="audio/mp3")
        except Exception as e:
            st.warning(f"⚠️ TTS failed: {e}")

# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")