mayf commited on
Commit
7d2ac1c
·
verified ·
1 Parent(s): 121e41f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -44
app.py CHANGED
@@ -2,7 +2,8 @@
2
 
3
  import streamlit as st
4
  from PIL import Image
5
- from transformers import pipeline
 
6
  from gtts import gTTS
7
  import tempfile
8
 
@@ -10,48 +11,52 @@ import tempfile
10
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
11
  st.title("🖼️ ➡️ 📖 Interactive Storyteller")
12
 
13
- # —––––––– Load & warm pipelines
14
  @st.cache_resource
15
- def load_pipelines():
16
- # 1) BLIP-base for captions
17
- captioner = pipeline(
18
- "image-to-text",
19
- model="Salesforce/blip-image-captioning-base",
20
- device=0 # set to -1 if you only have CPU
 
 
 
21
  )
22
- # 2) DeepSeek-R1-Distill (Qwen-1.5B) for stories
23
- ds_storyteller = pipeline(
24
- "text-generation",
25
- model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
26
- trust_remote_code=True,
27
- device=0
28
  )
 
29
 
30
- # Warm-up both so the first real request is faster
31
- dummy = Image.new("RGB", (384, 384), color=(128, 128, 128))
32
- captioner(dummy)
33
- ds_storyteller("Warm up", max_new_tokens=1)
34
-
35
- return captioner, ds_storyteller
36
-
37
- captioner, ds_storyteller = load_pipelines()
38
 
39
  # —––––––– Main UI
40
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
41
- if uploaded:
42
- # 1) Preprocess & display
43
- image = Image.open(uploaded).convert("RGB")
44
- image = image.resize((384, 384), Image.LANCZOS)
45
- st.image(image, caption="Your image", use_container_width=True)
 
46
 
47
- # 2) Generate caption
48
  with st.spinner("🔍 Generating caption..."):
49
- cap = captioner(image)[0]["generated_text"].strip()
50
- st.markdown(f"**Caption:** {cap}")
 
 
 
 
 
 
 
51
 
52
  # 3) Build prompt
53
  prompt = (
54
- f"Here is an image description: “{cap}”.\n"
55
  "Write an 80–100 word playful story for 3–10 year-old children that:\n"
56
  "1) Describes the scene and main subject.\n"
57
  "2) Explains what it’s doing and how it feels.\n"
@@ -59,20 +64,25 @@ if uploaded:
59
  "Story:"
60
  )
61
 
62
- # 4) Generate story via DeepSeek
63
- with st.spinner("✍️ Generating story with DeepSeek..."):
64
- out = ds_storyteller(
65
- prompt,
66
- max_new_tokens=120,
67
- do_sample=True,
68
- temperature=0.7,
69
- top_p=0.9,
70
- top_k=50,
71
- repetition_penalty=1.2,
72
- no_repeat_ngram_size=3
 
 
73
  )
74
- story = out[0]["generated_text"].strip()
75
-
 
 
 
76
  st.markdown("**Story:**")
77
  st.write(story)
78
 
@@ -83,3 +93,4 @@ if uploaded:
83
  tts.write_to_fp(tmp)
84
  tmp.flush()
85
  st.audio(tmp.name, format="audio/mp3")
 
 
2
 
3
  import streamlit as st
4
  from PIL import Image
5
+ from io import BytesIO
6
+ from huggingface_hub import InferenceApi
7
  from gtts import gTTS
8
  import tempfile
9
 
 
11
  st.set_page_config(page_title="Storyteller for Kids", layout="centered")
12
  st.title("🖼️ ➡️ 📖 Interactive Storyteller")
13
 
14
+ # —––––––– Inference clients (cached)
15
  @st.cache_resource
16
+ def load_clients():
17
+ # read your HF token from Space secrets
18
+ hf_token = st.secrets["HF_TOKEN"]
19
+
20
+ # caption client: BLIP-base via HF Image-to-Text API
21
+ caption_client = InferenceApi(
22
+ repo_id="Salesforce/blip-image-captioning-base",
23
+ task="image-to-text",
24
+ token=hf_token
25
  )
26
+ # story client: DeepSeek-R1-Distill via HF Text-Generation API
27
+ story_client = InferenceApi(
28
+ repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
29
+ task="text-generation",
30
+ token=hf_token
 
31
  )
32
+ return caption_client, story_client
33
 
34
+ caption_client, story_client = load_clients()
 
 
 
 
 
 
 
35
 
36
  # —––––––– Main UI
37
  uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
38
+ if not uploaded:
39
+ st.info("Please upload an image (JPG/PNG) to begin.")
40
+ else:
41
+ # 1) Display the image
42
+ img = Image.open(uploaded).convert("RGB")
43
+ st.image(img, use_container_width=True)
44
 
45
+ # 2) Caption via HF Inference API
46
  with st.spinner("🔍 Generating caption..."):
47
+ buf = BytesIO()
48
+ img.save(buf, format="PNG")
49
+ caption_output = caption_client(data=buf.getvalue())
50
+ # handle API return formats
51
+ if isinstance(caption_output, dict):
52
+ cap_text = caption_output.get("generated_text", "").strip()
53
+ else:
54
+ cap_text = str(caption_output).strip()
55
+ st.markdown(f"**Caption:** {cap_text}")
56
 
57
  # 3) Build prompt
58
  prompt = (
59
+ f"Here’s an image description: “{cap_text}”.\n\n"
60
  "Write an 80–100 word playful story for 3–10 year-old children that:\n"
61
  "1) Describes the scene and main subject.\n"
62
  "2) Explains what it’s doing and how it feels.\n"
 
64
  "Story:"
65
  )
66
 
67
+ # 4) Story via HF Inference API
68
+ with st.spinner("✍️ Generating story..."):
69
+ story_output = story_client(
70
+ inputs=prompt,
71
+ params={
72
+ "max_new_tokens": 120,
73
+ "do_sample": True,
74
+ "temperature": 0.7,
75
+ "top_p": 0.9,
76
+ "top_k": 50,
77
+ "repetition_penalty": 1.2,
78
+ "no_repeat_ngram_size": 3
79
+ }
80
  )
81
+ # API returns list of generations or a dict
82
+ if isinstance(story_output, list):
83
+ story = story_output[0].get("generated_text", "").strip()
84
+ else:
85
+ story = story_output.get("generated_text", "").strip()
86
  st.markdown("**Story:**")
87
  st.write(story)
88
 
 
93
  tts.write_to_fp(tmp)
94
  tmp.flush()
95
  st.audio(tmp.name, format="audio/mp3")
96
+