1
File size: 2,101 Bytes
dfb3989
 
8367fb2
 
 
 
 
 
dfb3989
 
 
8367fb2
dfb3989
8367fb2
 
dfb3989
8367fb2
504dc12
dfb3989
8367fb2
dfb3989
8367fb2
 
dfb3989
8367fb2
 
 
dfb3989
8367fb2
dfb3989
 
 
 
 
8367fb2
dfb3989
 
 
 
 
 
8367fb2
dfb3989
 
dff0660
dfb3989
 
 
dd4f7ba
dff0660
 
 
 
 
 
 
 
dfb3989
 
1c165f8
dfb3989
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# app.py

import streamlit as st
from PIL import Image
from transformers import pipeline
from gtts import gTTS
import tempfile

# —––––––– Page config
st.set_page_config(page_title="Storyteller for Kids", layout="centered")
st.title("🖼️ ➡️ 📖 Interactive Storyteller")

# —––––––– Cache model loading
@st.cache_resource
def load_pipelines():
    # 1) Image-to-text (captioning)
    captioner = pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base"
    )
    # 2) Story generation with Flan-T5
    storyteller = pipeline(
        "text2text-generation",
        model="google/flan-t5-base"
    )
    return captioner, storyteller

captioner, storyteller = load_pipelines()

# —––––––– Image upload
uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
if uploaded:
    image = Image.open(uploaded).convert("RGB")
    st.image(image, caption="Your image", use_column_width=True)

    # —––––––– 1. Caption
    with st.spinner("🔍 Looking at the image..."):
        cap_outputs = captioner(image)
        # BLIP returns a list of dicts with key "generated_text"
        cap = cap_outputs[0].get("generated_text", "").strip()
    st.markdown(f"**Caption:** {cap}")

    # —––––––– 2. Story generation
    prompt = (
        "Write a 80–100 words story for kids"
        f"based on this description:\n\n“{cap}”\n\nStory:"
    )
    with st.spinner("✍️ Writing a story..."):
        out = storyteller(
        prompt,
        max_new_tokens=120, 
        do_sample=True,
        top_p=0.9,
        temperature=0.8,
        num_return_sequences=1
)
    story = out[0]["generated_text"].strip()
    st.markdown("**Story:**")
    st.write(story)

    # —––––––– 3. Text-to-Speech
    with st.spinner("🔊 Converting to speech..."):
        tts = gTTS(story, lang="en")
        tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
        tts.write_to_fp(tmp)
        tmp.flush()
    st.audio(tmp.name, format="audio/mp3")