mayf commited on
Commit
4d1f328
·
verified ·
1 Parent(s): 395cd70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -82
app.py CHANGED
@@ -1,21 +1,18 @@
1
- # Must be FIRST import and FIRST Streamlit command
2
- import streamlit as st
3
- st.set_page_config(
4
- page_title="Magic Story Generator",
5
- layout="centered",
6
- page_icon="📖"
7
- )
8
-
9
- # Other imports AFTER Streamlit config
10
  import re
11
  import time
12
  import tempfile
 
13
  from PIL import Image
14
  from gtts import gTTS
15
  from transformers import pipeline
16
 
17
- # --- Constants & Setup ---
18
- st.title("📖✨ Turn Images into Children's Stories")
 
 
 
 
19
 
20
  # --- Model Loading (Cached) ---
21
  @st.cache_resource(show_spinner=False)
@@ -24,113 +21,134 @@ def load_models():
24
  captioner = pipeline(
25
  "image-to-text",
26
  model="Salesforce/blip-image-captioning-base",
27
- device=-1 # Use -1 for CPU, 0 for GPU
28
  )
29
 
30
- # Story generation model (Qwen3-1.7B)
31
  storyteller = pipeline(
32
  "text-generation",
33
  model="Qwen/Qwen3-1.7B",
34
  device_map="auto",
35
  trust_remote_code=True,
36
  torch_dtype="auto",
37
- max_new_tokens=250,
38
- temperature=0.7,
39
- top_p=0.85,
40
- repetition_penalty=1.15,
41
- eos_token_id=151645
 
 
42
  )
43
 
44
  return captioner, storyteller
45
 
46
- caption_pipe, story_pipe = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # --- Main Application Flow ---
49
  uploaded_image = st.file_uploader(
50
  "Upload a children's book style image:",
51
  type=["jpg", "jpeg", "png"]
52
  )
53
 
54
  if uploaded_image:
55
- # Process image
56
  image = Image.open(uploaded_image).convert("RGB")
57
- st.image(image, use_container_width=True)
58
-
59
- # Generate caption
60
- with st.spinner("🔍 Analyzing image..."):
61
- caption_result = caption_pipe(image)
62
- image_caption = caption_result[0].get("generated_text", "").strip()
63
 
64
- if not image_caption:
65
- st.error("❌ Couldn't understand this image. Please try another!")
66
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- st.success(f"**Image Understanding:** {image_caption}")
69
-
70
  # Create story prompt
71
  story_prompt = (
72
  f"<|im_start|>system\n"
73
- f"You are a children's book author. Create a 100-150 word story based on: {image_caption}\n"
74
- "Use simple language, friendly characters, and a positive lesson.<|im_end|>\n"
 
 
 
 
 
75
  f"<|im_start|>user\n"
76
- f"Write a child-friendly story with a clear beginning, middle, and end.<|im_end|>\n"
77
  f"<|im_start|>assistant\n"
78
  )
79
-
80
- # Generate story
81
- with st.spinner("📝 Crafting magical story..."):
82
- start_time = time.time()
83
- story_result = story_pipe(
84
- story_prompt,
85
- do_sample=True,
86
- num_return_sequences=1,
87
- pad_token_id=151645
88
- )
89
- generation_time = time.time() - start_time
90
-
91
- # Process output
92
- raw_story = story_result[0]['generated_text']
93
 
94
- # Clean up story text
95
- clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
96
- clean_story = clean_story.split("<|im_start|>")[0] # Remove any new turns
97
- clean_story = clean_story.replace("<|im_end|>", "").strip()
 
 
 
 
 
 
 
 
 
98
 
99
- # Remove assistant mentions using regex
100
- clean_story = re.sub(
101
- r'^(assistant[:>]?\s*)+',
102
- '',
103
- clean_story,
104
- flags=re.IGNORECASE
105
- ).strip()
106
-
107
- # Format story punctuation
108
- final_story = []
109
- for sentence in clean_story.split(". "):
110
- sentence = sentence.strip()
111
- if not sentence:
112
- continue
113
- if not sentence.endswith('.'):
114
- sentence += '.'
115
- final_story.append(sentence[0].upper() + sentence[1:])
116
 
117
- final_story = " ".join(final_story).replace("..", ".")[:800]
118
-
119
- # Display story
120
- st.subheader("✨ Your Magical Story")
121
  st.write(final_story)
122
-
123
- # Audio conversion
124
  with st.spinner("🔊 Creating audio version..."):
125
  try:
126
- audio = gTTS(text=final_story, lang="en", slow=False)
127
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
128
- audio.save(tmp_file.name)
129
- st.audio(tmp_file.name, format="audio/mp3")
130
  except Exception as e:
131
- st.error(f" Audio conversion failed: {str(e)}")
132
 
133
  # Footer
134
  st.markdown("---")
135
- st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")
136
 
 
1
+ # story_generator.py
 
 
 
 
 
 
 
 
2
  import re
3
  import time
4
  import tempfile
5
+ import streamlit as st
6
  from PIL import Image
7
  from gtts import gTTS
8
  from transformers import pipeline
9
 
10
+ # --- Initialize Streamlit Config ---
11
+ st.set_page_config(
12
+ page_title="Magic Story Generator",
13
+ layout="centered",
14
+ page_icon="📖"
15
+ )
16
 
17
  # --- Model Loading (Cached) ---
18
  @st.cache_resource(show_spinner=False)
 
21
  captioner = pipeline(
22
  "image-to-text",
23
  model="Salesforce/blip-image-captioning-base",
24
+ device=-1
25
  )
26
 
27
+ # Story generation model with optimized settings
28
  storyteller = pipeline(
29
  "text-generation",
30
  model="Qwen/Qwen3-1.7B",
31
  device_map="auto",
32
  trust_remote_code=True,
33
  torch_dtype="auto",
34
+ model_kwargs={
35
+ "revision": "main",
36
+ "temperature": 0.7,
37
+ "top_p": 0.9,
38
+ "repetition_penalty": 1.1,
39
+ "pad_token_id": 151645
40
+ }
41
  )
42
 
43
  return captioner, storyteller
44
 
45
+ # --- Text Processing Utilities ---
46
+ def clean_generated_text(raw_text):
47
+ # Split at first assistant marker
48
+ clean_text = raw_text.split("<|im_start|>assistant\n", 1)[-1]
49
+
50
+ # Remove any subsequent chat turns
51
+ clean_text = clean_text.split("<|im_start|>")[0]
52
+
53
+ # Remove special tokens and whitespace
54
+ clean_text = clean_text.replace("<|im_end|>", "").strip()
55
+
56
+ # Regex cleanup for remaining markers
57
+ clean_text = re.sub(
58
+ r'^(assistant[\s\-\:>]*)+',
59
+ '',
60
+ clean_text,
61
+ flags=re.IGNORECASE
62
+ ).strip()
63
+
64
+ # Format punctuation and capitalization
65
+ sentences = []
66
+ for sent in re.split(r'(?<=[.!?]) +', clean_text):
67
+ sent = sent.strip()
68
+ if not sent:
69
+ continue
70
+ if sent[-1] not in {'.', '!', '?'}:
71
+ sent += '.'
72
+ sentences.append(sent[0].upper() + sent[1:])
73
+
74
+ return ' '.join(sentences)
75
+
76
+ # --- Main Application UI ---
77
+ st.title("📖✨ Magic Story Generator")
78
 
 
79
  uploaded_image = st.file_uploader(
80
  "Upload a children's book style image:",
81
  type=["jpg", "jpeg", "png"]
82
  )
83
 
84
  if uploaded_image:
85
+ # Display uploaded image
86
  image = Image.open(uploaded_image).convert("RGB")
87
+ st.image(image, use_column_width=True)
 
 
 
 
 
88
 
89
+ # Load models only when needed
90
+ caption_pipe, story_pipe = load_models()
91
+
92
+ # Generate image caption
93
+ with st.spinner("🔍 Analyzing image..."):
94
+ try:
95
+ caption_result = caption_pipe(image)
96
+ image_caption = caption_result[0].get("generated_text", "").strip()
97
+
98
+ if not image_caption:
99
+ raise ValueError("Couldn't generate caption")
100
+
101
+ st.success(f"**Image Understanding:** {image_caption}")
102
+ except Exception as e:
103
+ st.error("❌ Failed to analyze image. Please try another.")
104
+ st.stop()
105
 
 
 
106
  # Create story prompt
107
  story_prompt = (
108
  f"<|im_start|>system\n"
109
+ f"You are a children's book author. Create a 150-word story based on: {image_caption}\n"
110
+ "Include these elements:\n"
111
+ "- Friendly characters\n"
112
+ "- Simple vocabulary\n"
113
+ "- Positive lesson\n"
114
+ "- Clear story structure\n"
115
+ "<|im_end|>\n"
116
  f"<|im_start|>user\n"
117
+ f"Write an engaging story suitable for ages 6-8.<|im_end|>\n"
118
  f"<|im_start|>assistant\n"
119
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ # Generate story text
122
+ with st.spinner("📝 Crafting magical story..."):
123
+ try:
124
+ story_result = story_pipe(
125
+ story_prompt,
126
+ max_new_tokens=300,
127
+ do_sample=True,
128
+ num_return_sequences=1
129
+ )
130
+ raw_story = story_result[0]['generated_text']
131
+ except Exception as e:
132
+ st.error("❌ Story generation failed. Please try again.")
133
+ st.stop()
134
 
135
+ # Process and display story
136
+ final_story = clean_generated_text(raw_story)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ st.subheader(" Your Story")
 
 
 
139
  st.write(final_story)
140
+
141
+ # Generate audio version
142
  with st.spinner("🔊 Creating audio version..."):
143
  try:
144
+ tts = gTTS(text=final_story, lang='en', slow=False)
145
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
146
+ tts.save(fp.name)
147
+ st.audio(fp.read(), format="audio/mp3")
148
  except Exception as e:
149
+ st.warning("⚠️ Audio conversion failed. Text version still available.")
150
 
151
  # Footer
152
  st.markdown("---")
153
+ st.caption("Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")
154