1
File size: 2,735 Bytes
dfb3989
 
8367fb2
 
 
6b1de29
8367fb2
 
dfb3989
 
 
8367fb2
c916589
8367fb2
 
dd489ad
8367fb2
504dc12
33fead7
dd489ad
8367fb2
dd489ad
8367fb2
b3f64ee
c916589
 
8367fb2
c916589
dd489ad
33fead7
c916589
 
 
8367fb2
 
dfb3989
8367fb2
dd489ad
dfb3989
 
dd489ad
dfb3989
c916589
 
8367fb2
6b1de29
c916589
 
dfb3989
8367fb2
dd489ad
dfb3989
dd489ad
 
dfb3989
c916589
dd4f7ba
b3f64ee
dd489ad
 
 
 
 
 
 
b3f64ee
dd489ad
 
dfb3989
 
1c165f8
dd489ad
dfb3989
6b1de29
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# app.py

import streamlit as st
from PIL import Image
from transformers import pipeline
from gtts import gTTS
import tempfile

# —––––––– Page config
st.set_page_config(page_title="Storyteller for Kids", layout="centered")
st.title("🖼️ ➡️ 📖 Interactive Storyteller")

# —––––––– Model loading + warm-up
@st.cache_resource
def load_pipelines():
    # 1) Original BLIP-base for captions
    captioner = pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device=0  # change to -1 if you only have CPU
    )
    # 2) Small GPT-Neo for quick stories
    storyteller = pipeline(
        "text-generation",
        model="EleutherAI/gpt-neo-125M",
        device=0
    )

    # Warm up both so the first real call is faster
    dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
    captioner(dummy)
    storyteller("Hello", max_new_tokens=1)

    return captioner, storyteller

captioner, storyteller = load_pipelines()

# —––––––– Main UI
uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
if uploaded:
    # 1) Load + resize for faster encoding
    image = Image.open(uploaded).convert("RGB")
    image = image.resize((384, 384), Image.LANCZOS)
    st.image(image, caption="Your image", use_container_width=True)

    # 2) Caption step
    with st.spinner("🔍 Generating caption..."):
        cap = captioner(image)[0]["generated_text"].strip()
    st.markdown(f"**Caption:** {cap}")

    # 3) Story generation (sampling + repetition control)
    prompt = (
        f"Write an 80–100 word fun story for 3–10 year-old children "
        f"based on this description:\n\n“{cap}”\n\nStory: "
    )
    with st.spinner("✍️ Generating story..."):
        out = storyteller(
            prompt,
            max_new_tokens=120,        # room for ~100 words
            do_sample=True,            # enable sampling
            temperature=0.8,           # creativity
            top_p=0.9,                 # nucleus sampling
            top_k=50,                  # limit to top 50 tokens
            repetition_penalty=1.2,    # discourage exact repeats
            no_repeat_ngram_size=3     # prevent 3-gram repeats
        )
        # strip off the prompt so only the story remains
        story = out[0]["generated_text"][len(prompt):].strip()
    st.markdown("**Story:**")
    st.write(story)

    # 4) Text-to-Speech via gTTS
    with st.spinner("🔊 Converting to speech..."):
        tts = gTTS(text=story, lang="en")
        tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
        tts.write_to_fp(tmp)
        tmp.flush()
    st.audio(tmp.name, format="audio/mp3")