File size: 5,024 Bytes
6a2dbfc 4d1f328 8151df4 9862828 8151df4 9862828 8151df4 c83a777 6a2dbfc c83a777 6a2dbfc 6523fb1 6a2dbfc 6523fb1 6a2dbfc 8151df4 fd1d947 982555a 6a2dbfc 8151df4 6a2dbfc 8151df4 6a2dbfc 8151df4 9862828 6a2dbfc 9862828 6a2dbfc 6523fb1 6a2dbfc 6523fb1 6a2dbfc 9862828 6a2dbfc 9862828 6523fb1 6a2dbfc 9862828 6a2dbfc eec20c9 6a2dbfc eec20c9 6523fb1 6a2dbfc eec20c9 6a2dbfc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# FIRST import and FIRST Streamlit command
import streamlit as st
st.set_page_config(
page_title="Magic Story Generator",
layout="centered",
page_icon="📖"
)
# Other imports
import re
import time
import torch
import tempfile
from PIL import Image
from gtts import gTTS
from transformers import pipeline, AutoTokenizer
# --- Constants & Setup ---
st.title("📖✨ Turn Images into Children's Stories")
# --- Model Loading (Cached) ---
@st.cache_resource(show_spinner=False)
def load_models():
# Image captioning model
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device=0 if torch.cuda.is_available() else -1
)
# Optimized story generation model
tokenizer = AutoTokenizer.from_pretrained("Deepthoughtworks/gpt-neo-2.7B__low-cpu")
storyteller = pipeline(
"text-generation",
model="Deepthoughtworks/gpt-neo-2.7B__low-cpu",
tokenizer=tokenizer,
device_map="auto",
torch_dtype=torch.float32, # Changed to float32 for better CPU compatibility
max_new_tokens=150, # Reduced length for faster generation
temperature=0.85,
top_k=40,
top_p=0.92,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id # Added for padding control
)
return captioner, storyteller
caption_pipe, story_pipe = load_models()
# --- Main Application Flow ---
uploaded_image = st.file_uploader(
"Upload a children's book style image:",
type=["jpg", "jpeg", "png"]
)
if uploaded_image:
# Process image
image = Image.open(uploaded_image).convert("RGB")
st.image(image, use_container_width=True) # Fixed deprecated parameter
# Generate caption
with st.spinner("🔍 Analyzing image..."):
try:
caption_result = caption_pipe(image)
image_caption = caption_result[0].get("generated_text", "").strip()
except Exception as e:
st.error(f"❌ Image analysis failed: {str(e)}")
st.stop()
if not image_caption:
st.error("❌ Couldn't understand this image. Please try another!")
st.stop()
st.success(f"**Image Understanding:** {image_caption}")
# Create story prompt
story_prompt = f"""Write a 50 to 100 words children's story based on: {image_caption}
Requirements:
- Exclude your thinking process
Story:"""
# Generate story with progress
progress_bar = st.progress(0)
status_text = st.empty()
try:
with st.spinner("📝 Crafting magical story..."):
start_time = time.time()
def update_progress(step):
progress = min(step/5, 1.0)
progress_bar.progress(progress)
status_text.text(f"Step {int(step)}/5: {'📖'*int(step)}")
update_progress(1)
story_result = story_pipe(
story_prompt,
do_sample=True,
num_return_sequences=1
)
update_progress(4)
generation_time = time.time() - start_time
st.info(f"Story generated in {generation_time:.1f} seconds")
# Process output
raw_story = story_result[0]['generated_text']
clean_story = raw_story.split("Story:")[-1].strip()
clean_story = re.sub(r'\n+', '\n\n', clean_story) # Improve paragraph spacing
# Format story text
final_story = ""
for paragraph in clean_story.split('\n\n'):
paragraph = paragraph.strip()
if paragraph:
sentences = []
for sent in re.split(r'(?<=[.!?]) +', paragraph):
sent = sent.strip()
if sent:
if len(sent) > 1 and not sent.endswith(('.','!','?')):
sent += '.'
sentences.append(sent[0].upper() + sent[1:])
final_story += ' '.join(sentences) + '\n\n'
update_progress(5)
time.sleep(0.5)
except Exception as e:
st.error(f"❌ Story generation failed: {str(e)}")
st.stop()
finally:
progress_bar.empty()
status_text.empty()
# Display story
st.subheader("✨ Your Magical Story")
st.write(final_story.strip())
# Audio conversion
with st.spinner("🔊 Creating audio version..."):
try:
audio = gTTS(text=final_story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
audio.save(tmp_file.name)
st.audio(tmp_file.name, format="audio/mp3")
except Exception as e:
st.error(f"❌ Audio conversion failed: {str(e)}")
# Footer
st.markdown("---")
st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")
|