mayf commited on
Commit
f913ab4
·
verified ·
1 Parent(s): d179ebe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -61
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # FIRST import and FIRST Streamlit command
2
  import streamlit as st
3
  st.set_page_config(
4
  page_title="Magic Story Generator",
@@ -13,38 +13,64 @@ import torch
13
  import tempfile
14
  from PIL import Image
15
  from gtts import gTTS
16
- from transformers import pipeline
17
 
18
  # --- Constants & Setup ---
19
  st.title("📖✨ Turn Images into Children's Stories")
20
 
21
- # --- Model Loading (Cached) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @st.cache_resource(show_spinner=False)
23
  def load_models():
24
- # Image captioning model
25
  captioner = pipeline(
26
  "image-to-text",
27
  model="Salesforce/blip-image-captioning-base",
28
  device=0 if torch.cuda.is_available() else -1
29
  )
30
 
31
- # Optimized story generation model
32
- storyteller = pipeline(
 
 
 
 
 
 
33
  "text-generation",
34
  model="Qwen/Qwen3-0.6B",
 
35
  device_map="auto",
36
  torch_dtype=torch.float16,
37
- max_new_tokens=200,
38
- temperature=0.9,
39
- top_k=50,
40
  top_p=0.9,
41
- repetition_penalty=1.1,
42
- eos_token_id=151645
 
43
  )
44
 
45
- return captioner, storyteller
46
-
47
- caption_pipe, story_pipe = load_models()
48
 
49
  # --- Main Application Flow ---
50
  uploaded_image = st.file_uploader(
@@ -53,7 +79,6 @@ uploaded_image = st.file_uploader(
53
  )
54
 
55
  if uploaded_image:
56
- # Process image
57
  image = Image.open(uploaded_image).convert("RGB")
58
  st.image(image, use_column_width=True)
59
 
@@ -72,73 +97,54 @@ if uploaded_image:
72
 
73
  st.success(f"**Image Understanding:** {image_caption}")
74
 
75
- # Create story prompt
76
- story_prompt = (
77
- f"<|im_start|>system\n"
78
- f"You're a children's author. Create a short story (100-150 words) based on: {image_caption}\n"
79
- f"Use simple language and include a moral lesson.<|im_end|>\n"
80
- f"<|im_start|>assistant\n"
81
- )
82
-
83
- # Generate story with progress
84
- progress_bar = st.progress(0)
85
- status_text = st.empty()
86
-
 
 
 
 
87
  try:
88
  with st.spinner("📝 Crafting magical story..."):
89
  start_time = time.time()
90
 
91
- def update_progress(step):
92
- progress = min(step/5, 1.0) # Simulate progress steps
93
- progress_bar.progress(progress)
94
- status_text.text(f"Step {int(step)}/5: {'📖'*int(step)}")
95
-
96
- update_progress(1)
97
  story_result = story_pipe(
98
  story_prompt,
99
- do_sample=True,
100
- num_return_sequences=1
101
  )
102
 
103
- update_progress(4)
104
- generation_time = time.time() - start_time
105
- st.info(f"Story generated in {generation_time:.1f} seconds")
106
-
107
- # Process output
108
  raw_story = story_result[0]['generated_text']
109
- clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
110
- clean_story = re.sub(r'<\|.*?\|>', '', clean_story).strip()
111
 
112
- # Format story text
113
- sentences = []
114
- for sent in re.split(r'(?<=[.!?]) +', clean_story):
115
- sent = sent.strip()
116
- if sent:
117
- if len(sent) > 1 and not sent.endswith(('.','!','?')):
118
- sent += '.'
119
- sentences.append(sent[0].upper() + sent[1:])
120
-
121
- final_story = ' '.join(sentences)[:600] # Limit length
122
-
123
- update_progress(5)
124
- time.sleep(0.5) # Final progress pause
125
 
126
  except Exception as e:
127
  st.error(f"❌ Story generation failed: {str(e)}")
128
  st.stop()
129
 
130
- finally:
131
- progress_bar.empty()
132
- status_text.empty()
133
-
134
  # Display story
135
  st.subheader("✨ Your Magical Story")
136
- st.write(final_story)
137
 
138
  # Audio conversion
139
  with st.spinner("🔊 Creating audio version..."):
140
  try:
141
- audio = gTTS(text=final_story, lang="en", slow=False)
142
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
143
  audio.save(tmp_file.name)
144
  st.audio(tmp_file.name, format="audio/mp3")
 
1
+ # Import Streamlit first
2
  import streamlit as st
3
  st.set_page_config(
4
  page_title="Magic Story Generator",
 
13
  import tempfile
14
  from PIL import Image
15
  from gtts import gTTS
16
+ from transformers import pipeline, AutoTokenizer
17
 
18
  # --- Constants & Setup ---
19
  st.title("📖✨ Turn Images into Children's Stories")
20
 
21
+ # --- Enhanced Cleaning Functions ---
22
+ def clean_story_text(raw_text):
23
+ """Multi-stage cleaning pipeline for generated stories"""
24
+ # Remove chat template artifacts
25
+ clean = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', raw_text, flags=re.DOTALL)
26
+
27
+ # Remove thinking chain patterns
28
+ clean = re.sub(
29
+ r'(Okay, I need|Let me start|First,|Maybe|I should|How to)(.*?)(?=\n\w|\Z)',
30
+ '',
31
+ clean,
32
+ flags=re.DOTALL|re.IGNORECASE
33
+ )
34
+
35
+ # Remove special tokens and markdown
36
+ clean = re.sub(r'<\|.*?\|>|\[.*?\]|\*\*', '', clean)
37
+
38
+ # Split and clean paragraphs
39
+ paragraphs = [p.strip() for p in clean.split('\n') if p.strip()]
40
+ return '\n\n'.join(paragraphs[:3]) # Keep max 3 paragraphs
41
+
42
+ # --- Optimized Model Loading ---
43
  @st.cache_resource(show_spinner=False)
44
  def load_models():
45
+ # Image captioning
46
  captioner = pipeline(
47
  "image-to-text",
48
  model="Salesforce/blip-image-captioning-base",
49
  device=0 if torch.cuda.is_available() else -1
50
  )
51
 
52
+ # Story generator with Qwen-specific config
53
+ tokenizer = AutoTokenizer.from_pretrained(
54
+ "Qwen/Qwen3-0.6B",
55
+ trust_remote_code=True,
56
+ pad_token='<|endoftext|>'
57
+ )
58
+
59
+ story_pipe = pipeline(
60
  "text-generation",
61
  model="Qwen/Qwen3-0.6B",
62
+ tokenizer=tokenizer,
63
  device_map="auto",
64
  torch_dtype=torch.float16,
65
+ max_new_tokens=300, # Increased for better story flow
66
+ temperature=0.7, # Lower temperature for more focused output
 
67
  top_p=0.9,
68
+ repetition_penalty=1.2,
69
+ do_sample=True,
70
+ eos_token_id=tokenizer.eos_token_id
71
  )
72
 
73
+ return captioner, story_pipe
 
 
74
 
75
  # --- Main Application Flow ---
76
  uploaded_image = st.file_uploader(
 
79
  )
80
 
81
  if uploaded_image:
 
82
  image = Image.open(uploaded_image).convert("RGB")
83
  st.image(image, use_column_width=True)
84
 
 
97
 
98
  st.success(f"**Image Understanding:** {image_caption}")
99
 
100
+ # Enhanced prompt engineering
101
+ story_prompt = f"""<|im_start|>system
102
+ You are a children's story writer. Create a SHORT STORY based on this image description: "{image_caption}"
103
+
104
+ RULES:
105
+ 1. Use simple language (Grade 2 level)
106
+ 2. Include a magical element
107
+ 3. Add a moral lesson about kindness
108
+ 4. NO internal thoughts/explanations
109
+ 5. 3 paragraphs maximum<|im_end|>
110
+ <|im_start|>user
111
+ Write the story<|im_end|>
112
+ <|im_start|>assistant
113
+ """
114
+
115
+ # Generate story
116
  try:
117
  with st.spinner("📝 Crafting magical story..."):
118
  start_time = time.time()
119
 
 
 
 
 
 
 
120
  story_result = story_pipe(
121
  story_prompt,
122
+ num_return_sequences=1,
123
+ stopping_criteria=[lambda _: False] # Disable default stopping
124
  )
125
 
126
+ # Enhanced post-processing
 
 
 
 
127
  raw_story = story_result[0]['generated_text']
128
+ clean_story = clean_story_text(raw_story.split("<|im_start|>assistant")[-1])
 
129
 
130
+ # Format paragraphs
131
+ formatted_story = "\n\n".join(
132
+ [f"<p style='font-size:18px; line-height:1.6'>{p}</p>"
133
+ for p in clean_story.split("\n\n")]
134
+ )
 
 
 
 
 
 
 
 
135
 
136
  except Exception as e:
137
  st.error(f"❌ Story generation failed: {str(e)}")
138
  st.stop()
139
 
 
 
 
 
140
  # Display story
141
  st.subheader("✨ Your Magical Story")
142
+ st.markdown(formatted_story, unsafe_allow_html=True)
143
 
144
  # Audio conversion
145
  with st.spinner("🔊 Creating audio version..."):
146
  try:
147
+ audio = gTTS(text=clean_story, lang="en", slow=False)
148
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
149
  audio.save(tmp_file.name)
150
  st.audio(tmp_file.name, format="audio/mp3")