1
File size: 3,978 Bytes
0dcd353
e508bdf
 
8367fb2
6adb177
 
 
88ee0a7
 
e508bdf
 
8367fb2
e508bdf
88ee0a7
 
8367fb2
6adb177
e508bdf
6adb177
 
88ee0a7
6adb177
 
 
88ee0a7
6adb177
 
 
b3abd21
e508bdf
88ee0a7
6adb177
88ee0a7
 
 
 
 
 
 
 
 
 
 
e508bdf
6adb177
88ee0a7
 
 
 
 
 
 
 
6adb177
88ee0a7
 
 
6adb177
 
 
 
 
 
 
 
 
 
 
 
 
88ee0a7
6adb177
 
88ee0a7
6adb177
bddc67c
422a749
6adb177
88ee0a7
 
 
b3abd21
88ee0a7
 
6adb177
 
 
88ee0a7
 
6adb177
e508bdf
6adb177
 
e508bdf
 
 
88ee0a7
e508bdf
b3abd21
e508bdf
6adb177
e508bdf
 
 
 
 
 
6adb177
b3abd21
e508bdf
6adb177
b3abd21
e508bdf
 
6adb177
e508bdf
 
 
 
 
2aae3c9
e508bdf
6adb177
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import time
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi, login
from transformers import pipeline
import torch
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniTokenizer
from gtts import gTTS
import tempfile

# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")

# —––––––– Load Clients & Pipelines (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_clients():
    hf_token = st.secrets["HF_TOKEN"]
    # Authenticate for HF Hub
    os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
    login(hf_token)

    # 1) BLIP captioning via HF Inference API
    caption_client = InferenceApi(
        repo_id="Salesforce/blip-image-captioning-base",
        token=hf_token
    )

    # 2) Qwen2.5-Omni story generator
    t0 = time.time()
    model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-Omni-7B",
        device_map="auto",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        trust_remote_code=True
    )
    tokenizer = Qwen2_5OmniTokenizer.from_pretrained(
        "Qwen/Qwen2.5-Omni-7B",
        trust_remote_code=True
    )
    storyteller = pipeline(
        task="text2text-generation",
        model=model,
        tokenizer=tokenizer,
        device_map="auto",
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.2,
        no_repeat_ngram_size=3,
        max_new_tokens=120
    )
    load_time = time.time() - t0
    st.text(f"✅ Story model loaded in {load_time:.1f}s (cached thereafter)")

    return caption_client, storyteller

caption_client, storyteller = load_clients()

# —––––––– Helpers —–––––––
def generate_caption(img: Image.Image) -> str:
    buf = BytesIO()
    img.save(buf, format="JPEG")
    resp = caption_client(data=buf.getvalue())
    if isinstance(resp, list) and resp:
        return resp[0].get("generated_text", "").strip()
    return ""


def generate_story(caption: str) -> str:
    prompt = (
        "You are a creative children's-story author.\n"
        f"Image description: “{caption}”\n\n"
        "Write a coherent 50–100 word story\n"
    )
    t0 = time.time()
    outputs = storyteller(prompt)
    gen_time = time.time() - t0
    st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")

    story = outputs[0]["generated_text"].strip()
    # Enforce ≤100 words
    words = story.split()
    if len(words) > 100:
        story = " ".join(words[:100])
        if not story.endswith('.'):
            story += '.'
    return story

# —––––––– Main App —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
if uploaded:
    img = Image.open(uploaded).convert("RGB")
    if max(img.size) > 2048:
        img.thumbnail((2048,2048))
    st.image(img, use_container_width=True)

    with st.spinner("🔍 Generating caption..."):
        caption = generate_caption(img)
    if not caption:
        st.error("😢 Couldn't understand this image. Try another one!")
        st.stop()
    st.success(f"**Caption:** {caption}")

    with st.spinner("📝 Writing story..."):
        story = generate_story(caption)

    st.subheader("📚 Your Magical Story")
    st.write(story)

    with st.spinner("🔊 Converting to audio..."):
        try:
            tts = gTTS(text=story, lang="en", slow=False)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
                tts.save(fp.name)
                st.audio(fp.name, format="audio/mp3")
        except Exception as e:
            st.warning(f"⚠️ TTS failed: {e}")

# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")