mayf commited on
Commit
9862828
·
verified ·
1 Parent(s): c5b69e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -68
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Must be FIRST import and FIRST Streamlit command
2
  import streamlit as st
3
  st.set_page_config(
4
  page_title="Magic Story Generator",
@@ -6,9 +6,10 @@ st.set_page_config(
6
  page_icon="📖"
7
  )
8
 
9
- # Other imports AFTER Streamlit config
10
  import re
11
  import time
 
12
  import tempfile
13
  from PIL import Image
14
  from gtts import gTTS
@@ -24,21 +25,22 @@ def load_models():
24
  captioner = pipeline(
25
  "image-to-text",
26
  model="Salesforce/blip-image-captioning-base",
27
- device=-1 # Use -1 for CPU, 0 for GPU
28
  )
29
 
30
- # Story generation model (Qwen3-1.7B)
31
  storyteller = pipeline(
32
  "text-generation",
33
- model="Qwen/Qwen3-0.6B",
34
  device_map="auto",
35
  trust_remote_code=True,
36
- torch_dtype="auto",
37
- max_new_tokens=230,
38
- temperature=0.8,
 
39
  top_k=50,
40
- top_p=0.85,
41
- repetition_penalty=1.15,
42
  eos_token_id=151645
43
  )
44
 
@@ -55,12 +57,16 @@ uploaded_image = st.file_uploader(
55
  if uploaded_image:
56
  # Process image
57
  image = Image.open(uploaded_image).convert("RGB")
58
- st.image(image, use_container_width=True)
59
 
60
  # Generate caption
61
  with st.spinner("🔍 Analyzing image..."):
62
- caption_result = caption_pipe(image)
63
- image_caption = caption_result[0].get("generated_text", "").strip()
 
 
 
 
64
 
65
  if not image_caption:
66
  st.error("❌ Couldn't understand this image. Please try another!")
@@ -71,62 +77,38 @@ if uploaded_image:
71
  # Create story prompt
72
  story_prompt = (
73
  f"<|im_start|>system\n"
74
- f"You are a children's book author. Create a 100-150 word story based on: {image_caption}\n"
75
  )
76
 
77
- # Generate story
78
- with st.spinner("📝 Crafting magical story..."):
79
- start_time = time.time()
80
- story_result = story_pipe(
81
- story_prompt,
82
- do_sample=True,
83
- num_return_sequences=1,
84
- pad_token_id=151645
85
- )
86
- generation_time = time.time() - start_time
87
-
88
- # Process output
89
- raw_story = story_result[0]['generated_text']
90
-
91
- # Clean up story text
92
- clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
93
- clean_story = clean_story.split("<|im_start|>")[0] # Remove any new turns
94
- clean_story = clean_story.replace("<|im_end|>", "").strip()
95
-
96
- # Remove assistant mentions using regex
97
- clean_story = re.sub(
98
- r'^(assistant[:>]?\s*)+',
99
- '',
100
- clean_story,
101
- flags=re.IGNORECASE
102
- ).strip()
103
-
104
- # Format story punctuation
105
- final_story = []
106
- for sentence in clean_story.split(". "):
107
- sentence = sentence.strip()
108
- if not sentence:
109
- continue
110
- if not sentence.endswith('.'):
111
- sentence += '.'
112
- final_story.append(sentence[0].upper() + sentence[1:])
113
 
114
- final_story = " ".join(final_story).replace("..", ".")[:800]
115
-
116
- # Display story
117
- st.subheader("✨ Your Magical Story")
118
- st.write(final_story)
119
-
120
- # Audio conversion
121
- with st.spinner("🔊 Creating audio version..."):
122
- try:
123
- audio = gTTS(text=final_story, lang="en", slow=False)
124
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
125
- audio.save(tmp_file.name)
126
- st.audio(tmp_file.name, format="audio/mp3")
127
- except Exception as e:
128
- st.error(f"❌ Audio conversion failed: {str(e)}")
 
 
 
 
129
 
130
- # Footer
131
- st.markdown("---")
132
- st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")
 
 
 
 
 
 
1
+ # FIRST import and FIRST Streamlit command
2
  import streamlit as st
3
  st.set_page_config(
4
  page_title="Magic Story Generator",
 
6
  page_icon="📖"
7
  )
8
 
9
+ # Other imports
10
  import re
11
  import time
12
+ import torch
13
  import tempfile
14
  from PIL import Image
15
  from gtts import gTTS
 
25
  captioner = pipeline(
26
  "image-to-text",
27
  model="Salesforce/blip-image-captioning-base",
28
+ device=0 if torch.cuda.is_available() else -1
29
  )
30
 
31
+ # Optimized story generation model
32
  storyteller = pipeline(
33
  "text-generation",
34
+ model="Qwen/Qwen3-0.5B",
35
  device_map="auto",
36
  trust_remote_code=True,
37
+ model_kwargs={"load_in_8bit": True},
38
+ torch_dtype=torch.float16,
39
+ max_new_tokens=200,
40
+ temperature=0.9,
41
  top_k=50,
42
+ top_p=0.9,
43
+ repetition_penalty=1.1,
44
  eos_token_id=151645
45
  )
46
 
 
57
  if uploaded_image:
58
  # Process image
59
  image = Image.open(uploaded_image).convert("RGB")
60
+ st.image(image, use_column_width=True)
61
 
62
  # Generate caption
63
  with st.spinner("🔍 Analyzing image..."):
64
+ try:
65
+ caption_result = caption_pipe(image)
66
+ image_caption = caption_result[0].get("generated_text", "").strip()
67
+ except Exception as e:
68
+ st.error(f"❌ Image analysis failed: {str(e)}")
69
+ st.stop()
70
 
71
  if not image_caption:
72
  st.error("❌ Couldn't understand this image. Please try another!")
 
77
  # Create story prompt
78
  story_prompt = (
79
  f"<|im_start|>system\n"
80
+ f"You're a children's author. Create a short story (100-150 words) based on: {image_caption}\n"
81
  )
82
 
83
+ # Generate story with progress
84
+ progress_bar = st.progress(0)
85
+ status_text = st.empty()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ try:
88
+ with st.spinner("📝 Crafting magical story..."):
89
+ start_time = time.time()
90
+
91
+ def update_progress(step):
92
+ progress = min(step/5, 1.0) # Simulate progress steps
93
+ progress_bar.progress(progress)
94
+ status_text.text(f"Step {int(step)}/5: {'📖'*int(step)}")
95
+
96
+ update_progress(1)
97
+ story_result = story_pipe(
98
+ story_prompt,
99
+ do_sample=True,
100
+ num_return_sequences=1
101
+ )
102
+
103
+ update_progress(4)
104
+ generation_time = time.time() - start_time
105
+ st.info(f"Story generated in {generation_time:.1f} seconds")
106
 
107
+ # Process output
108
+ raw_story = story_result[0]['generated_text']
109
+ clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
110
+ clean_story = re.sub(r'<\|.*?\|>', '', clean_story).strip()
111
+
112
+ # Format story text
113
+ sentences = []
114
+ for sent in re