mayf commited on
Commit
6adb177
·
verified ·
1 Parent(s): e508bdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -64
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
  import time
3
  import streamlit as st
4
- from transformers import pipeline
5
  from PIL import Image
 
 
 
6
  from gtts import gTTS
7
  import tempfile
8
 
@@ -10,94 +12,92 @@ import tempfile
10
  st.set_page_config(page_title="Magic Story Generator", layout="centered")
11
  st.title("📖✨ Turn Images into Children's Stories")
12
 
13
- # —––––––– Load Pipelines (cached) —–––––––
14
  @st.cache_resource(show_spinner=False)
15
- def load_pipelines():
16
- # Cache transformers models locally
17
- os.environ.setdefault("TRANSFORMERS_CACHE", "./hf_cache")
18
-
19
- # 1) Image-to-text pipeline for captioning (BLIP)
20
- captioner = pipeline(
21
- task="image-to-text",
22
- model="Salesforce/blip-image-captioning-base",
23
- device=-1 # force CPU; use 0 for GPU
 
24
  )
25
 
26
- # 2) Text-generation pipeline for storytelling (GPT-2)
 
27
  storyteller = pipeline(
28
- task="text-generation",
29
- model="gpt2",
30
- tokenizer="gpt2",
31
- device=-1 # CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
 
 
 
 
33
 
34
- return captioner, storyteller
 
 
 
 
 
 
35
 
36
- captioner, storyteller = load_pipelines()
37
 
38
- # —––––––– Main App Flow —–––––––
39
- uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
40
  if uploaded:
41
- # Load image
42
  img = Image.open(uploaded).convert("RGB")
43
  if max(img.size) > 2048:
44
  img.thumbnail((2048, 2048))
45
  st.image(img, use_container_width=True)
46
 
47
- # Generate caption
48
  with st.spinner("🔍 Generating caption..."):
49
- raw = captioner(img)
50
- caption = raw[0].get("generated_text", "").strip()
51
  if not caption:
52
  st.error("😢 Couldn't understand this image. Try another one!")
53
  st.stop()
54
  st.success(f"**Caption:** {caption}")
55
 
56
- # Build storytelling prompt
57
- prompt = f"""
58
- You are a creative children’s-story author.
59
- Image description: “{caption}”
60
-
61
- Write a coherent, 50–100 word story that:
62
- 1. Introduces the main character.
63
- 2. Shows a simple problem or discovery.
64
- 3. Has a happy resolution.
65
- 4. Uses clear language for ages 3–8.
66
- 5. Keeps sentences under 20 words.
67
- Story:
68
- """
69
- # Generate story
70
  with st.spinner("📝 Writing story..."):
71
- t0 = time.time()
72
- outputs = storyteller(
73
- prompt,
74
- max_new_tokens=120,
75
- temperature=0.7,
76
- top_p=0.9,
77
- repetition_penalty=1.1,
78
- no_repeat_ngram_size=3,
79
- do_sample=True,
80
- pad_token_id=storyteller.tokenizer.eos_token_id
81
- )
82
- story_text = outputs[0]["generated_text"].strip()
83
- load_time = time.time() - t0
84
- st.text(f"⏱ Story generated in {load_time:.1f}s")
85
-
86
- # Post-process: strip prompt echo and truncate
87
- if story_text.startswith(prompt):
88
- story_text = story_text[len(prompt):].strip()
89
- words = story_text.split()
90
- if len(words) > 100:
91
- story_text = " ".join(words[:100]) + ("." if not story_text.endswith('.') else "")
92
 
93
- # Display story
94
  st.subheader("📚 Your Magical Story")
95
- st.write(story_text)
96
 
97
- # Convert to audio
98
  with st.spinner("🔊 Converting to audio..."):
99
  try:
100
- tts = gTTS(text=story_text, lang="en", slow=False)
101
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
102
  tts.save(fp.name)
103
  st.audio(fp.name, format="audio/mp3")
@@ -105,4 +105,4 @@ Story:
105
  st.warning(f"⚠️ TTS failed: {e}")
106
 
107
  # Footer
108
- st.markdown("---\n*Made with ❤️ by your friendly story wizard* ")
 
1
  import os
2
  import time
3
  import streamlit as st
 
4
  from PIL import Image
5
+ from io import BytesIO
6
+ from huggingface_hub import InferenceApi, login
7
+ from transformers import pipeline
8
  from gtts import gTTS
9
  import tempfile
10
 
 
12
  st.set_page_config(page_title="Magic Story Generator", layout="centered")
13
  st.title("📖✨ Turn Images into Children's Stories")
14
 
15
+ # —––––––– Load Clients & Pipelines (cached) —–––––––
16
  @st.cache_resource(show_spinner=False)
17
+ def load_clients():
18
+ hf_token = st.secrets["HF_TOKEN"]
19
+ # authenticate so transformers can pick up your token
20
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
21
+ login(hf_token)
22
+
23
+ # BLIP captioning via Hugging Face Inference API
24
+ caption_client = InferenceApi(
25
+ repo_id="Salesforce/blip-image-captioning-base",
26
+ token=hf_token
27
  )
28
 
29
+ # Instruction-tuned story generator: Flan-T5
30
+ t0 = time.time()
31
  storyteller = pipeline(
32
+ task="text2text-generation",
33
+ model="google/flan-t5-small",
34
+ device=-1, # CPU
35
+ max_length=150 # prompt + generation cap
36
+ )
37
+ st.text(f"✅ Story model loaded in {time.time() - t0:.1f}s")
38
+ return caption_client, storyteller
39
+
40
+ caption_client, storyteller = load_clients()
41
+
42
+
43
+ # —––––––– Helpers —–––––––
44
+ def generate_caption(img: Image.Image) -> str:
45
+ buf = BytesIO()
46
+ img.save(buf, format="JPEG")
47
+ resp = caption_client(data=buf.getvalue())
48
+ if isinstance(resp, list) and resp:
49
+ return resp[0].get("generated_text", "").strip()
50
+ return ""
51
+
52
+ def generate_story(caption: str) -> str:
53
+ prompt = (
54
+ "You are a creative children’s-story author.\n"
55
+ f"Image description: “{caption}”\n\n"
56
+ "Write a coherent 50–100 word story that:\n"
57
+ "1. Introduces the main character.\n"
58
+ "2. Shows a simple problem or discovery.\n"
59
+ "3. Has a happy resolution.\n"
60
+ "4. Uses clear language for ages 3–8.\n"
61
+ "5. Keeps each sentence under 20 words.\n"
62
  )
63
+ t0 = time.time()
64
+ out = storyteller(prompt, max_new_tokens=120, temperature=0.7, top_p=0.9)[0]["generated_text"]
65
+ st.text(f"⏱ Generated in {time.time() - t0:.1f}s")
66
+ story = out.strip()
67
 
68
+ # Truncate to at most 100 words
69
+ words = story.split()
70
+ if len(words) > 100:
71
+ story = " ".join(words[:100])
72
+ if not story.endswith("."):
73
+ story += "."
74
+ return story
75
 
 
76
 
77
+ # —––––––– Main App —–––––––
78
+ uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
79
  if uploaded:
 
80
  img = Image.open(uploaded).convert("RGB")
81
  if max(img.size) > 2048:
82
  img.thumbnail((2048, 2048))
83
  st.image(img, use_container_width=True)
84
 
 
85
  with st.spinner("🔍 Generating caption..."):
86
+ caption = generate_caption(img)
 
87
  if not caption:
88
  st.error("😢 Couldn't understand this image. Try another one!")
89
  st.stop()
90
  st.success(f"**Caption:** {caption}")
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  with st.spinner("📝 Writing story..."):
93
+ story = generate_story(caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
95
  st.subheader("📚 Your Magical Story")
96
+ st.write(story)
97
 
 
98
  with st.spinner("🔊 Converting to audio..."):
99
  try:
100
+ tts = gTTS(text=story, lang="en", slow=False)
101
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
102
  tts.save(fp.name)
103
  st.audio(fp.name, format="audio/mp3")
 
105
  st.warning(f"⚠️ TTS failed: {e}")
106
 
107
  # Footer
108
+ st.markdown("---\n*Made with ❤️ by your friendly story wizard*")