1
File size: 2,746 Bytes
0dcd353
e508bdf
 
8367fb2
fd1d947
e508bdf
 
8367fb2
8087810
 
fd1d947
8367fb2
8087810
e508bdf
8087810
c876f7b
fd1d947
 
 
c876f7b
b3abd21
c876f7b
 
e508bdf
8087810
c876f7b
 
91713d8
c876f7b
e5b9c42
c876f7b
6adb177
fd1d947
 
8087810
e508bdf
6adb177
fd1d947
e508bdf
 
e5b9c42
b3abd21
8087810
e508bdf
8087810
c876f7b
e508bdf
 
 
 
 
c876f7b
91713d8
e508bdf
8087810
c876f7b
8087810
 
 
c876f7b
 
 
 
 
 
8087810
c876f7b
b3abd21
1573779
e508bdf
6adb177
b3abd21
c876f7b
e508bdf
 
6adb177
c876f7b
 
 
e508bdf
c876f7b
2aae3c9
e508bdf
e616e4e
2c0fb69
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import time
import streamlit as st
from PIL import Image
from transformers import pipeline
from gtts import gTTS
import tempfile

# —––––––– Page Setup —–––––––
st.set_page_config(page_title="Magic Story Generator", layout="centered")
st.title("📖✨ Turn Images into Children's Stories")

# —––––––– Load Pipelines (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_pipelines():
    # 1) Image captioning pipeline
    captioner = pipeline(
        task="image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device=-1
    )
    
    # 2) Story generation pipeline using verified model
    storyteller = pipeline(
        task="text2text-generation",
        model="laxya007/story-generator-t5-small",
        tokenizer="t5-small",
        device=-1,
        max_length=200,
        temperature=0.7,
        do_sample=True
    )
    return captioner, storyteller

captioner, storyteller = load_pipelines()

# —––––––– Main App —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
if uploaded:
    img = Image.open(uploaded).convert("RGB")
    st.image(img, use_column_width=True)

    # Generate caption
    with st.spinner("🔍 Generating caption..."):
        cap = captioner(img)
        caption = cap[0].get("generated_text", "").strip()
    if not caption:
        st.error("😢 Couldn't understand this image. Try another one!")
        st.stop()
    st.success(f"**Caption:** {caption}")

    # Generate story
    prompt = f"generate story: {caption}"
    with st.spinner("📝 Writing story..."):
        start = time.time()
        story = storyteller(prompt)[0]['generated_text']
        gen_time = time.time() - start
        st.text(f"⏱ Generated in {gen_time:.1f}s")

    # Format story output
    story = story.replace("<pad>", "").replace("</s>", "").strip()
    if story.startswith("generate story:"):
        story = story[15:].strip()
    
    # Word limit enforcement
    words = story.split()
    story = " ".join(words[:100]) if len(words) > 100 else story

    # Display story
    st.subheader("📚 Your Magical Story")
    st.write(story)

    # Audio conversion
    with st.spinner("🔊 Converting to audio..."):
        try:
            tts = gTTS(text=story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
                tts.save(tmp.name)
                st.audio(tmp.name, format="audio/mp3")
        except Exception as e:
            st.warning(f"⚠️ Audio conversion failed: {str(e)}")

# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")