mayf commited on
Commit
121e41f
·
verified ·
1 Parent(s): cc355a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -10,55 +10,58 @@ import tempfile
10
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
  st.title("🖼️ ➡️ 📖 Interactive Storyteller")
12
 
13
- # —––––––– Load and warm pipelines
14
  @st.cache_resource
15
  def load_pipelines():
16
- # BLIP-base for captions
17
  captioner = pipeline(
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
- device=-1 # GPU if available, else -1
21
  )
22
- # Flan-T5-Large for stories
23
- storyteller = pipeline(
24
- "text2text-generation",
25
- model="google/flan-t5-large",
 
26
  device=0
27
  )
28
- # Warm-up runs so user-facing calls are fast
 
29
  dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
30
  captioner(dummy)
31
- storyteller("Warm up", max_new_tokens=1)
32
- return captioner, storyteller
 
33
 
34
- captioner, storyteller = load_pipelines()
35
 
36
  # —––––––– Main UI
37
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
38
  if uploaded:
39
- # 1) Preprocess image
40
  image = Image.open(uploaded).convert("RGB")
41
  image = image.resize((384, 384), Image.LANCZOS)
42
  st.image(image, caption="Your image", use_container_width=True)
43
 
44
- # 2) Caption
45
- with st.spinner("🔍 Generating caption"):
46
  cap = captioner(image)[0]["generated_text"].strip()
47
  st.markdown(f"**Caption:** {cap}")
48
 
49
- # 3) Build a dynamic prompt
50
  prompt = (
51
  f"Here is an image description: “{cap}”.\n"
52
  "Write an 80–100 word playful story for 3–10 year-old children that:\n"
53
- "1) Describes the scene and subject from the description.\n"
54
- "2) Explains what the subject is doing and how it feels.\n"
55
  "3) Concludes with a fun, imaginative ending.\n\n"
56
  "Story:"
57
  )
58
 
59
- # 4) Generate the story
60
- with st.spinner("✍️ Writing the story…"):
61
- output = storyteller(
62
  prompt,
63
  max_new_tokens=120,
64
  do_sample=True,
@@ -68,13 +71,13 @@ if uploaded:
68
  repetition_penalty=1.2,
69
  no_repeat_ngram_size=3
70
  )
71
- story = output[0]["generated_text"].strip()
72
 
73
  st.markdown("**Story:**")
74
  st.write(story)
75
 
76
- # 5) Text-to-Speech
77
- with st.spinner("🔊 Converting to speech"):
78
  tts = gTTS(text=story, lang="en")
79
  tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
80
  tts.write_to_fp(tmp)
 
10
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
  st.title("🖼️ ➡️ 📖 Interactive Storyteller")
12
 
13
+ # —––––––– Load & warm pipelines
14
  @st.cache_resource
15
  def load_pipelines():
16
+ # 1) BLIP-base for captions
17
  captioner = pipeline(
18
  "image-to-text",
19
  model="Salesforce/blip-image-captioning-base",
20
+ device=0 # set to -1 if you only have CPU
21
  )
22
+ # 2) DeepSeek-R1-Distill (Qwen-1.5B) for stories
23
+ ds_storyteller = pipeline(
24
+ "text-generation",
25
+ model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
26
+ trust_remote_code=True,
27
  device=0
28
  )
29
+
30
+ # Warm-up both so the first real request is faster
31
  dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
32
  captioner(dummy)
33
+ ds_storyteller("Warm up", max_new_tokens=1)
34
+
35
+ return captioner, ds_storyteller
36
 
37
+ captioner, ds_storyteller = load_pipelines()
38
 
39
  # —––––––– Main UI
40
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
41
  if uploaded:
42
+ # 1) Preprocess & display
43
  image = Image.open(uploaded).convert("RGB")
44
  image = image.resize((384, 384), Image.LANCZOS)
45
  st.image(image, caption="Your image", use_container_width=True)
46
 
47
+ # 2) Generate caption
48
+ with st.spinner("🔍 Generating caption..."):
49
  cap = captioner(image)[0]["generated_text"].strip()
50
  st.markdown(f"**Caption:** {cap}")
51
 
52
+ # 3) Build prompt
53
  prompt = (
54
  f"Here is an image description: “{cap}”.\n"
55
  "Write an 80–100 word playful story for 3–10 year-old children that:\n"
56
+ "1) Describes the scene and main subject.\n"
57
+ "2) Explains what it’s doing and how it feels.\n"
58
  "3) Concludes with a fun, imaginative ending.\n\n"
59
  "Story:"
60
  )
61
 
62
+ # 4) Generate story via DeepSeek
63
+ with st.spinner("✍️ Generating story with DeepSeek..."):
64
+ out = ds_storyteller(
65
  prompt,
66
  max_new_tokens=120,
67
  do_sample=True,
 
71
  repetition_penalty=1.2,
72
  no_repeat_ngram_size=3
73
  )
74
+ story = out[0]["generated_text"].strip()
75
 
76
  st.markdown("**Story:**")
77
  st.write(story)
78
 
79
+ # 5) Text-to-Speech via gTTS
80
+ with st.spinner("🔊 Converting to speech..."):
81
  tts = gTTS(text=story, lang="en")
82
  tmp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
83
  tts.write_to_fp(tmp)