Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
#
|
2 |
import streamlit as st
|
3 |
st.set_page_config(
|
4 |
page_title="Magic Story Generator",
|
@@ -6,9 +6,10 @@ st.set_page_config(
|
|
6 |
page_icon="📖"
|
7 |
)
|
8 |
|
9 |
-
# Other imports
|
10 |
import re
|
11 |
import time
|
|
|
12 |
import tempfile
|
13 |
from PIL import Image
|
14 |
from gtts import gTTS
|
@@ -24,21 +25,22 @@ def load_models():
|
|
24 |
captioner = pipeline(
|
25 |
"image-to-text",
|
26 |
model="Salesforce/blip-image-captioning-base",
|
27 |
-
device
|
28 |
)
|
29 |
|
30 |
-
#
|
31 |
storyteller = pipeline(
|
32 |
"text-generation",
|
33 |
-
model="Qwen/Qwen3-0.
|
34 |
device_map="auto",
|
35 |
trust_remote_code=True,
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
39 |
top_k=50,
|
40 |
-
top_p=0.
|
41 |
-
repetition_penalty=1.
|
42 |
eos_token_id=151645
|
43 |
)
|
44 |
|
@@ -55,12 +57,16 @@ uploaded_image = st.file_uploader(
|
|
55 |
if uploaded_image:
|
56 |
# Process image
|
57 |
image = Image.open(uploaded_image).convert("RGB")
|
58 |
-
st.image(image,
|
59 |
|
60 |
# Generate caption
|
61 |
with st.spinner("🔍 Analyzing image..."):
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
|
65 |
if not image_caption:
|
66 |
st.error("❌ Couldn't understand this image. Please try another!")
|
@@ -71,62 +77,38 @@ if uploaded_image:
|
|
71 |
# Create story prompt
|
72 |
story_prompt = (
|
73 |
f"<|im_start|>system\n"
|
74 |
-
f"You
|
75 |
)
|
76 |
|
77 |
-
# Generate story
|
78 |
-
|
79 |
-
|
80 |
-
story_result = story_pipe(
|
81 |
-
story_prompt,
|
82 |
-
do_sample=True,
|
83 |
-
num_return_sequences=1,
|
84 |
-
pad_token_id=151645
|
85 |
-
)
|
86 |
-
generation_time = time.time() - start_time
|
87 |
-
|
88 |
-
# Process output
|
89 |
-
raw_story = story_result[0]['generated_text']
|
90 |
-
|
91 |
-
# Clean up story text
|
92 |
-
clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
|
93 |
-
clean_story = clean_story.split("<|im_start|>")[0] # Remove any new turns
|
94 |
-
clean_story = clean_story.replace("<|im_end|>", "").strip()
|
95 |
-
|
96 |
-
# Remove assistant mentions using regex
|
97 |
-
clean_story = re.sub(
|
98 |
-
r'^(assistant[:>]?\s*)+',
|
99 |
-
'',
|
100 |
-
clean_story,
|
101 |
-
flags=re.IGNORECASE
|
102 |
-
).strip()
|
103 |
-
|
104 |
-
# Format story punctuation
|
105 |
-
final_story = []
|
106 |
-
for sentence in clean_story.split(". "):
|
107 |
-
sentence = sentence.strip()
|
108 |
-
if not sentence:
|
109 |
-
continue
|
110 |
-
if not sentence.endswith('.'):
|
111 |
-
sentence += '.'
|
112 |
-
final_story.append(sentence[0].upper() + sentence[1:])
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
129 |
|
130 |
-
#
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FIRST import and FIRST Streamlit command
|
2 |
import streamlit as st
|
3 |
st.set_page_config(
|
4 |
page_title="Magic Story Generator",
|
|
|
6 |
page_icon="📖"
|
7 |
)
|
8 |
|
9 |
+
# Other imports
|
10 |
import re
|
11 |
import time
|
12 |
+
import torch
|
13 |
import tempfile
|
14 |
from PIL import Image
|
15 |
from gtts import gTTS
|
|
|
25 |
captioner = pipeline(
|
26 |
"image-to-text",
|
27 |
model="Salesforce/blip-image-captioning-base",
|
28 |
+
device=0 if torch.cuda.is_available() else -1
|
29 |
)
|
30 |
|
31 |
+
# Optimized story generation model
|
32 |
storyteller = pipeline(
|
33 |
"text-generation",
|
34 |
+
model="Qwen/Qwen3-0.5B",
|
35 |
device_map="auto",
|
36 |
trust_remote_code=True,
|
37 |
+
model_kwargs={"load_in_8bit": True},
|
38 |
+
torch_dtype=torch.float16,
|
39 |
+
max_new_tokens=200,
|
40 |
+
temperature=0.9,
|
41 |
top_k=50,
|
42 |
+
top_p=0.9,
|
43 |
+
repetition_penalty=1.1,
|
44 |
eos_token_id=151645
|
45 |
)
|
46 |
|
|
|
57 |
if uploaded_image:
|
58 |
# Process image
|
59 |
image = Image.open(uploaded_image).convert("RGB")
|
60 |
+
st.image(image, use_column_width=True)
|
61 |
|
62 |
# Generate caption
|
63 |
with st.spinner("🔍 Analyzing image..."):
|
64 |
+
try:
|
65 |
+
caption_result = caption_pipe(image)
|
66 |
+
image_caption = caption_result[0].get("generated_text", "").strip()
|
67 |
+
except Exception as e:
|
68 |
+
st.error(f"❌ Image analysis failed: {str(e)}")
|
69 |
+
st.stop()
|
70 |
|
71 |
if not image_caption:
|
72 |
st.error("❌ Couldn't understand this image. Please try another!")
|
|
|
77 |
# Create story prompt
|
78 |
story_prompt = (
|
79 |
f"<|im_start|>system\n"
|
80 |
+
f"You're a children's author. Create a short story (100-150 words) based on: {image_caption}\n"
|
81 |
)
|
82 |
|
83 |
+
# Generate story with progress
|
84 |
+
progress_bar = st.progress(0)
|
85 |
+
status_text = st.empty()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
+
try:
|
88 |
+
with st.spinner("📝 Crafting magical story..."):
|
89 |
+
start_time = time.time()
|
90 |
+
|
91 |
+
def update_progress(step):
|
92 |
+
progress = min(step/5, 1.0) # Simulate progress steps
|
93 |
+
progress_bar.progress(progress)
|
94 |
+
status_text.text(f"Step {int(step)}/5: {'📖'*int(step)}")
|
95 |
+
|
96 |
+
update_progress(1)
|
97 |
+
story_result = story_pipe(
|
98 |
+
story_prompt,
|
99 |
+
do_sample=True,
|
100 |
+
num_return_sequences=1
|
101 |
+
)
|
102 |
+
|
103 |
+
update_progress(4)
|
104 |
+
generation_time = time.time() - start_time
|
105 |
+
st.info(f"Story generated in {generation_time:.1f} seconds")
|
106 |
|
107 |
+
# Process output
|
108 |
+
raw_story = story_result[0]['generated_text']
|
109 |
+
clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
|
110 |
+
clean_story = re.sub(r'<\|.*?\|>', '', clean_story).strip()
|
111 |
+
|
112 |
+
# Format story text
|
113 |
+
sentences = []
|
114 |
+
for sent in re
|