mayf commited on
Commit
6523fb1
·
verified ·
1 Parent(s): f913ab4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -109
app.py CHANGED
@@ -13,66 +13,58 @@ import torch
13
  import tempfile
14
  from PIL import Image
15
  from gtts import gTTS
16
- from transformers import pipeline, AutoTokenizer
17
 
18
- # --- Constants & Setup ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  st.title("📖✨ Turn Images into Children's Stories")
20
 
21
- # --- Enhanced Cleaning Functions ---
22
  def clean_story_text(raw_text):
23
- """Multi-stage cleaning pipeline for generated stories"""
24
- # Remove chat template artifacts
25
- clean = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', raw_text, flags=re.DOTALL)
26
-
27
- # Remove thinking chain patterns
28
- clean = re.sub(
29
- r'(Okay, I need|Let me start|First,|Maybe|I should|How to)(.*?)(?=\n\w|\Z)',
30
- '',
31
- clean,
32
- flags=re.DOTALL|re.IGNORECASE
33
- )
34
-
35
- # Remove special tokens and markdown
36
- clean = re.sub(r'<\|.*?\|>|\[.*?\]|\*\*', '', clean)
37
-
38
- # Split and clean paragraphs
39
- paragraphs = [p.strip() for p in clean.split('\n') if p.strip()]
40
- return '\n\n'.join(paragraphs[:3]) # Keep max 3 paragraphs
41
-
42
- # --- Optimized Model Loading ---
43
- @st.cache_resource(show_spinner=False)
44
- def load_models():
45
- # Image captioning
46
- captioner = pipeline(
47
- "image-to-text",
48
- model="Salesforce/blip-image-captioning-base",
49
- device=0 if torch.cuda.is_available() else -1
50
- )
51
-
52
- # Story generator with Qwen-specific config
53
- tokenizer = AutoTokenizer.from_pretrained(
54
- "Qwen/Qwen3-0.6B",
55
- trust_remote_code=True,
56
- pad_token='<|endoftext|>'
57
- )
58
-
59
- story_pipe = pipeline(
60
- "text-generation",
61
- model="Qwen/Qwen3-0.6B",
62
- tokenizer=tokenizer,
63
- device_map="auto",
64
- torch_dtype=torch.float16,
65
- max_new_tokens=300, # Increased for better story flow
66
- temperature=0.7, # Lower temperature for more focused output
67
- top_p=0.9,
68
- repetition_penalty=1.2,
69
- do_sample=True,
70
- eos_token_id=tokenizer.eos_token_id
71
- )
72
-
73
- return captioner, story_pipe
74
-
75
- # --- Main Application Flow ---
76
  uploaded_image = st.file_uploader(
77
  "Upload a children's book style image:",
78
  type=["jpg", "jpeg", "png"]
@@ -80,76 +72,50 @@ uploaded_image = st.file_uploader(
80
 
81
  if uploaded_image:
82
  image = Image.open(uploaded_image).convert("RGB")
83
- st.image(image, use_column_width=True)
 
84
 
85
- # Generate caption
86
  with st.spinner("🔍 Analyzing image..."):
87
  try:
88
  caption_result = caption_pipe(image)
89
- image_caption = caption_result[0].get("generated_text", "").strip()
 
90
  except Exception as e:
91
  st.error(f"❌ Image analysis failed: {str(e)}")
92
  st.stop()
93
-
94
- if not image_caption:
95
- st.error("❌ Couldn't understand this image. Please try another!")
96
- st.stop()
97
-
98
- st.success(f"**Image Understanding:** {image_caption}")
99
-
100
- # Enhanced prompt engineering
101
- story_prompt = f"""<|im_start|>system
102
- You are a children's story writer. Create a SHORT STORY based on this image description: "{image_caption}"
103
-
104
- RULES:
105
- 1. Use simple language (Grade 2 level)
106
- 2. Include a magical element
107
- 3. Add a moral lesson about kindness
108
- 4. NO internal thoughts/explanations
109
- 5. 3 paragraphs maximum<|im_end|>
110
- <|im_start|>user
111
- Write the story<|im_end|>
112
- <|im_start|>assistant
113
- """
114
-
115
- # Generate story
116
  try:
117
  with st.spinner("📝 Crafting magical story..."):
118
- start_time = time.time()
119
-
120
  story_result = story_pipe(
121
  story_prompt,
122
- num_return_sequences=1,
123
- stopping_criteria=[lambda _: False] # Disable default stopping
 
124
  )
125
 
126
- # Enhanced post-processing
127
  raw_story = story_result[0]['generated_text']
128
- clean_story = clean_story_text(raw_story.split("<|im_start|>assistant")[-1])
129
-
130
- # Format paragraphs
131
- formatted_story = "\n\n".join(
132
- [f"<p style='font-size:18px; line-height:1.6'>{p}</p>"
133
- for p in clean_story.split("\n\n")]
134
- )
135
 
136
- except Exception as e:
137
- st.error(f"❌ Story generation failed: {str(e)}")
138
- st.stop()
139
 
140
- # Display story
141
- st.subheader(" Your Magical Story")
142
- st.markdown(formatted_story, unsafe_allow_html=True)
 
 
 
143
 
144
- # Audio conversion
145
- with st.spinner("🔊 Creating audio version..."):
146
- try:
147
- audio = gTTS(text=clean_story, lang="en", slow=False)
148
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
149
- audio.save(tmp_file.name)
150
- st.audio(tmp_file.name, format="audio/mp3")
151
- except Exception as e:
152
- st.error(f"❌ Audio conversion failed: {str(e)}")
153
 
154
  # Footer
155
  st.markdown("---")
 
13
  import tempfile
14
  from PIL import Image
15
  from gtts import gTTS
16
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
17
 
18
+ # --- Initialize Models First ---
19
+ @st.cache_resource(show_spinner=False)
20
+ def load_models():
21
+ """Load and return both models at startup"""
22
+ try:
23
+ # 1. Image Captioning Model
24
+ caption_pipe = pipeline(
25
+ "image-to-text",
26
+ model="Salesforce/blip-image-captioning-base",
27
+ device=0 if torch.cuda.is_available() else -1
28
+ )
29
+
30
+ # 2. Story Generation Model
31
+ story_tokenizer = AutoTokenizer.from_pretrained(
32
+ "Qwen/Qwen3-0.6B",
33
+ trust_remote_code=True
34
+ )
35
+
36
+ story_model = AutoModelForCausalLM.from_pretrained(
37
+ "Qwen/Qwen3-0.6B",
38
+ device_map="auto",
39
+ torch_dtype=torch.float16
40
+ )
41
+
42
+ story_pipe = pipeline(
43
+ "text-generation",
44
+ model=story_model,
45
+ tokenizer=story_tokenizer,
46
+ max_new_tokens=300,
47
+ temperature=0.7
48
+ )
49
+
50
+ return caption_pipe, story_pipe
51
+
52
+ except Exception as e:
53
+ st.error(f"🚨 Model loading failed: {str(e)}")
54
+ st.stop()
55
+
56
+ # Initialize models immediately when app starts
57
+ caption_pipe, story_pipe = load_models()
58
+
59
+ # --- Rest of Application ---
60
  st.title("📖✨ Turn Images into Children's Stories")
61
 
 
62
  def clean_story_text(raw_text):
63
+ """Improved cleaning function"""
64
+ clean = re.sub(r'<\|.*?\|>', '', raw_text) # Remove special tokens
65
+ clean = re.sub(r'Okay, I need.*?(?=\n|$)', '', clean, flags=re.DOTALL) # Remove thinking chains
66
+ return clean.strip()
67
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  uploaded_image = st.file_uploader(
69
  "Upload a children's book style image:",
70
  type=["jpg", "jpeg", "png"]
 
72
 
73
  if uploaded_image:
74
  image = Image.open(uploaded_image).convert("RGB")
75
+ # Updated parameter here ↓
76
+ st.image(image, use_container_width=True) # Changed use_column_width to use_container_width
77
 
 
78
  with st.spinner("🔍 Analyzing image..."):
79
  try:
80
  caption_result = caption_pipe(image)
81
+ image_caption = caption_result[0].get("generated_text", "")
82
+ st.success(f"**Image Understanding:** {image_caption}")
83
  except Exception as e:
84
  st.error(f"❌ Image analysis failed: {str(e)}")
85
  st.stop()
86
+
87
+ # Story generation prompt
88
+ story_prompt = f"""Write a children's story about: {image_caption}
89
+ Rules:
90
+ - Use simple words (Grade 2 level)
91
+ - Exclude thinking processes
92
+ - 3 paragraphs maximum
93
+ Story:"""
94
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  try:
96
  with st.spinner("📝 Crafting magical story..."):
 
 
97
  story_result = story_pipe(
98
  story_prompt,
99
+ do_sample=True,
100
+ top_p=0.9,
101
+ repetition_penalty=1.2
102
  )
103
 
 
104
  raw_story = story_result[0]['generated_text']
105
+ final_story = clean_story_text(raw_story.split("Story:")[-1])
 
 
 
 
 
 
106
 
107
+ st.subheader("✨ Your Magical Story")
108
+ st.write(final_story)
 
109
 
110
+ # Audio conversion
111
+ with st.spinner("🔊 Creating audio version..."):
112
+ audio = gTTS(text=final_story, lang="en", slow=False)
113
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
114
+ audio.save(tmp_file.name)
115
+ st.audio(tmp_file.name, format="audio/mp3")
116
 
117
+ except Exception as e:
118
+ st.error(f" Story generation failed: {str(e)}")
 
 
 
 
 
 
 
119
 
120
  # Footer
121
  st.markdown("---")