Update app.py
Browse files
app.py
CHANGED
@@ -13,22 +13,22 @@ st.title("📖✨ Turn Images into Children's Stories")
|
|
13 |
# —––––––– Load Pipelines (cached) —–––––––
|
14 |
@st.cache_resource(show_spinner=False)
|
15 |
def load_pipelines():
|
16 |
-
# 1) Image
|
17 |
captioner = pipeline(
|
18 |
task="image-to-text",
|
19 |
model="Salesforce/blip-image-captioning-base",
|
20 |
-
device=-1
|
21 |
)
|
22 |
-
|
|
|
23 |
storyteller = pipeline(
|
24 |
task="text2text-generation",
|
25 |
-
model="
|
26 |
-
tokenizer="
|
27 |
device=-1,
|
|
|
28 |
temperature=0.7,
|
29 |
-
|
30 |
-
repetition_penalty=1.2,
|
31 |
-
max_new_tokens=150
|
32 |
)
|
33 |
return captioner, storyteller
|
34 |
|
@@ -43,41 +43,43 @@ if uploaded:
|
|
43 |
# Generate caption
|
44 |
with st.spinner("🔍 Generating caption..."):
|
45 |
cap = captioner(img)
|
46 |
-
caption = cap[0].get("generated_text", "").strip()
|
47 |
if not caption:
|
48 |
st.error("😢 Couldn't understand this image. Try another one!")
|
49 |
st.stop()
|
50 |
st.success(f"**Caption:** {caption}")
|
51 |
|
52 |
-
#
|
53 |
prompt = f"generate story: {caption}"
|
54 |
with st.spinner("📝 Writing story..."):
|
55 |
start = time.time()
|
56 |
-
|
57 |
gen_time = time.time() - start
|
58 |
st.text(f"⏱ Generated in {gen_time:.1f}s")
|
59 |
-
story = out[0].get("generated_text", "").strip()
|
60 |
|
61 |
-
#
|
|
|
|
|
|
|
|
|
|
|
62 |
words = story.split()
|
63 |
-
if len(words) > 100
|
64 |
-
story = " ".join(words[:100]) + ("" if story.endswith('.') else ".")
|
65 |
|
66 |
# Display story
|
67 |
st.subheader("📚 Your Magical Story")
|
68 |
st.write(story)
|
69 |
|
70 |
-
#
|
71 |
with st.spinner("🔊 Converting to audio..."):
|
72 |
try:
|
73 |
tts = gTTS(text=story, lang="en", slow=False)
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
except Exception as e:
|
78 |
-
st.warning(f"⚠️
|
79 |
|
80 |
# Footer
|
81 |
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
|
82 |
|
83 |
-
|
|
|
13 |
# —––––––– Load Pipelines (cached) —–––––––
|
14 |
@st.cache_resource(show_spinner=False)
|
15 |
def load_pipelines():
|
16 |
+
# 1) Image captioning pipeline
|
17 |
captioner = pipeline(
|
18 |
task="image-to-text",
|
19 |
model="Salesforce/blip-image-captioning-base",
|
20 |
+
device=-1
|
21 |
)
|
22 |
+
|
23 |
+
# 2) Story generation pipeline using verified model
|
24 |
storyteller = pipeline(
|
25 |
task="text2text-generation",
|
26 |
+
model="laxya007/story-generator-t5-small",
|
27 |
+
tokenizer="t5-small",
|
28 |
device=-1,
|
29 |
+
max_length=200,
|
30 |
temperature=0.7,
|
31 |
+
do_sample=True
|
|
|
|
|
32 |
)
|
33 |
return captioner, storyteller
|
34 |
|
|
|
43 |
# Generate caption
|
44 |
with st.spinner("🔍 Generating caption..."):
|
45 |
cap = captioner(img)
|
46 |
+
caption = cap[0].get("generated_text", "").strip()
|
47 |
if not caption:
|
48 |
st.error("😢 Couldn't understand this image. Try another one!")
|
49 |
st.stop()
|
50 |
st.success(f"**Caption:** {caption}")
|
51 |
|
52 |
+
# Generate story
|
53 |
prompt = f"generate story: {caption}"
|
54 |
with st.spinner("📝 Writing story..."):
|
55 |
start = time.time()
|
56 |
+
story = storyteller(prompt)[0]['generated_text']
|
57 |
gen_time = time.time() - start
|
58 |
st.text(f"⏱ Generated in {gen_time:.1f}s")
|
|
|
59 |
|
60 |
+
# Format story output
|
61 |
+
story = story.replace("<pad>", "").replace("</s>", "").strip()
|
62 |
+
if story.startswith("generate story:"):
|
63 |
+
story = story[15:].strip()
|
64 |
+
|
65 |
+
# Word limit enforcement
|
66 |
words = story.split()
|
67 |
+
story = " ".join(words[:100]) if len(words) > 100 else story
|
|
|
68 |
|
69 |
# Display story
|
70 |
st.subheader("📚 Your Magical Story")
|
71 |
st.write(story)
|
72 |
|
73 |
+
# Audio conversion
|
74 |
with st.spinner("🔊 Converting to audio..."):
|
75 |
try:
|
76 |
tts = gTTS(text=story, lang="en", slow=False)
|
77 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
|
78 |
+
tts.save(tmp.name)
|
79 |
+
st.audio(tmp.name, format="audio/mp3")
|
80 |
except Exception as e:
|
81 |
+
st.warning(f"⚠️ Audio conversion failed: {str(e)}")
|
82 |
|
83 |
# Footer
|
84 |
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
|
85 |
|
|