mayf commited on
Commit
6a2dbfc
·
verified ·
1 Parent(s): ac11067

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -74
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Import Streamlit first
2
  import streamlit as st
3
  st.set_page_config(
4
  page_title="Magic Story Generator",
@@ -13,114 +13,142 @@ import torch
13
  import tempfile
14
  from PIL import Image
15
  from gtts import gTTS
16
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
17
 
18
- # --- Initialize Models First ---
 
 
 
19
  @st.cache_resource(show_spinner=False)
20
  def load_models():
21
- """Load and return both models at startup"""
22
- try:
23
- # 1. Image Captioning Model
24
- caption_pipe = pipeline(
25
- "image-to-text",
26
- model="Salesforce/blip-image-captioning-base",
27
- device=0 if torch.cuda.is_available() else -1
28
- )
29
-
30
- # 2. Story Generation Model
31
- story_tokenizer = AutoTokenizer.from_pretrained(
32
- "Qwen/Qwen3-0.6B",
33
- trust_remote_code=True
34
- )
35
-
36
- story_model = AutoModelForCausalLM.from_pretrained(
37
- "Qwen/Qwen3-0.6B",
38
- device_map="auto",
39
- torch_dtype=torch.float16
40
- )
41
-
42
- story_pipe = pipeline(
43
- "text-generation",
44
- model=story_model,
45
- tokenizer=story_tokenizer,
46
- max_new_tokens=230,
47
- temperature=0.9,
48
- top_k=50,
49
- top_p=0.9,
50
- repetition_penalty=1.1,
51
- eos_token_id=151645
52
- )
53
-
54
- return caption_pipe, story_pipe
55
-
56
- except Exception as e:
57
- st.error(f"🚨 Model loading failed: {str(e)}")
58
- st.stop()
59
 
60
- # Initialize models immediately when app starts
61
  caption_pipe, story_pipe = load_models()
62
 
63
- # --- Rest of Application ---
64
- st.title("📖✨ Turn Images into Children's Stories")
65
-
66
- def clean_story_text(raw_text):
67
- """Improved cleaning function"""
68
- clean = re.sub(r'<\|.*?\|>', '', raw_text) # Remove special tokens
69
- clean = re.sub(r'Okay, I need.*?(?=\n|$)', '', clean, flags=re.DOTALL) # Remove thinking chains
70
- return clean.strip()
71
-
72
  uploaded_image = st.file_uploader(
73
  "Upload a children's book style image:",
74
  type=["jpg", "jpeg", "png"]
75
  )
76
 
77
  if uploaded_image:
 
78
  image = Image.open(uploaded_image).convert("RGB")
79
- # Updated parameter here ↓
80
- st.image(image, use_container_width=True) # Changed use_column_width to use_container_width
81
 
 
82
  with st.spinner("🔍 Analyzing image..."):
83
  try:
84
  caption_result = caption_pipe(image)
85
- image_caption = caption_result[0].get("generated_text", "")
86
- st.success(f"**Image Understanding:** {image_caption}")
87
  except Exception as e:
88
  st.error(f"❌ Image analysis failed: {str(e)}")
89
  st.stop()
 
 
 
 
 
 
90
 
91
- # Story generation prompt
92
- story_prompt = f"""Write a children's story about: {image_caption}
93
- Rules:
94
- - Use simple words (Grade 2 level)
95
- - Exclude thinking processes
96
- - 3 paragraphs maximum
97
  Story:"""
98
 
 
 
 
 
99
  try:
100
  with st.spinner("📝 Crafting magical story..."):
 
 
 
 
 
 
 
 
101
  story_result = story_pipe(
102
  story_prompt,
103
  do_sample=True,
104
- top_p=0.9,
105
- repetition_penalty=1.2
106
  )
107
 
108
- raw_story = story_result[0]['generated_text']
109
- final_story = clean_story_text(raw_story.split("Story:")[-1])
110
-
111
- st.subheader("✨ Your Magical Story")
112
- st.write(final_story)
113
 
114
- # Audio conversion
115
- with st.spinner("🔊 Creating audio version..."):
116
- audio = gTTS(text=final_story, lang="en", slow=False)
117
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
118
- audio.save(tmp_file.name)
119
- st.audio(tmp_file.name, format="audio/mp3")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  except Exception as e:
122
  st.error(f"❌ Story generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # Footer
125
  st.markdown("---")
126
  st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")
 
 
1
+ # FIRST import and FIRST Streamlit command
2
  import streamlit as st
3
  st.set_page_config(
4
  page_title="Magic Story Generator",
 
13
  import tempfile
14
  from PIL import Image
15
  from gtts import gTTS
16
+ from transformers import pipeline, AutoTokenizer
17
 
18
+ # --- Constants & Setup ---
19
+ st.title("📖✨ Turn Images into Children's Stories")
20
+
21
+ # --- Model Loading (Cached) ---
22
  @st.cache_resource(show_spinner=False)
23
  def load_models():
24
+ # Image captioning model
25
+ captioner = pipeline(
26
+ "image-to-text",
27
+ model="Salesforce/blip-image-captioning-base",
28
+ device=0 if torch.cuda.is_available() else -1
29
+ )
30
+
31
+ # Optimized story generation model
32
+ tokenizer = AutoTokenizer.from_pretrained("Deepthoughtworks/gpt-neo-2.7B__low-cpu")
33
+ storyteller = pipeline(
34
+ "text-generation",
35
+ model="Deepthoughtworks/gpt-neo-2.7B__low-cpu",
36
+ tokenizer=tokenizer,
37
+ device_map="auto",
38
+ torch_dtype=torch.float32, # Changed to float32 for better CPU compatibility
39
+ max_new_tokens=150, # Reduced length for faster generation
40
+ temperature=0.85,
41
+ top_k=40,
42
+ top_p=0.92,
43
+ repetition_penalty=1.15,
44
+ pad_token_id=tokenizer.eos_token_id # Added for padding control
45
+ )
46
+
47
+ return captioner, storyteller
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
49
  caption_pipe, story_pipe = load_models()
50
 
51
+ # --- Main Application Flow ---
 
 
 
 
 
 
 
 
52
  uploaded_image = st.file_uploader(
53
  "Upload a children's book style image:",
54
  type=["jpg", "jpeg", "png"]
55
  )
56
 
57
  if uploaded_image:
58
+ # Process image
59
  image = Image.open(uploaded_image).convert("RGB")
60
+ st.image(image, use_container_width=True) # Fixed deprecated parameter
 
61
 
62
+ # Generate caption
63
  with st.spinner("🔍 Analyzing image..."):
64
  try:
65
  caption_result = caption_pipe(image)
66
+ image_caption = caption_result[0].get("generated_text", "").strip()
 
67
  except Exception as e:
68
  st.error(f"❌ Image analysis failed: {str(e)}")
69
  st.stop()
70
+
71
+ if not image_caption:
72
+ st.error("❌ Couldn't understand this image. Please try another!")
73
+ st.stop()
74
+
75
+ st.success(f"**Image Understanding:** {image_caption}")
76
 
77
+ # Create story prompt
78
+ story_prompt = f"""Write a 50 to 100 words children's story based on: {image_caption}
79
+ Requirements:
80
+ - Exclude your thinking process
 
 
81
  Story:"""
82
 
83
+ # Generate story with progress
84
+ progress_bar = st.progress(0)
85
+ status_text = st.empty()
86
+
87
  try:
88
  with st.spinner("📝 Crafting magical story..."):
89
+ start_time = time.time()
90
+
91
+ def update_progress(step):
92
+ progress = min(step/5, 1.0)
93
+ progress_bar.progress(progress)
94
+ status_text.text(f"Step {int(step)}/5: {'📖'*int(step)}")
95
+
96
+ update_progress(1)
97
  story_result = story_pipe(
98
  story_prompt,
99
  do_sample=True,
100
+ num_return_sequences=1
 
101
  )
102
 
103
+ update_progress(4)
104
+ generation_time = time.time() - start_time
105
+ st.info(f"Story generated in {generation_time:.1f} seconds")
 
 
106
 
107
+ # Process output
108
+ raw_story = story_result[0]['generated_text']
109
+ clean_story = raw_story.split("Story:")[-1].strip()
110
+ clean_story = re.sub(r'\n+', '\n\n', clean_story) # Improve paragraph spacing
111
+
112
+ # Format story text
113
+ final_story = ""
114
+ for paragraph in clean_story.split('\n\n'):
115
+ paragraph = paragraph.strip()
116
+ if paragraph:
117
+ sentences = []
118
+ for sent in re.split(r'(?<=[.!?]) +', paragraph):
119
+ sent = sent.strip()
120
+ if sent:
121
+ if len(sent) > 1 and not sent.endswith(('.','!','?')):
122
+ sent += '.'
123
+ sentences.append(sent[0].upper() + sent[1:])
124
+ final_story += ' '.join(sentences) + '\n\n'
125
+
126
+ update_progress(5)
127
+ time.sleep(0.5)
128
 
129
  except Exception as e:
130
  st.error(f"❌ Story generation failed: {str(e)}")
131
+ st.stop()
132
+
133
+ finally:
134
+ progress_bar.empty()
135
+ status_text.empty()
136
+
137
+ # Display story
138
+ st.subheader("✨ Your Magical Story")
139
+ st.write(final_story.strip())
140
+
141
+ # Audio conversion
142
+ with st.spinner("🔊 Creating audio version..."):
143
+ try:
144
+ audio = gTTS(text=final_story, lang="en", slow=False)
145
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
146
+ audio.save(tmp_file.name)
147
+ st.audio(tmp_file.name, format="audio/mp3")
148
+ except Exception as e:
149
+ st.error(f"❌ Audio conversion failed: {str(e)}")
150
 
151
  # Footer
152
  st.markdown("---")
153
  st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")
154
+