Update app.py
Browse files
app.py
CHANGED
@@ -13,66 +13,58 @@ import torch
|
|
13 |
import tempfile
|
14 |
from PIL import Image
|
15 |
from gtts import gTTS
|
16 |
-
from transformers import pipeline, AutoTokenizer
|
17 |
|
18 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
st.title("📖✨ Turn Images into Children's Stories")
|
20 |
|
21 |
-
# --- Enhanced Cleaning Functions ---
|
22 |
def clean_story_text(raw_text):
|
23 |
-
"""
|
24 |
-
# Remove
|
25 |
-
clean = re.sub(r'
|
26 |
-
|
27 |
-
|
28 |
-
clean = re.sub(
|
29 |
-
r'(Okay, I need|Let me start|First,|Maybe|I should|How to)(.*?)(?=\n\w|\Z)',
|
30 |
-
'',
|
31 |
-
clean,
|
32 |
-
flags=re.DOTALL|re.IGNORECASE
|
33 |
-
)
|
34 |
-
|
35 |
-
# Remove special tokens and markdown
|
36 |
-
clean = re.sub(r'<\|.*?\|>|\[.*?\]|\*\*', '', clean)
|
37 |
-
|
38 |
-
# Split and clean paragraphs
|
39 |
-
paragraphs = [p.strip() for p in clean.split('\n') if p.strip()]
|
40 |
-
return '\n\n'.join(paragraphs[:3]) # Keep max 3 paragraphs
|
41 |
-
|
42 |
-
# --- Optimized Model Loading ---
|
43 |
-
@st.cache_resource(show_spinner=False)
|
44 |
-
def load_models():
|
45 |
-
# Image captioning
|
46 |
-
captioner = pipeline(
|
47 |
-
"image-to-text",
|
48 |
-
model="Salesforce/blip-image-captioning-base",
|
49 |
-
device=0 if torch.cuda.is_available() else -1
|
50 |
-
)
|
51 |
-
|
52 |
-
# Story generator with Qwen-specific config
|
53 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
54 |
-
"Qwen/Qwen3-0.6B",
|
55 |
-
trust_remote_code=True,
|
56 |
-
pad_token='<|endoftext|>'
|
57 |
-
)
|
58 |
-
|
59 |
-
story_pipe = pipeline(
|
60 |
-
"text-generation",
|
61 |
-
model="Qwen/Qwen3-0.6B",
|
62 |
-
tokenizer=tokenizer,
|
63 |
-
device_map="auto",
|
64 |
-
torch_dtype=torch.float16,
|
65 |
-
max_new_tokens=300, # Increased for better story flow
|
66 |
-
temperature=0.7, # Lower temperature for more focused output
|
67 |
-
top_p=0.9,
|
68 |
-
repetition_penalty=1.2,
|
69 |
-
do_sample=True,
|
70 |
-
eos_token_id=tokenizer.eos_token_id
|
71 |
-
)
|
72 |
-
|
73 |
-
return captioner, story_pipe
|
74 |
-
|
75 |
-
# --- Main Application Flow ---
|
76 |
uploaded_image = st.file_uploader(
|
77 |
"Upload a children's book style image:",
|
78 |
type=["jpg", "jpeg", "png"]
|
@@ -80,76 +72,50 @@ uploaded_image = st.file_uploader(
|
|
80 |
|
81 |
if uploaded_image:
|
82 |
image = Image.open(uploaded_image).convert("RGB")
|
83 |
-
|
|
|
84 |
|
85 |
-
# Generate caption
|
86 |
with st.spinner("🔍 Analyzing image..."):
|
87 |
try:
|
88 |
caption_result = caption_pipe(image)
|
89 |
-
image_caption = caption_result[0].get("generated_text", "")
|
|
|
90 |
except Exception as e:
|
91 |
st.error(f"❌ Image analysis failed: {str(e)}")
|
92 |
st.stop()
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
You are a children's story writer. Create a SHORT STORY based on this image description: "{image_caption}"
|
103 |
-
|
104 |
-
RULES:
|
105 |
-
1. Use simple language (Grade 2 level)
|
106 |
-
2. Include a magical element
|
107 |
-
3. Add a moral lesson about kindness
|
108 |
-
4. NO internal thoughts/explanations
|
109 |
-
5. 3 paragraphs maximum<|im_end|>
|
110 |
-
<|im_start|>user
|
111 |
-
Write the story<|im_end|>
|
112 |
-
<|im_start|>assistant
|
113 |
-
"""
|
114 |
-
|
115 |
-
# Generate story
|
116 |
try:
|
117 |
with st.spinner("📝 Crafting magical story..."):
|
118 |
-
start_time = time.time()
|
119 |
-
|
120 |
story_result = story_pipe(
|
121 |
story_prompt,
|
122 |
-
|
123 |
-
|
|
|
124 |
)
|
125 |
|
126 |
-
# Enhanced post-processing
|
127 |
raw_story = story_result[0]['generated_text']
|
128 |
-
|
129 |
-
|
130 |
-
# Format paragraphs
|
131 |
-
formatted_story = "\n\n".join(
|
132 |
-
[f"<p style='font-size:18px; line-height:1.6'>{p}</p>"
|
133 |
-
for p in clean_story.split("\n\n")]
|
134 |
-
)
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
st.stop()
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
try:
|
147 |
-
audio = gTTS(text=clean_story, lang="en", slow=False)
|
148 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
|
149 |
-
audio.save(tmp_file.name)
|
150 |
-
st.audio(tmp_file.name, format="audio/mp3")
|
151 |
-
except Exception as e:
|
152 |
-
st.error(f"❌ Audio conversion failed: {str(e)}")
|
153 |
|
154 |
# Footer
|
155 |
st.markdown("---")
|
|
|
13 |
import tempfile
|
14 |
from PIL import Image
|
15 |
from gtts import gTTS
|
16 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
17 |
|
18 |
+
# --- Initialize Models First ---
|
19 |
+
@st.cache_resource(show_spinner=False)
|
20 |
+
def load_models():
|
21 |
+
"""Load and return both models at startup"""
|
22 |
+
try:
|
23 |
+
# 1. Image Captioning Model
|
24 |
+
caption_pipe = pipeline(
|
25 |
+
"image-to-text",
|
26 |
+
model="Salesforce/blip-image-captioning-base",
|
27 |
+
device=0 if torch.cuda.is_available() else -1
|
28 |
+
)
|
29 |
+
|
30 |
+
# 2. Story Generation Model
|
31 |
+
story_tokenizer = AutoTokenizer.from_pretrained(
|
32 |
+
"Qwen/Qwen3-0.6B",
|
33 |
+
trust_remote_code=True
|
34 |
+
)
|
35 |
+
|
36 |
+
story_model = AutoModelForCausalLM.from_pretrained(
|
37 |
+
"Qwen/Qwen3-0.6B",
|
38 |
+
device_map="auto",
|
39 |
+
torch_dtype=torch.float16
|
40 |
+
)
|
41 |
+
|
42 |
+
story_pipe = pipeline(
|
43 |
+
"text-generation",
|
44 |
+
model=story_model,
|
45 |
+
tokenizer=story_tokenizer,
|
46 |
+
max_new_tokens=300,
|
47 |
+
temperature=0.7
|
48 |
+
)
|
49 |
+
|
50 |
+
return caption_pipe, story_pipe
|
51 |
+
|
52 |
+
except Exception as e:
|
53 |
+
st.error(f"🚨 Model loading failed: {str(e)}")
|
54 |
+
st.stop()
|
55 |
+
|
56 |
+
# Initialize models immediately when app starts
|
57 |
+
caption_pipe, story_pipe = load_models()
|
58 |
+
|
59 |
+
# --- Rest of Application ---
|
60 |
st.title("📖✨ Turn Images into Children's Stories")
|
61 |
|
|
|
62 |
def clean_story_text(raw_text):
|
63 |
+
"""Improved cleaning function"""
|
64 |
+
clean = re.sub(r'<\|.*?\|>', '', raw_text) # Remove special tokens
|
65 |
+
clean = re.sub(r'Okay, I need.*?(?=\n|$)', '', clean, flags=re.DOTALL) # Remove thinking chains
|
66 |
+
return clean.strip()
|
67 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
uploaded_image = st.file_uploader(
|
69 |
"Upload a children's book style image:",
|
70 |
type=["jpg", "jpeg", "png"]
|
|
|
72 |
|
73 |
if uploaded_image:
|
74 |
image = Image.open(uploaded_image).convert("RGB")
|
75 |
+
# Updated parameter here ↓
|
76 |
+
st.image(image, use_container_width=True) # Changed use_column_width to use_container_width
|
77 |
|
|
|
78 |
with st.spinner("🔍 Analyzing image..."):
|
79 |
try:
|
80 |
caption_result = caption_pipe(image)
|
81 |
+
image_caption = caption_result[0].get("generated_text", "")
|
82 |
+
st.success(f"**Image Understanding:** {image_caption}")
|
83 |
except Exception as e:
|
84 |
st.error(f"❌ Image analysis failed: {str(e)}")
|
85 |
st.stop()
|
86 |
+
|
87 |
+
# Story generation prompt
|
88 |
+
story_prompt = f"""Write a children's story about: {image_caption}
|
89 |
+
Rules:
|
90 |
+
- Use simple words (Grade 2 level)
|
91 |
+
- Exclude thinking processes
|
92 |
+
- 3 paragraphs maximum
|
93 |
+
Story:"""
|
94 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
try:
|
96 |
with st.spinner("📝 Crafting magical story..."):
|
|
|
|
|
97 |
story_result = story_pipe(
|
98 |
story_prompt,
|
99 |
+
do_sample=True,
|
100 |
+
top_p=0.9,
|
101 |
+
repetition_penalty=1.2
|
102 |
)
|
103 |
|
|
|
104 |
raw_story = story_result[0]['generated_text']
|
105 |
+
final_story = clean_story_text(raw_story.split("Story:")[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
+
st.subheader("✨ Your Magical Story")
|
108 |
+
st.write(final_story)
|
|
|
109 |
|
110 |
+
# Audio conversion
|
111 |
+
with st.spinner("🔊 Creating audio version..."):
|
112 |
+
audio = gTTS(text=final_story, lang="en", slow=False)
|
113 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
|
114 |
+
audio.save(tmp_file.name)
|
115 |
+
st.audio(tmp_file.name, format="audio/mp3")
|
116 |
|
117 |
+
except Exception as e:
|
118 |
+
st.error(f"❌ Story generation failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
# Footer
|
121 |
st.markdown("---")
|