mayf commited on
Commit
b3abd21
·
verified ·
1 Parent(s): 2aae3c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -118
app.py CHANGED
@@ -1,135 +1,76 @@
1
  import os
2
- import time
3
- import streamlit as st
 
4
  from PIL import Image
5
  from io import BytesIO
6
- from huggingface_hub import InferenceApi, login
7
- from transformers import pipeline
8
- from gtts import gTTS
9
- import tempfile
10
 
11
- # —––––––– Page Config —–––––––
12
- st.set_page_config(page_title="Magic Story Generator", layout="centered")
13
- st.title("📖✨ Turn Images into Children's Stories")
14
 
15
- # —––––––– Clients (cached) —–––––––
16
- @st.cache_resource(show_spinner=False)
17
- def load_clients():
18
- hf_token = st.secrets["HF_TOKEN"]
19
-
20
- # Authenticate for both HF Hub and transformers
21
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
22
- login(hf_token)
23
-
24
- # Pin transformers cache locally via env var
25
- cache_dir = "./hf_cache"
26
- os.makedirs(cache_dir, exist_ok=True)
27
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
28
-
29
- # 1) BLIP image-captioning client
30
- caption_client = InferenceApi(
31
- repo_id="Salesforce/blip-image-captioning-base",
32
- token=hf_token
33
- )
34
-
35
- # 2) Text-generation pipeline on CPU (no cache_dir arg here!)
36
- t0 = time.time()
37
- story_generator = pipeline(
38
- task="text-generation",
39
- model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
40
- tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
41
- device=-1 # force CPU
42
- )
43
- st.text(f"✅ Story model loaded in {time.time() - t0:.1f}s (cached thereafter)")
44
-
45
- return caption_client, story_generator
46
-
47
- caption_client, story_generator = load_clients()
48
-
49
-
50
- # —––––––– Helper: Generate Caption —–––––––
51
- def generate_caption(img: Image.Image) -> str:
52
  buf = BytesIO()
53
  img.save(buf, format="JPEG")
54
- try:
55
- resp = caption_client(data=buf.getvalue())
56
- if isinstance(resp, list) and resp:
57
- return resp[0].get("generated_text", "").strip()
58
- except Exception as e:
59
- st.error(f"Caption generation error: {e}")
60
  return ""
61
 
62
-
63
- # —––––––– Helper: Generate Story via pipeline —–––––––
64
- def generate_story(caption: str) -> str:
65
- prompt = f"""
66
- You are a creative children’s-story author.
67
- Below is an image description:
68
- “{caption}”
69
-
70
- Write a coherent 50–100 word story that:
71
- 1. Introduces the main character.
72
- 2. Shows a simple problem or discovery.
73
- 3. Has a happy resolution.
74
- 4. Uses clear language for ages 3–8.
75
- 5. Keeps sentences under 20 words.
76
- Story:
77
- """
78
- t0 = time.time()
79
- outputs = story_generator(
80
- prompt,
81
  max_new_tokens=120,
82
  temperature=0.7,
83
  top_p=0.9,
84
  repetition_penalty=1.1,
85
  no_repeat_ngram_size=3,
86
- do_sample=True
 
87
  )
88
- st.text(f"⏱ Generated in {time.time() - t0:.1f}s on CPU")
89
-
90
- text = outputs[0]["generated_text"].strip()
91
- # strip the prompt echo
92
- if text.startswith(prompt):
93
- text = text[len(prompt):].strip()
94
- # enforce ≤100 words
95
- words = text.split()
96
  if len(words) > 100:
97
- text = " ".join(words[:100])
98
- if not text.endswith("."):
99
- text += "."
100
- return text
101
-
102
-
103
- # —––––––– Main App Flow —–––––––
104
- uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
105
- if uploaded:
106
- img = Image.open(uploaded).convert("RGB")
107
- if max(img.size) > 2048:
108
- img.thumbnail((2048, 2048))
109
- st.image(img, use_container_width=True)
110
-
111
- with st.spinner("🔍 Generating caption..."):
112
- caption = generate_caption(img)
113
- if not caption:
114
- st.error("😢 Couldn't understand this image. Try another one!")
115
- st.stop()
116
- st.success(f"**Caption:** {caption}")
117
-
118
- with st.spinner("📝 Writing magical story..."):
119
- story = generate_story(caption)
120
-
121
- st.subheader("📚 Your Magical Story")
122
- st.write(story)
123
-
124
- with st.spinner("🔊 Converting to audio..."):
125
- try:
126
- tts = gTTS(text=story, lang="en", slow=False)
127
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
128
- tts.save(fp.name)
129
- st.audio(fp.name, format="audio/mp3")
130
- except Exception as e:
131
- st.warning(f"⚠️ TTS failed: {e}")
132
-
133
- # Footer
134
- st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
135
 
 
1
  import os
2
+ import torch
3
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
+ from huggingface_hub import InferenceApi
5
  from PIL import Image
6
  from io import BytesIO
 
 
 
 
7
 
8
+ def load_caption_client(token: str):
9
+ return InferenceApi(repo_id="Salesforce/blip-image-captioning-base", token=token)
 
10
 
11
+ def generate_caption(image_path: str, caption_client) -> str:
12
+ img = Image.open(image_path).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  buf = BytesIO()
14
  img.save(buf, format="JPEG")
15
+ resp = caption_client(data=buf.getvalue())
16
+ if isinstance(resp, list) and resp:
17
+ return resp[0].get("generated_text", "").strip()
 
 
 
18
  return ""
19
 
20
+ def load_gpt2(model_name="gpt2"):
21
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
22
+ model = GPT2LMHeadModel.from_pretrained(model_name)
23
+ model.eval()
24
+ return tokenizer, model
25
+
26
+ def generate_story(caption: str, tokenizer, model) -> str:
27
+ # Build a strong prompt
28
+ prompt = (
29
+ f"You are a creative children’s-story author.\n"
30
+ f"Image description: “{caption}”\n\n"
31
+ "Write a coherent, 50–100 word story:\n"
32
+ )
33
+ # Tokenize and move to device
34
+ inputs = tokenizer(prompt, return_tensors="pt")
35
+ # Generate up to ~120 new tokens
36
+ outputs = model.generate(
37
+ **inputs,
 
38
  max_new_tokens=120,
39
  temperature=0.7,
40
  top_p=0.9,
41
  repetition_penalty=1.1,
42
  no_repeat_ngram_size=3,
43
+ do_sample=True,
44
+ pad_token_id=tokenizer.eos_token_id
45
  )
46
+ # Decode and strip the prompt echo
47
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+ story = text[len(prompt):].strip()
49
+ # Truncate to 100 words if needed
50
+ words = story.split()
 
 
 
51
  if len(words) > 100:
52
+ story = " ".join(words[:100])
53
+ if not story.endswith("."):
54
+ story += "."
55
+ return story
56
+
57
+ if __name__ == "__main__":
58
+ # 1) Read your HF token
59
+ hf_token = os.environ.get("HF_TOKEN")
60
+ if not hf_token:
61
+ raise RuntimeError("Please set HF_TOKEN env var")
62
+
63
+ # 2) Generate caption
64
+ caption_client = load_caption_client(hf_token)
65
+ image_path = "path/to/your/image.jpg" # <-- change me
66
+ caption = generate_caption(image_path, caption_client)
67
+ print(f"Caption: {caption}\n")
68
+
69
+ # 3) Load GPT-2
70
+ tokenizer, model = load_gpt2("gpt2")
71
+ # (optionally move model to GPU: model.to("cuda"))
72
+
73
+ # 4) Generate & print story
74
+ story = generate_story(caption, tokenizer, model)
75
+ print("Story:\n", story)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76