Update app.py
Browse files
app.py
CHANGED
@@ -9,37 +9,36 @@ from gtts import gTTS
|
|
9 |
import tempfile
|
10 |
|
11 |
# —––––––– Page Config —–––––––
|
12 |
-
st.set_page_config(page_title="Magic Story Generator
|
13 |
-
st.title("📖✨ Turn Images into Children's Stories
|
14 |
|
15 |
# —––––––– Clients (cached) —–––––––
|
16 |
@st.cache_resource(show_spinner=False)
|
17 |
def load_clients():
|
18 |
hf_token = st.secrets["HF_TOKEN"]
|
19 |
|
20 |
-
# Authenticate
|
21 |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
|
22 |
login(hf_token)
|
23 |
|
24 |
-
# Pin cache locally
|
25 |
cache_dir = "./hf_cache"
|
26 |
os.makedirs(cache_dir, exist_ok=True)
|
27 |
os.environ["TRANSFORMERS_CACHE"] = cache_dir
|
28 |
|
29 |
-
# 1) BLIP
|
30 |
caption_client = InferenceApi(
|
31 |
repo_id="Salesforce/blip-image-captioning-base",
|
32 |
token=hf_token
|
33 |
)
|
34 |
|
35 |
-
# 2) Text-generation pipeline
|
36 |
t0 = time.time()
|
37 |
story_generator = pipeline(
|
38 |
task="text-generation",
|
39 |
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
40 |
tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
41 |
-
device=-1
|
42 |
-
cache_dir=cache_dir
|
43 |
)
|
44 |
st.text(f"✅ Story model loaded in {time.time() - t0:.1f}s (cached thereafter)")
|
45 |
|
@@ -65,15 +64,15 @@ def generate_caption(img: Image.Image) -> str:
|
|
65 |
def generate_story(caption: str) -> str:
|
66 |
prompt = f"""
|
67 |
You are a creative children’s-story author.
|
68 |
-
Below is
|
69 |
“{caption}”
|
70 |
|
71 |
-
Write a coherent
|
72 |
-
1. Introduces the main character
|
73 |
2. Shows a simple problem or discovery.
|
74 |
-
3.
|
75 |
4. Uses clear language for ages 3–8.
|
76 |
-
5. Keeps
|
77 |
Story:
|
78 |
"""
|
79 |
t0 = time.time()
|
@@ -86,14 +85,13 @@ Story:
|
|
86 |
no_repeat_ngram_size=3,
|
87 |
do_sample=True
|
88 |
)
|
89 |
-
|
90 |
-
st.text(f"⏱ Generated in {gen_time:.1f}s on CPU")
|
91 |
|
92 |
text = outputs[0]["generated_text"].strip()
|
93 |
-
#
|
94 |
if text.startswith(prompt):
|
95 |
text = text[len(prompt):].strip()
|
96 |
-
#
|
97 |
words = text.split()
|
98 |
if len(words) > 100:
|
99 |
text = " ".join(words[:100])
|
@@ -103,7 +101,7 @@ Story:
|
|
103 |
|
104 |
|
105 |
# —––––––– Main App Flow —–––––––
|
106 |
-
uploaded = st.file_uploader("Upload an image:", type=["jpg",
|
107 |
if uploaded:
|
108 |
img = Image.open(uploaded).convert("RGB")
|
109 |
if max(img.size) > 2048:
|
@@ -134,3 +132,4 @@ if uploaded:
|
|
134 |
|
135 |
# Footer
|
136 |
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
|
|
|
|
9 |
import tempfile
|
10 |
|
11 |
# —––––––– Page Config —–––––––
|
12 |
+
st.set_page_config(page_title="Magic Story Generator", layout="centered")
|
13 |
+
st.title("📖✨ Turn Images into Children's Stories")
|
14 |
|
15 |
# —––––––– Clients (cached) —–––––––
|
16 |
@st.cache_resource(show_spinner=False)
|
17 |
def load_clients():
|
18 |
hf_token = st.secrets["HF_TOKEN"]
|
19 |
|
20 |
+
# Authenticate for both HF Hub and transformers
|
21 |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
|
22 |
login(hf_token)
|
23 |
|
24 |
+
# Pin transformers cache locally via env var
|
25 |
cache_dir = "./hf_cache"
|
26 |
os.makedirs(cache_dir, exist_ok=True)
|
27 |
os.environ["TRANSFORMERS_CACHE"] = cache_dir
|
28 |
|
29 |
+
# 1) BLIP image-captioning client
|
30 |
caption_client = InferenceApi(
|
31 |
repo_id="Salesforce/blip-image-captioning-base",
|
32 |
token=hf_token
|
33 |
)
|
34 |
|
35 |
+
# 2) Text-generation pipeline on CPU (no cache_dir arg here!)
|
36 |
t0 = time.time()
|
37 |
story_generator = pipeline(
|
38 |
task="text-generation",
|
39 |
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
40 |
tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
41 |
+
device=-1 # force CPU
|
|
|
42 |
)
|
43 |
st.text(f"✅ Story model loaded in {time.time() - t0:.1f}s (cached thereafter)")
|
44 |
|
|
|
64 |
def generate_story(caption: str) -> str:
|
65 |
prompt = f"""
|
66 |
You are a creative children’s-story author.
|
67 |
+
Below is an image description:
|
68 |
“{caption}”
|
69 |
|
70 |
+
Write a coherent 50–100 word story that:
|
71 |
+
1. Introduces the main character.
|
72 |
2. Shows a simple problem or discovery.
|
73 |
+
3. Has a happy resolution.
|
74 |
4. Uses clear language for ages 3–8.
|
75 |
+
5. Keeps sentences under 20 words.
|
76 |
Story:
|
77 |
"""
|
78 |
t0 = time.time()
|
|
|
85 |
no_repeat_ngram_size=3,
|
86 |
do_sample=True
|
87 |
)
|
88 |
+
st.text(f"⏱ Generated in {time.time() - t0:.1f}s on CPU")
|
|
|
89 |
|
90 |
text = outputs[0]["generated_text"].strip()
|
91 |
+
# strip the prompt echo
|
92 |
if text.startswith(prompt):
|
93 |
text = text[len(prompt):].strip()
|
94 |
+
# enforce ≤100 words
|
95 |
words = text.split()
|
96 |
if len(words) > 100:
|
97 |
text = " ".join(words[:100])
|
|
|
101 |
|
102 |
|
103 |
# —––––––– Main App Flow —–––––––
|
104 |
+
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
|
105 |
if uploaded:
|
106 |
img = Image.open(uploaded).convert("RGB")
|
107 |
if max(img.size) > 2048:
|
|
|
132 |
|
133 |
# Footer
|
134 |
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
|
135 |
+
|