File size: 4,187 Bytes
0dcd353 e508bdf 8367fb2 6adb177 e616e4e 88ee0a7 e508bdf 8367fb2 e508bdf 88ee0a7 8367fb2 6adb177 e508bdf 6adb177 e616e4e 6adb177 e616e4e 6adb177 b3abd21 e508bdf e616e4e 6adb177 e616e4e 88ee0a7 e616e4e 88ee0a7 e616e4e 88ee0a7 e616e4e e508bdf 6adb177 88ee0a7 6adb177 88ee0a7 e616e4e 6adb177 88ee0a7 6adb177 88ee0a7 6adb177 e616e4e 422a749 6adb177 e616e4e 88ee0a7 b3abd21 e616e4e 88ee0a7 6adb177 88ee0a7 6adb177 e508bdf 6adb177 e508bdf 88ee0a7 e508bdf b3abd21 e508bdf 6adb177 e508bdf 6adb177 b3abd21 e508bdf 6adb177 b3abd21 e508bdf 6adb177 e508bdf 2aae3c9 e508bdf e616e4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
import time
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi, login
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from gtts import gTTS
import tempfile
# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")
# —––––––– Load Clients & Pipelines (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_clients():
hf_token = st.secrets["HF_TOKEN"]
# Authenticate for Hugging Face Hub
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
login(hf_token)
# 1) BLIP captioning via HTTP API
caption_client = InferenceApi(
repo_id="Salesforce/blip-image-captioning-base",
token=hf_token
)
# 2) Load Qwen2.5-Omni model & tokenizer
t0 = time.time()
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2.5-Omni-7B",
trust_remote_code=True
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"Qwen/Qwen2.5-Omni-7B",
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
)
# 3) Build text2text pipeline
storyteller = pipeline(
task="text2text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
max_new_tokens=120
)
load_time = time.time() - t0
st.text(f"✅ Story model loaded in {load_time:.1f}s (cached)")
return caption_client, storyteller
caption_client, storyteller = load_clients()
# —––––––– Helpers —–––––––
def generate_caption(img: Image.Image) -> str:
buf = BytesIO()
img.save(buf, format="JPEG")
resp = caption_client(data=buf.getvalue())
if isinstance(resp, list) and resp:
return resp[0].get("generated_text", "").strip()
return ""
def generate_story(caption: str) -> str:
prompt = (
"You are a creative children's-story author.\n"
f"Image description: “{caption}”\n\n"
"Write a coherent 50–100 word story that:\n"
"1. Introduces the main character.\n"
"2. Shows a simple problem or discovery.\n"
"3. Has a happy resolution.\n"
"4. Uses clear language for ages 3–8.\n"
"5. Keeps each sentence under 20 words.\n"
)
t0 = time.time()
result = storyteller(prompt)
gen_time = time.time() - t0
st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
story = result[0]["generated_text"].strip()
# Enforce ≤100 words
words = story.split()
if len(words) > 100:
story = " ".join(words[:100])
if not story.endswith('.'):
story += '.'
return story
# —––––––– Main App —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
if uploaded:
img = Image.open(uploaded).convert("RGB")
if max(img.size) > 2048:
img.thumbnail((2048,2048))
st.image(img, use_container_width=True)
with st.spinner("🔍 Generating caption..."):
caption = generate_caption(img)
if not caption:
st.error("😢 Couldn't understand this image. Try another one!")
st.stop()
st.success(f"**Caption:** {caption}")
with st.spinner("📝 Writing story..."):
story = generate_story(caption)
st.subheader("📚 Your Magical Story")
st.write(story)
with st.spinner("🔊 Converting to audio..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.name, format="audio/mp3")
except Exception as e:
st.warning(f"⚠️ TTS failed: {e}")
# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
|