File size: 3,978 Bytes
0dcd353 e508bdf 8367fb2 6adb177 88ee0a7 e508bdf 8367fb2 e508bdf 88ee0a7 8367fb2 6adb177 e508bdf 6adb177 88ee0a7 6adb177 88ee0a7 6adb177 b3abd21 e508bdf 88ee0a7 6adb177 88ee0a7 e508bdf 6adb177 88ee0a7 6adb177 88ee0a7 6adb177 88ee0a7 6adb177 88ee0a7 6adb177 bddc67c 422a749 6adb177 88ee0a7 b3abd21 88ee0a7 6adb177 88ee0a7 6adb177 e508bdf 6adb177 e508bdf 88ee0a7 e508bdf b3abd21 e508bdf 6adb177 e508bdf 6adb177 b3abd21 e508bdf 6adb177 b3abd21 e508bdf 6adb177 e508bdf 2aae3c9 e508bdf 6adb177 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import time
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi, login
from transformers import pipeline
import torch
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniTokenizer
from gtts import gTTS
import tempfile
# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")
# —––––––– Load Clients & Pipelines (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_clients():
hf_token = st.secrets["HF_TOKEN"]
# Authenticate for HF Hub
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
login(hf_token)
# 1) BLIP captioning via HF Inference API
caption_client = InferenceApi(
repo_id="Salesforce/blip-image-captioning-base",
token=hf_token
)
# 2) Qwen2.5-Omni story generator
t0 = time.time()
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-Omni-7B",
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True
)
tokenizer = Qwen2_5OmniTokenizer.from_pretrained(
"Qwen/Qwen2.5-Omni-7B",
trust_remote_code=True
)
storyteller = pipeline(
task="text2text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
max_new_tokens=120
)
load_time = time.time() - t0
st.text(f"✅ Story model loaded in {load_time:.1f}s (cached thereafter)")
return caption_client, storyteller
caption_client, storyteller = load_clients()
# —––––––– Helpers —–––––––
def generate_caption(img: Image.Image) -> str:
buf = BytesIO()
img.save(buf, format="JPEG")
resp = caption_client(data=buf.getvalue())
if isinstance(resp, list) and resp:
return resp[0].get("generated_text", "").strip()
return ""
def generate_story(caption: str) -> str:
prompt = (
"You are a creative children's-story author.\n"
f"Image description: “{caption}”\n\n"
"Write a coherent 50–100 word story\n"
)
t0 = time.time()
outputs = storyteller(prompt)
gen_time = time.time() - t0
st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
story = outputs[0]["generated_text"].strip()
# Enforce ≤100 words
words = story.split()
if len(words) > 100:
story = " ".join(words[:100])
if not story.endswith('.'):
story += '.'
return story
# —––––––– Main App —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
if uploaded:
img = Image.open(uploaded).convert("RGB")
if max(img.size) > 2048:
img.thumbnail((2048,2048))
st.image(img, use_container_width=True)
with st.spinner("🔍 Generating caption..."):
caption = generate_caption(img)
if not caption:
st.error("😢 Couldn't understand this image. Try another one!")
st.stop()
st.success(f"**Caption:** {caption}")
with st.spinner("📝 Writing story..."):
story = generate_story(caption)
st.subheader("📚 Your Magical Story")
st.write(story)
with st.spinner("🔊 Converting to audio..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.name, format="audio/mp3")
except Exception as e:
st.warning(f"⚠️ TTS failed: {e}")
# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*") |