mayf commited on
Commit
982555a
·
verified ·
1 Parent(s): 2abb776

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -72
app.py CHANGED
@@ -1,115 +1,122 @@
1
- import os
2
- import time
3
  import streamlit as st
 
 
 
 
 
 
 
 
 
4
  from PIL import Image
5
- from transformers import pipeline
6
  from gtts import gTTS
7
- import tempfile
8
-
9
- # --- Requirements ---
10
- # Update requirements.txt to include:
11
- """
12
- streamlit>=1.20
13
- pillow>=9.0
14
- torch>=2.0.0
15
- transformers>=4.40
16
- sentencepiece>=0.2.0
17
- gTTS>=2.3.1
18
- accelerate>=0.30
19
- """
20
 
21
- # --- Page Setup ---
22
- st.set_page_config(page_title="Magic Story Generator", layout="centered")
23
  st.title("📖✨ Turn Images into Children's Stories")
24
 
25
- # --- Load Pipelines (cached) ---
26
  @st.cache_resource(show_spinner=False)
27
- def load_pipelines():
28
- # 1) Image-captioning pipeline (BLIP)
29
  captioner = pipeline(
30
- task="image-to-text",
31
  model="Salesforce/blip-image-captioning-base",
32
- device=-1
33
  )
34
 
35
- # 2) Modified story-generation pipeline using Qwen3-1.7B
36
  storyteller = pipeline(
37
- task="text-generation",
38
  model="Qwen/Qwen3-1.7B",
39
  device_map="auto",
40
  trust_remote_code=True,
41
  torch_dtype="auto",
42
  max_new_tokens=150,
43
  temperature=0.7,
44
- top_p=0.9,
45
- repetition_penalty=1.2,
46
- eos_token_id=151645 # Specific to Qwen3 tokenizer
47
  )
48
 
49
  return captioner, storyteller
50
 
51
- captioner, storyteller = load_pipelines()
 
 
 
 
 
 
52
 
53
- # --- Main App ---
54
- uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
55
- if uploaded:
56
- # Load and display the image
57
- img = Image.open(uploaded).convert("RGB")
58
- st.image(img, use_container_width=True)
59
 
60
  # Generate caption
61
- with st.spinner("🔍 Generating caption..."):
62
- cap = captioner(img)
63
- caption = cap[0].get("generated_text", "").strip() if isinstance(cap, list) else ""
64
- if not caption:
65
- st.error("😢 Couldn't understand this image. Try another one!")
 
66
  st.stop()
67
- st.success(f"**Caption:** {caption}")
 
68
 
69
- # Build prompt and generate story
70
- prompt = (
71
  f"<|im_start|>system\n"
72
- f"You are a children's story writer. Create a 50-100 word story based on this image description: {caption}\n"
73
- f"<|im_end|>\n"
74
  f"<|im_start|>user\n"
75
- f"Write a coherent, child-friendly story that flows naturally with simple vocabulary.<|im_end|>\n"
76
  f"<|im_start|>assistant\n"
77
  )
78
-
79
- with st.spinner("📝 Writing story..."):
80
- start = time.time()
81
- out = storyteller(
82
- prompt,
 
83
  do_sample=True,
84
  num_return_sequences=1
85
  )
86
- gen_time = time.time() - start
87
- st.text(f"⏱ Generated in {gen_time:.1f}s")
88
-
89
  # Process output
90
- story = out[0]['generated_text'].split("<|im_start|>assistant\n")[-1]
91
- story = story.replace("<|im_end|>", "").strip()
 
 
 
 
 
 
 
 
 
92
 
93
- # Enforce ≤100 words and proper ending
94
- words = story.split()
95
- if len(words) > 100:
96
- story = " ".join(words[:100])
97
- if not story.endswith(('.', '!', '?')):
98
- story += '.'
99
 
100
  # Display story
101
- st.subheader("📚 Your Magical Story")
102
- st.write(story)
103
 
104
- # Convert to audio
105
- with st.spinner("🔊 Converting to audio..."):
106
  try:
107
- tts = gTTS(text=story, lang="en", slow=False)
108
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
109
- tts.save(tmp.name)
110
- st.audio(tmp.name, format="audio/mp3")
111
  except Exception as e:
112
- st.warning(f"⚠️ TTS failed: {e}")
113
 
114
  # Footer
115
- st.markdown("---\nMade with ❤️ by your friendly story wizard")
 
 
 
1
+ # Must be FIRST import and FIRST Streamlit command
 
2
  import streamlit as st
3
+ st.set_page_config(
4
+ page_title="Magic Story Generator",
5
+ layout="centered",
6
+ page_icon="📖"
7
+ )
8
+
9
+ # Other imports AFTER Streamlit config
10
+ import time
11
+ import tempfile
12
  from PIL import Image
 
13
  from gtts import gTTS
14
+ from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # --- Constants & Setup ---
 
17
  st.title("📖✨ Turn Images into Children's Stories")
18
 
19
+ # --- Model Loading (Cached) ---
20
  @st.cache_resource(show_spinner=False)
21
+ def load_models():
22
+ # Image captioning model
23
  captioner = pipeline(
24
+ "image-to-text",
25
  model="Salesforce/blip-image-captioning-base",
26
+ device=-1 # Use -1 for CPU, 0 for GPU
27
  )
28
 
29
+ # Story generation model (Qwen3-1.7B)
30
  storyteller = pipeline(
31
+ "text-generation",
32
  model="Qwen/Qwen3-1.7B",
33
  device_map="auto",
34
  trust_remote_code=True,
35
  torch_dtype="auto",
36
  max_new_tokens=150,
37
  temperature=0.7,
38
+ top_p=0.85,
39
+ repetition_penalty=1.15,
40
+ eos_token_id=151645 # Qwen3's specific EOS token
41
  )
42
 
43
  return captioner, storyteller
44
 
45
+ caption_pipe, story_pipe = load_models()
46
+
47
+ # --- Main Application Flow ---
48
+ uploaded_image = st.file_uploader(
49
+ "Upload a children's book style image:",
50
+ type=["jpg", "jpeg", "png"]
51
+ )
52
 
53
+ if uploaded_image:
54
+ # Process image
55
+ image = Image.open(uploaded_image).convert("RGB")
56
+ st.image(image, use_container_width=True)
 
 
57
 
58
  # Generate caption
59
+ with st.spinner("🔍 Analyzing image..."):
60
+ caption_result = caption_pipe(image)
61
+ image_caption = caption_result[0].get("generated_text", "").strip()
62
+
63
+ if not image_caption:
64
+ st.error("❌ Couldn't understand this image. Please try another!")
65
  st.stop()
66
+
67
+ st.success(f"**Image Understanding:** {image_caption}")
68
 
69
+ # Create story prompt
70
+ story_prompt = (
71
  f"<|im_start|>system\n"
72
+ f"You are a children's book author. Create a 50-100 word story based on this image description: {image_caption}\n"
73
+ "Use simple language, friendly characters, and a positive lesson.<|im_end|>\n"
74
  f"<|im_start|>user\n"
75
+ f"Write a short, child-friendly story with a clear beginning, middle, and end.<|im_end|>\n"
76
  f"<|im_start|>assistant\n"
77
  )
78
+
79
+ # Generate story
80
+ with st.spinner("📝 Crafting magical story..."):
81
+ start_time = time.time()
82
+ story_result = story_pipe(
83
+ story_prompt,
84
  do_sample=True,
85
  num_return_sequences=1
86
  )
87
+ generation_time = time.time() - start_time
88
+ st.text(f"⏱ Generation time: {generation_time:.1f}s")
89
+
90
  # Process output
91
+ raw_story = story_result[0]['generated_text']
92
+ clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
93
+ clean_story = clean_story.replace("<|im_end|>", "").strip()
94
+
95
+ # Ensure proper story formatting
96
+ final_story = []
97
+ for sentence in clean_story.split(". "):
98
+ if not sentence: continue
99
+ if not sentence.endswith('.'):
100
+ sentence += '.'
101
+ final_story.append(sentence[0].upper() + sentence[1:])
102
 
103
+ final_story = " ".join(final_story).replace("..", ".")[:600] # Character limit safeguard
 
 
 
 
 
104
 
105
  # Display story
106
+ st.subheader(" Your Magical Story")
107
+ st.write(final_story)
108
 
109
+ # Audio conversion
110
+ with st.spinner("🔊 Creating audio version..."):
111
  try:
112
+ audio = gTTS(text=final_story, lang="en", slow=False)
113
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
114
+ audio.save(tmp_file.name)
115
+ st.audio(tmp_file.name, format="audio/mp3")
116
  except Exception as e:
117
+ st.error(f" Audio conversion failed: {str(e)}")
118
 
119
  # Footer
120
+ st.markdown("---")
121
+ st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")
122
+