mayf commited on
Commit
c876f7b
·
verified ·
1 Parent(s): 91713d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -21
app.py CHANGED
@@ -13,22 +13,22 @@ st.title("📖✨ Turn Images into Children's Stories")
13
  # —––––––– Load Pipelines (cached) —–––––––
14
  @st.cache_resource(show_spinner=False)
15
  def load_pipelines():
16
- # 1) Image-captioning pipeline (BLIP)
17
  captioner = pipeline(
18
  task="image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
- device=-1 # CPU; set to 0+ for GPU
21
  )
22
- # 2) Story-generation pipeline (T5-base Story)
 
23
  storyteller = pipeline(
24
  task="text2text-generation",
25
- model="mrm8488/t5-base-finetuned-story-generation",
26
- tokenizer="mrm8488/t5-base-finetuned-story-generation",
27
  device=-1,
 
28
  temperature=0.7,
29
- top_p=0.9,
30
- repetition_penalty=1.2,
31
- max_new_tokens=150
32
  )
33
  return captioner, storyteller
34
 
@@ -43,41 +43,43 @@ if uploaded:
43
  # Generate caption
44
  with st.spinner("🔍 Generating caption..."):
45
  cap = captioner(img)
46
- caption = cap[0].get("generated_text", "").strip() if isinstance(cap, list) else ""
47
  if not caption:
48
  st.error("😢 Couldn't understand this image. Try another one!")
49
  st.stop()
50
  st.success(f"**Caption:** {caption}")
51
 
52
- # Build prompt and generate story
53
  prompt = f"generate story: {caption}"
54
  with st.spinner("📝 Writing story..."):
55
  start = time.time()
56
- out = storyteller(prompt)
57
  gen_time = time.time() - start
58
  st.text(f"⏱ Generated in {gen_time:.1f}s")
59
- story = out[0].get("generated_text", "").strip()
60
 
61
- # Enforce ≤100 words
 
 
 
 
 
62
  words = story.split()
63
- if len(words) > 100:
64
- story = " ".join(words[:100]) + ("" if story.endswith('.') else ".")
65
 
66
  # Display story
67
  st.subheader("📚 Your Magical Story")
68
  st.write(story)
69
 
70
- # Convert to audio
71
  with st.spinner("🔊 Converting to audio..."):
72
  try:
73
  tts = gTTS(text=story, lang="en", slow=False)
74
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
75
- tts.save(tmp.name)
76
- st.audio(tmp.name, format="audio/mp3")
77
  except Exception as e:
78
- st.warning(f"⚠️ TTS failed: {e}")
79
 
80
  # Footer
81
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
82
 
83
-
 
13
  # —––––––– Load Pipelines (cached) —–––––––
14
  @st.cache_resource(show_spinner=False)
15
  def load_pipelines():
16
+ # 1) Image captioning pipeline
17
  captioner = pipeline(
18
  task="image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
+ device=-1
21
  )
22
+
23
+ # 2) Story generation pipeline using verified model
24
  storyteller = pipeline(
25
  task="text2text-generation",
26
+ model="laxya007/story-generator-t5-small",
27
+ tokenizer="t5-small",
28
  device=-1,
29
+ max_length=200,
30
  temperature=0.7,
31
+ do_sample=True
 
 
32
  )
33
  return captioner, storyteller
34
 
 
43
  # Generate caption
44
  with st.spinner("🔍 Generating caption..."):
45
  cap = captioner(img)
46
+ caption = cap[0].get("generated_text", "").strip()
47
  if not caption:
48
  st.error("😢 Couldn't understand this image. Try another one!")
49
  st.stop()
50
  st.success(f"**Caption:** {caption}")
51
 
52
+ # Generate story
53
  prompt = f"generate story: {caption}"
54
  with st.spinner("📝 Writing story..."):
55
  start = time.time()
56
+ story = storyteller(prompt)[0]['generated_text']
57
  gen_time = time.time() - start
58
  st.text(f"⏱ Generated in {gen_time:.1f}s")
 
59
 
60
+ # Format story output
61
+ story = story.replace("<pad>", "").replace("</s>", "").strip()
62
+ if story.startswith("generate story:"):
63
+ story = story[15:].strip()
64
+
65
+ # Word limit enforcement
66
  words = story.split()
67
+ story = " ".join(words[:100]) if len(words) > 100 else story
 
68
 
69
  # Display story
70
  st.subheader("📚 Your Magical Story")
71
  st.write(story)
72
 
73
+ # Audio conversion
74
  with st.spinner("🔊 Converting to audio..."):
75
  try:
76
  tts = gTTS(text=story, lang="en", slow=False)
77
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
78
+ tts.save(tmp.name)
79
+ st.audio(tmp.name, format="audio/mp3")
80
  except Exception as e:
81
+ st.warning(f"⚠️ Audio conversion failed: {str(e)}")
82
 
83
  # Footer
84
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
85