|
import os |
|
import time |
|
import streamlit as st |
|
from PIL import Image |
|
from transformers import pipeline |
|
from gtts import gTTS |
|
import tempfile |
|
from llama_cpp import Llama |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Magic Story Generator", layout="centered") |
|
st.title("📖✨ Turn Images into Children's Stories") |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_models(): |
|
|
|
captioner = pipeline( |
|
"image-to-text", |
|
model="Salesforce/blip-image-captioning-base" |
|
) |
|
|
|
|
|
storyteller = Llama( |
|
model_path="DavidAU/L3-Grand-Story-Darkness-MOE-4X8-24.9B-e32-GGUF", |
|
n_ctx=2048, |
|
n_threads=4, |
|
n_gpu_layers=0 |
|
) |
|
return captioner, storyteller |
|
|
|
captioner, storyteller = load_models() |
|
|
|
|
|
uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"]) |
|
if uploaded: |
|
img = Image.open(uploaded).convert("RGB") |
|
st.image(img, use_column_width=True) |
|
|
|
|
|
with st.spinner("🔍 Generating caption..."): |
|
cap = captioner(img) |
|
caption = cap[0]['generated_text'] |
|
st.success(f"**Caption:** {caption}") |
|
|
|
|
|
prompt = f"""Below is an image description. Write a children's story based on it. |
|
|
|
Image Description: {caption} |
|
Story:""" |
|
|
|
with st.spinner("📝 Crafting magical story..."): |
|
start = time.time() |
|
output = storyteller( |
|
prompt=prompt, |
|
max_tokens=500, |
|
temperature=0.7, |
|
top_p=0.9, |
|
repeat_penalty=1.1 |
|
) |
|
gen_time = time.time() - start |
|
story = output['choices'][0]['text'].strip() |
|
st.text(f"⏱ Generated in {gen_time:.1f}s") |
|
|
|
|
|
story = story.split("###")[0].strip() |
|
|
|
|
|
st.subheader("📚 Your Magical Story") |
|
st.write(story) |
|
|
|
|
|
with st.spinner("🔊 Converting to audio..."): |
|
try: |
|
tts = gTTS(text=story, lang="en", slow=False) |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: |
|
tts.save(tmp.name) |
|
st.audio(tmp.name, format="audio/mp3") |
|
except Exception as e: |
|
st.warning(f"⚠️ Audio conversion failed: {str(e)}") |
|
|
|
|
|
st.markdown("---\n*Made with ❤️ by your friendly story wizard*") |
|
|