|
|
|
import streamlit as st |
|
st.set_page_config( |
|
page_title="Magic Story Generator", |
|
layout="centered", |
|
page_icon="📖" |
|
) |
|
|
|
|
|
import re |
|
import time |
|
import torch |
|
import tempfile |
|
from PIL import Image |
|
from gtts import gTTS |
|
from transformers import pipeline, AutoTokenizer |
|
|
|
|
|
st.title("📖✨ Turn Images into Children's Stories") |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_models(): |
|
|
|
captioner = pipeline( |
|
"image-to-text", |
|
model="Salesforce/blip-image-captioning-base", |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Deepthoughtworks/gpt-neo-2.7B__low-cpu") |
|
storyteller = pipeline( |
|
"text-generation", |
|
model="Deepthoughtworks/gpt-neo-2.7B__low-cpu", |
|
tokenizer=tokenizer, |
|
device_map="auto", |
|
torch_dtype=torch.float32, |
|
max_new_tokens=150, |
|
temperature=0.85, |
|
top_k=40, |
|
top_p=0.92, |
|
repetition_penalty=1.15, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
return captioner, storyteller |
|
|
|
caption_pipe, story_pipe = load_models() |
|
|
|
|
|
uploaded_image = st.file_uploader( |
|
"Upload a children's book style image:", |
|
type=["jpg", "jpeg", "png"] |
|
) |
|
|
|
if uploaded_image: |
|
|
|
image = Image.open(uploaded_image).convert("RGB") |
|
st.image(image, use_container_width=True) |
|
|
|
|
|
with st.spinner("🔍 Analyzing image..."): |
|
try: |
|
caption_result = caption_pipe(image) |
|
image_caption = caption_result[0].get("generated_text", "").strip() |
|
except Exception as e: |
|
st.error(f"❌ Image analysis failed: {str(e)}") |
|
st.stop() |
|
|
|
if not image_caption: |
|
st.error("❌ Couldn't understand this image. Please try another!") |
|
st.stop() |
|
|
|
st.success(f"**Image Understanding:** {image_caption}") |
|
|
|
|
|
story_prompt = f"""Write a 50 to 100 words children's story based on: {image_caption} |
|
Requirements: |
|
- Exclude your thinking process |
|
Story:""" |
|
|
|
|
|
progress_bar = st.progress(0) |
|
status_text = st.empty() |
|
|
|
try: |
|
with st.spinner("📝 Crafting magical story..."): |
|
start_time = time.time() |
|
|
|
def update_progress(step): |
|
progress = min(step/5, 1.0) |
|
progress_bar.progress(progress) |
|
status_text.text(f"Step {int(step)}/5: {'📖'*int(step)}") |
|
|
|
update_progress(1) |
|
story_result = story_pipe( |
|
story_prompt, |
|
do_sample=True, |
|
num_return_sequences=1 |
|
) |
|
|
|
update_progress(4) |
|
generation_time = time.time() - start_time |
|
st.info(f"Story generated in {generation_time:.1f} seconds") |
|
|
|
|
|
raw_story = story_result[0]['generated_text'] |
|
clean_story = raw_story.split("Story:")[-1].strip() |
|
clean_story = re.sub(r'\n+', '\n\n', clean_story) |
|
|
|
|
|
final_story = "" |
|
for paragraph in clean_story.split('\n\n'): |
|
paragraph = paragraph.strip() |
|
if paragraph: |
|
sentences = [] |
|
for sent in re.split(r'(?<=[.!?]) +', paragraph): |
|
sent = sent.strip() |
|
if sent: |
|
if len(sent) > 1 and not sent.endswith(('.','!','?')): |
|
sent += '.' |
|
sentences.append(sent[0].upper() + sent[1:]) |
|
final_story += ' '.join(sentences) + '\n\n' |
|
|
|
update_progress(5) |
|
time.sleep(0.5) |
|
|
|
except Exception as e: |
|
st.error(f"❌ Story generation failed: {str(e)}") |
|
st.stop() |
|
|
|
finally: |
|
progress_bar.empty() |
|
status_text.empty() |
|
|
|
|
|
st.subheader("✨ Your Magical Story") |
|
st.write(final_story.strip()) |
|
|
|
|
|
with st.spinner("🔊 Creating audio version..."): |
|
try: |
|
audio = gTTS(text=final_story, lang="en", slow=False) |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: |
|
audio.save(tmp_file.name) |
|
st.audio(tmp_file.name, format="audio/mp3") |
|
except Exception as e: |
|
st.error(f"❌ Audio conversion failed: {str(e)}") |
|
|
|
|
|
st.markdown("---") |
|
st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)") |
|
|
|
|