1 / app.py
mayf's picture
Update app.py
6a2dbfc verified
# FIRST import and FIRST Streamlit command
import streamlit as st
st.set_page_config(
page_title="Magic Story Generator",
layout="centered",
page_icon="📖"
)
# Other imports
import re
import time
import torch
import tempfile
from PIL import Image
from gtts import gTTS
from transformers import pipeline, AutoTokenizer
# --- Constants & Setup ---
st.title("📖✨ Turn Images into Children's Stories")
# --- Model Loading (Cached) ---
@st.cache_resource(show_spinner=False)
def load_models():
# Image captioning model
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device=0 if torch.cuda.is_available() else -1
)
# Optimized story generation model
tokenizer = AutoTokenizer.from_pretrained("Deepthoughtworks/gpt-neo-2.7B__low-cpu")
storyteller = pipeline(
"text-generation",
model="Deepthoughtworks/gpt-neo-2.7B__low-cpu",
tokenizer=tokenizer,
device_map="auto",
torch_dtype=torch.float32, # Changed to float32 for better CPU compatibility
max_new_tokens=150, # Reduced length for faster generation
temperature=0.85,
top_k=40,
top_p=0.92,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id # Added for padding control
)
return captioner, storyteller
caption_pipe, story_pipe = load_models()
# --- Main Application Flow ---
uploaded_image = st.file_uploader(
"Upload a children's book style image:",
type=["jpg", "jpeg", "png"]
)
if uploaded_image:
# Process image
image = Image.open(uploaded_image).convert("RGB")
st.image(image, use_container_width=True) # Fixed deprecated parameter
# Generate caption
with st.spinner("🔍 Analyzing image..."):
try:
caption_result = caption_pipe(image)
image_caption = caption_result[0].get("generated_text", "").strip()
except Exception as e:
st.error(f"❌ Image analysis failed: {str(e)}")
st.stop()
if not image_caption:
st.error("❌ Couldn't understand this image. Please try another!")
st.stop()
st.success(f"**Image Understanding:** {image_caption}")
# Create story prompt
story_prompt = f"""Write a 50 to 100 words children's story based on: {image_caption}
Requirements:
- Exclude your thinking process
Story:"""
# Generate story with progress
progress_bar = st.progress(0)
status_text = st.empty()
try:
with st.spinner("📝 Crafting magical story..."):
start_time = time.time()
def update_progress(step):
progress = min(step/5, 1.0)
progress_bar.progress(progress)
status_text.text(f"Step {int(step)}/5: {'📖'*int(step)}")
update_progress(1)
story_result = story_pipe(
story_prompt,
do_sample=True,
num_return_sequences=1
)
update_progress(4)
generation_time = time.time() - start_time
st.info(f"Story generated in {generation_time:.1f} seconds")
# Process output
raw_story = story_result[0]['generated_text']
clean_story = raw_story.split("Story:")[-1].strip()
clean_story = re.sub(r'\n+', '\n\n', clean_story) # Improve paragraph spacing
# Format story text
final_story = ""
for paragraph in clean_story.split('\n\n'):
paragraph = paragraph.strip()
if paragraph:
sentences = []
for sent in re.split(r'(?<=[.!?]) +', paragraph):
sent = sent.strip()
if sent:
if len(sent) > 1 and not sent.endswith(('.','!','?')):
sent += '.'
sentences.append(sent[0].upper() + sent[1:])
final_story += ' '.join(sentences) + '\n\n'
update_progress(5)
time.sleep(0.5)
except Exception as e:
st.error(f"❌ Story generation failed: {str(e)}")
st.stop()
finally:
progress_bar.empty()
status_text.empty()
# Display story
st.subheader("✨ Your Magical Story")
st.write(final_story.strip())
# Audio conversion
with st.spinner("🔊 Creating audio version..."):
try:
audio = gTTS(text=final_story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
audio.save(tmp_file.name)
st.audio(tmp_file.name, format="audio/mp3")
except Exception as e:
st.error(f"❌ Audio conversion failed: {str(e)}")
# Footer
st.markdown("---")
st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")