1
File size: 4,187 Bytes
0dcd353
e508bdf
 
8367fb2
6adb177
 
e616e4e
88ee0a7
e508bdf
 
8367fb2
e508bdf
88ee0a7
 
8367fb2
6adb177
e508bdf
6adb177
 
e616e4e
6adb177
 
 
e616e4e
6adb177
 
 
b3abd21
e508bdf
e616e4e
6adb177
e616e4e
88ee0a7
 
 
e616e4e
88ee0a7
e616e4e
 
 
 
88ee0a7
e616e4e
e508bdf
6adb177
88ee0a7
 
 
 
 
 
 
 
6adb177
88ee0a7
e616e4e
6adb177
 
 
 
 
 
 
 
 
 
 
 
 
88ee0a7
6adb177
 
88ee0a7
6adb177
e616e4e
 
 
 
 
 
422a749
6adb177
e616e4e
88ee0a7
 
b3abd21
e616e4e
88ee0a7
6adb177
 
 
88ee0a7
 
6adb177
e508bdf
6adb177
 
e508bdf
 
 
88ee0a7
e508bdf
b3abd21
e508bdf
6adb177
e508bdf
 
 
 
 
 
6adb177
b3abd21
e508bdf
6adb177
b3abd21
e508bdf
 
6adb177
e508bdf
 
 
 
 
2aae3c9
e508bdf
e616e4e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import time
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi, login
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from gtts import gTTS
import tempfile

# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")

# —––––––– Load Clients & Pipelines (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_clients():
    hf_token = st.secrets["HF_TOKEN"]
    # Authenticate for Hugging Face Hub
    os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
    login(hf_token)

    # 1) BLIP captioning via HTTP API
    caption_client = InferenceApi(
        repo_id="Salesforce/blip-image-captioning-base",
        token=hf_token
    )

    # 2) Load Qwen2.5-Omni model & tokenizer
    t0 = time.time()
    tokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen2.5-Omni-7B",
        trust_remote_code=True
    )
    model = AutoModelForSeq2SeqLM.from_pretrained(
        "Qwen/Qwen2.5-Omni-7B",
        trust_remote_code=True,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2"
    )
    # 3) Build text2text pipeline
    storyteller = pipeline(
        task="text2text-generation",
        model=model,
        tokenizer=tokenizer,
        device_map="auto",
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.2,
        no_repeat_ngram_size=3,
        max_new_tokens=120
    )
    load_time = time.time() - t0
    st.text(f"✅ Story model loaded in {load_time:.1f}s (cached)")
    return caption_client, storyteller

caption_client, storyteller = load_clients()

# —––––––– Helpers —–––––––
def generate_caption(img: Image.Image) -> str:
    buf = BytesIO()
    img.save(buf, format="JPEG")
    resp = caption_client(data=buf.getvalue())
    if isinstance(resp, list) and resp:
        return resp[0].get("generated_text", "").strip()
    return ""


def generate_story(caption: str) -> str:
    prompt = (
        "You are a creative children's-story author.\n"
        f"Image description: “{caption}”\n\n"
        "Write a coherent 50–100 word story that:\n"
        "1. Introduces the main character.\n"
        "2. Shows a simple problem or discovery.\n"
        "3. Has a happy resolution.\n"
        "4. Uses clear language for ages 3–8.\n"
        "5. Keeps each sentence under 20 words.\n"
    )
    t0 = time.time()
    result = storyteller(prompt)
    gen_time = time.time() - t0
    st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")

    story = result[0]["generated_text"].strip()
    # Enforce ≤100 words
    words = story.split()
    if len(words) > 100:
        story = " ".join(words[:100])
        if not story.endswith('.'):
            story += '.'
    return story

# —––––––– Main App —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
if uploaded:
    img = Image.open(uploaded).convert("RGB")
    if max(img.size) > 2048:
        img.thumbnail((2048,2048))
    st.image(img, use_container_width=True)

    with st.spinner("🔍 Generating caption..."):
        caption = generate_caption(img)
    if not caption:
        st.error("😢 Couldn't understand this image. Try another one!")
        st.stop()
    st.success(f"**Caption:** {caption}")

    with st.spinner("📝 Writing story..."):
        story = generate_story(caption)

    st.subheader("📚 Your Magical Story")
    st.write(story)

    with st.spinner("🔊 Converting to audio..."):
        try:
            tts = gTTS(text=story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
                tts.save(fp.name)
                st.audio(fp.name, format="audio/mp3")
        except Exception as e:
            st.warning(f"⚠️ TTS failed: {e}")

# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")