Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
#
|
2 |
import streamlit as st
|
3 |
st.set_page_config(
|
4 |
page_title="Magic Story Generator",
|
@@ -13,38 +13,64 @@ import torch
|
|
13 |
import tempfile
|
14 |
from PIL import Image
|
15 |
from gtts import gTTS
|
16 |
-
from transformers import pipeline
|
17 |
|
18 |
# --- Constants & Setup ---
|
19 |
st.title("📖✨ Turn Images into Children's Stories")
|
20 |
|
21 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
@st.cache_resource(show_spinner=False)
|
23 |
def load_models():
|
24 |
-
# Image captioning
|
25 |
captioner = pipeline(
|
26 |
"image-to-text",
|
27 |
model="Salesforce/blip-image-captioning-base",
|
28 |
device=0 if torch.cuda.is_available() else -1
|
29 |
)
|
30 |
|
31 |
-
#
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
"text-generation",
|
34 |
model="Qwen/Qwen3-0.6B",
|
|
|
35 |
device_map="auto",
|
36 |
torch_dtype=torch.float16,
|
37 |
-
max_new_tokens=
|
38 |
-
temperature=0.
|
39 |
-
top_k=50,
|
40 |
top_p=0.9,
|
41 |
-
repetition_penalty=1.
|
42 |
-
|
|
|
43 |
)
|
44 |
|
45 |
-
return captioner,
|
46 |
-
|
47 |
-
caption_pipe, story_pipe = load_models()
|
48 |
|
49 |
# --- Main Application Flow ---
|
50 |
uploaded_image = st.file_uploader(
|
@@ -53,7 +79,6 @@ uploaded_image = st.file_uploader(
|
|
53 |
)
|
54 |
|
55 |
if uploaded_image:
|
56 |
-
# Process image
|
57 |
image = Image.open(uploaded_image).convert("RGB")
|
58 |
st.image(image, use_column_width=True)
|
59 |
|
@@ -72,73 +97,54 @@ if uploaded_image:
|
|
72 |
|
73 |
st.success(f"**Image Understanding:** {image_caption}")
|
74 |
|
75 |
-
#
|
76 |
-
story_prompt =
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
try:
|
88 |
with st.spinner("📝 Crafting magical story..."):
|
89 |
start_time = time.time()
|
90 |
|
91 |
-
def update_progress(step):
|
92 |
-
progress = min(step/5, 1.0) # Simulate progress steps
|
93 |
-
progress_bar.progress(progress)
|
94 |
-
status_text.text(f"Step {int(step)}/5: {'📖'*int(step)}")
|
95 |
-
|
96 |
-
update_progress(1)
|
97 |
story_result = story_pipe(
|
98 |
story_prompt,
|
99 |
-
|
100 |
-
|
101 |
)
|
102 |
|
103 |
-
|
104 |
-
generation_time = time.time() - start_time
|
105 |
-
st.info(f"Story generated in {generation_time:.1f} seconds")
|
106 |
-
|
107 |
-
# Process output
|
108 |
raw_story = story_result[0]['generated_text']
|
109 |
-
clean_story = raw_story.split("<|im_start|>assistant
|
110 |
-
clean_story = re.sub(r'<\|.*?\|>', '', clean_story).strip()
|
111 |
|
112 |
-
# Format
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
if len(sent) > 1 and not sent.endswith(('.','!','?')):
|
118 |
-
sent += '.'
|
119 |
-
sentences.append(sent[0].upper() + sent[1:])
|
120 |
-
|
121 |
-
final_story = ' '.join(sentences)[:600] # Limit length
|
122 |
-
|
123 |
-
update_progress(5)
|
124 |
-
time.sleep(0.5) # Final progress pause
|
125 |
|
126 |
except Exception as e:
|
127 |
st.error(f"❌ Story generation failed: {str(e)}")
|
128 |
st.stop()
|
129 |
|
130 |
-
finally:
|
131 |
-
progress_bar.empty()
|
132 |
-
status_text.empty()
|
133 |
-
|
134 |
# Display story
|
135 |
st.subheader("✨ Your Magical Story")
|
136 |
-
st.
|
137 |
|
138 |
# Audio conversion
|
139 |
with st.spinner("🔊 Creating audio version..."):
|
140 |
try:
|
141 |
-
audio = gTTS(text=
|
142 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
|
143 |
audio.save(tmp_file.name)
|
144 |
st.audio(tmp_file.name, format="audio/mp3")
|
|
|
1 |
+
# Import Streamlit first
|
2 |
import streamlit as st
|
3 |
st.set_page_config(
|
4 |
page_title="Magic Story Generator",
|
|
|
13 |
import tempfile
|
14 |
from PIL import Image
|
15 |
from gtts import gTTS
|
16 |
+
from transformers import pipeline, AutoTokenizer
|
17 |
|
18 |
# --- Constants & Setup ---
|
19 |
st.title("📖✨ Turn Images into Children's Stories")
|
20 |
|
21 |
+
# --- Enhanced Cleaning Functions ---
|
22 |
+
def clean_story_text(raw_text):
|
23 |
+
"""Multi-stage cleaning pipeline for generated stories"""
|
24 |
+
# Remove chat template artifacts
|
25 |
+
clean = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', raw_text, flags=re.DOTALL)
|
26 |
+
|
27 |
+
# Remove thinking chain patterns
|
28 |
+
clean = re.sub(
|
29 |
+
r'(Okay, I need|Let me start|First,|Maybe|I should|How to)(.*?)(?=\n\w|\Z)',
|
30 |
+
'',
|
31 |
+
clean,
|
32 |
+
flags=re.DOTALL|re.IGNORECASE
|
33 |
+
)
|
34 |
+
|
35 |
+
# Remove special tokens and markdown
|
36 |
+
clean = re.sub(r'<\|.*?\|>|\[.*?\]|\*\*', '', clean)
|
37 |
+
|
38 |
+
# Split and clean paragraphs
|
39 |
+
paragraphs = [p.strip() for p in clean.split('\n') if p.strip()]
|
40 |
+
return '\n\n'.join(paragraphs[:3]) # Keep max 3 paragraphs
|
41 |
+
|
42 |
+
# --- Optimized Model Loading ---
|
43 |
@st.cache_resource(show_spinner=False)
|
44 |
def load_models():
|
45 |
+
# Image captioning
|
46 |
captioner = pipeline(
|
47 |
"image-to-text",
|
48 |
model="Salesforce/blip-image-captioning-base",
|
49 |
device=0 if torch.cuda.is_available() else -1
|
50 |
)
|
51 |
|
52 |
+
# Story generator with Qwen-specific config
|
53 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
54 |
+
"Qwen/Qwen3-0.6B",
|
55 |
+
trust_remote_code=True,
|
56 |
+
pad_token='<|endoftext|>'
|
57 |
+
)
|
58 |
+
|
59 |
+
story_pipe = pipeline(
|
60 |
"text-generation",
|
61 |
model="Qwen/Qwen3-0.6B",
|
62 |
+
tokenizer=tokenizer,
|
63 |
device_map="auto",
|
64 |
torch_dtype=torch.float16,
|
65 |
+
max_new_tokens=300, # Increased for better story flow
|
66 |
+
temperature=0.7, # Lower temperature for more focused output
|
|
|
67 |
top_p=0.9,
|
68 |
+
repetition_penalty=1.2,
|
69 |
+
do_sample=True,
|
70 |
+
eos_token_id=tokenizer.eos_token_id
|
71 |
)
|
72 |
|
73 |
+
return captioner, story_pipe
|
|
|
|
|
74 |
|
75 |
# --- Main Application Flow ---
|
76 |
uploaded_image = st.file_uploader(
|
|
|
79 |
)
|
80 |
|
81 |
if uploaded_image:
|
|
|
82 |
image = Image.open(uploaded_image).convert("RGB")
|
83 |
st.image(image, use_column_width=True)
|
84 |
|
|
|
97 |
|
98 |
st.success(f"**Image Understanding:** {image_caption}")
|
99 |
|
100 |
+
# Enhanced prompt engineering
|
101 |
+
story_prompt = f"""<|im_start|>system
|
102 |
+
You are a children's story writer. Create a SHORT STORY based on this image description: "{image_caption}"
|
103 |
+
|
104 |
+
RULES:
|
105 |
+
1. Use simple language (Grade 2 level)
|
106 |
+
2. Include a magical element
|
107 |
+
3. Add a moral lesson about kindness
|
108 |
+
4. NO internal thoughts/explanations
|
109 |
+
5. 3 paragraphs maximum<|im_end|>
|
110 |
+
<|im_start|>user
|
111 |
+
Write the story<|im_end|>
|
112 |
+
<|im_start|>assistant
|
113 |
+
"""
|
114 |
+
|
115 |
+
# Generate story
|
116 |
try:
|
117 |
with st.spinner("📝 Crafting magical story..."):
|
118 |
start_time = time.time()
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
story_result = story_pipe(
|
121 |
story_prompt,
|
122 |
+
num_return_sequences=1,
|
123 |
+
stopping_criteria=[lambda _: False] # Disable default stopping
|
124 |
)
|
125 |
|
126 |
+
# Enhanced post-processing
|
|
|
|
|
|
|
|
|
127 |
raw_story = story_result[0]['generated_text']
|
128 |
+
clean_story = clean_story_text(raw_story.split("<|im_start|>assistant")[-1])
|
|
|
129 |
|
130 |
+
# Format paragraphs
|
131 |
+
formatted_story = "\n\n".join(
|
132 |
+
[f"<p style='font-size:18px; line-height:1.6'>{p}</p>"
|
133 |
+
for p in clean_story.split("\n\n")]
|
134 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
except Exception as e:
|
137 |
st.error(f"❌ Story generation failed: {str(e)}")
|
138 |
st.stop()
|
139 |
|
|
|
|
|
|
|
|
|
140 |
# Display story
|
141 |
st.subheader("✨ Your Magical Story")
|
142 |
+
st.markdown(formatted_story, unsafe_allow_html=True)
|
143 |
|
144 |
# Audio conversion
|
145 |
with st.spinner("🔊 Creating audio version..."):
|
146 |
try:
|
147 |
+
audio = gTTS(text=clean_story, lang="en", slow=False)
|
148 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
|
149 |
audio.save(tmp_file.name)
|
150 |
st.audio(tmp_file.name, format="audio/mp3")
|