1
File size: 3,767 Bytes
0dcd353
e508bdf
 
8367fb2
6adb177
fd1d947
 
e508bdf
 
8367fb2
748a576
fd1d947
748a576
e508bdf
fd1d947
 
8367fb2
6adb177
e508bdf
6adb177
fd1d947
 
 
 
 
 
 
 
 
 
 
b3abd21
e508bdf
fd1d947
e508bdf
2c0fb69
fd1d947
 
 
 
 
88ee0a7
fd1d947
 
2c0fb69
 
6adb177
 
fd1d947
 
 
6adb177
 
 
fd1d947
 
 
 
6adb177
 
88ee0a7
6adb177
fd1d947
6adb177
fd1d947
 
422a749
fd1d947
6adb177
fd1d947
 
 
88ee0a7
fd1d947
b3abd21
fd1d947
 
6adb177
 
748a576
6adb177
e508bdf
6adb177
fd1d947
e508bdf
 
 
fd1d947
e508bdf
b3abd21
e508bdf
6adb177
e508bdf
 
 
 
 
 
6adb177
b3abd21
e508bdf
6adb177
b3abd21
e508bdf
 
6adb177
e508bdf
 
 
 
 
2aae3c9
e508bdf
e616e4e
2c0fb69
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import time
import streamlit as st
from PIL import Image
from io import BytesIO
from transformers import pipeline
from huggingface_hub import login
from gtts import gTTS
import tempfile

# —––––––– Requirements —–––––––
# pip install streamlit pillow gTTS transformers huggingface_hub

# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator (Local Pipeline)", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")

# —––––––– Load Clients & Pipelines (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_clients():
    # Authenticate to pull private or remote-code models if needed
    hf_token = st.secrets.get("HF_TOKEN")
    if hf_token:
        os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
        login(hf_token)

    # 1) Image-captioning pipeline (BLIP)
    captioner = pipeline(
        task="image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device=-1  # CPU; change to 0 for GPU
    )

    # 2) Story-generation pipeline (DeepSeek-R1-Distill-Qwen)
    storyteller = pipeline(
        task="text-generation",
        model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        trust_remote_code=True,
        device=-1,                # CPU; set 0+ for GPU
        temperature=0.6,
        top_p=0.9,
        repetition_penalty=1.1,
        no_repeat_ngram_size=2,
        max_new_tokens=120,
        return_full_text=False
    )

    return captioner, storyteller

captioner, storyteller = load_clients()

# —––––––– Helpers —–––––––
def generate_caption(img: Image.Image) -> str:
    # Use the BLIP pipeline to generate a caption
    result = captioner(img)
    if isinstance(result, list) and result:
        return result[0].get("generated_text", "").strip()
    return ""


def generate_story(caption: str) -> str:
    # Build a simple prompt incorporating the caption
    prompt = (
        f"Image description: {caption}\n"
        "Write a coherent 50-100 word children's story that flows naturally."
    )

    t0 = time.time()
    outputs = storyteller(
        prompt
    )
    gen_time = time.time() - t0
    st.text(f"⏱ Generated in {gen_time:.1f}s")

    story = outputs[0].get("generated_text", "").strip()
    # Truncate to 100 words
    words = story.split()
    if len(words) > 100:
        story = " ".join(words[:100]) + ('.' if not story.endswith('.') else '')
    return story

# —––––––– Main App —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
    img = Image.open(uploaded).convert("RGB")
    if max(img.size) > 2048:
        img.thumbnail((2048, 2048))
    st.image(img, use_container_width=True)

    with st.spinner("🔍 Generating caption..."):
        caption = generate_caption(img)
    if not caption:
        st.error("😢 Couldn't understand this image. Try another one!")
        st.stop()
    st.success(f"**Caption:** {caption}")

    with st.spinner("📝 Writing story..."):
        story = generate_story(caption)

    st.subheader("📚 Your Magical Story")
    st.write(story)

    with st.spinner("🔊 Converting to audio..."):
        try:
            tts = gTTS(text=story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
                tts.save(fp.name)
                st.audio(fp.name, format="audio/mp3")
        except Exception as e:
            st.warning(f"⚠️ TTS failed: {e}")

# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")