mayf commited on
Commit
e508bdf
·
verified ·
1 Parent(s): b3abd21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -64
app.py CHANGED
@@ -1,76 +1,108 @@
1
  import os
2
- import torch
3
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
- from huggingface_hub import InferenceApi
5
  from PIL import Image
6
- from io import BytesIO
 
7
 
8
- def load_caption_client(token: str):
9
- return InferenceApi(repo_id="Salesforce/blip-image-captioning-base", token=token)
 
10
 
11
- def generate_caption(image_path: str, caption_client) -> str:
12
- img = Image.open(image_path).convert("RGB")
13
- buf = BytesIO()
14
- img.save(buf, format="JPEG")
15
- resp = caption_client(data=buf.getvalue())
16
- if isinstance(resp, list) and resp:
17
- return resp[0].get("generated_text", "").strip()
18
- return ""
19
 
20
- def load_gpt2(model_name="gpt2"):
21
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
22
- model = GPT2LMHeadModel.from_pretrained(model_name)
23
- model.eval()
24
- return tokenizer, model
25
-
26
- def generate_story(caption: str, tokenizer, model) -> str:
27
- # Build a strong prompt
28
- prompt = (
29
- f"You are a creative children’s-story author.\n"
30
- f"Image description: “{caption}”\n\n"
31
- "Write a coherent, 50–100 word story:\n"
32
  )
33
- # Tokenize and move to device
34
- inputs = tokenizer(prompt, return_tensors="pt")
35
- # Generate up to ~120 new tokens
36
- outputs = model.generate(
37
- **inputs,
38
- max_new_tokens=120,
39
- temperature=0.7,
40
- top_p=0.9,
41
- repetition_penalty=1.1,
42
- no_repeat_ngram_size=3,
43
- do_sample=True,
44
- pad_token_id=tokenizer.eos_token_id
45
  )
46
- # Decode and strip the prompt echo
47
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
- story = text[len(prompt):].strip()
49
- # Truncate to 100 words if needed
50
- words = story.split()
51
- if len(words) > 100:
52
- story = " ".join(words[:100])
53
- if not story.endswith("."):
54
- story += "."
55
- return story
56
 
57
- if __name__ == "__main__":
58
- # 1) Read your HF token
59
- hf_token = os.environ.get("HF_TOKEN")
60
- if not hf_token:
61
- raise RuntimeError("Please set HF_TOKEN env var")
 
 
 
 
 
 
 
62
 
63
- # 2) Generate caption
64
- caption_client = load_caption_client(hf_token)
65
- image_path = "path/to/your/image.jpg" # <-- change me
66
- caption = generate_caption(image_path, caption_client)
67
- print(f"Caption: {caption}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # 3) Load GPT-2
70
- tokenizer, model = load_gpt2("gpt2")
71
- # (optionally move model to GPU: model.to("cuda"))
72
 
73
- # 4) Generate & print story
74
- story = generate_story(caption, tokenizer, model)
75
- print("Story:\n", story)
 
 
 
 
 
 
76
 
 
 
 
1
  import os
2
+ import time
3
+ import streamlit as st
4
+ from transformers import pipeline
5
  from PIL import Image
6
+ from gtts import gTTS
7
+ import tempfile
8
 
9
+ # —––––––– Page Config —–––––––
10
+ st.set_page_config(page_title="Magic Story Generator", layout="centered")
11
+ st.title("📖✨ Turn Images into Children's Stories")
12
 
13
+ # —––––––– Load Pipelines (cached) —–––––––
14
+ @st.cache_resource(show_spinner=False)
15
+ def load_pipelines():
16
+ # Cache transformers models locally
17
+ os.environ.setdefault("TRANSFORMERS_CACHE", "./hf_cache")
 
 
 
18
 
19
+ # 1) Image-to-text pipeline for captioning (BLIP)
20
+ captioner = pipeline(
21
+ task="image-to-text",
22
+ model="Salesforce/blip-image-captioning-base",
23
+ device=-1 # force CPU; use 0 for GPU
 
 
 
 
 
 
 
24
  )
25
+
26
+ # 2) Text-generation pipeline for storytelling (GPT-2)
27
+ storyteller = pipeline(
28
+ task="text-generation",
29
+ model="gpt2",
30
+ tokenizer="gpt2",
31
+ device=-1 # CPU
 
 
 
 
 
32
  )
 
 
 
 
 
 
 
 
 
 
33
 
34
+ return captioner, storyteller
35
+
36
+ captioner, storyteller = load_pipelines()
37
+
38
+ # —––––––– Main App Flow —–––––––
39
+ uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
40
+ if uploaded:
41
+ # Load image
42
+ img = Image.open(uploaded).convert("RGB")
43
+ if max(img.size) > 2048:
44
+ img.thumbnail((2048, 2048))
45
+ st.image(img, use_container_width=True)
46
 
47
+ # Generate caption
48
+ with st.spinner("🔍 Generating caption..."):
49
+ raw = captioner(img)
50
+ caption = raw[0].get("generated_text", "").strip()
51
+ if not caption:
52
+ st.error("😢 Couldn't understand this image. Try another one!")
53
+ st.stop()
54
+ st.success(f"**Caption:** {caption}")
55
+
56
+ # Build storytelling prompt
57
+ prompt = f"""
58
+ You are a creative children’s-story author.
59
+ Image description: “{caption}”
60
+
61
+ Write a coherent, 50–100 word story that:
62
+ 1. Introduces the main character.
63
+ 2. Shows a simple problem or discovery.
64
+ 3. Has a happy resolution.
65
+ 4. Uses clear language for ages 3–8.
66
+ 5. Keeps sentences under 20 words.
67
+ Story:
68
+ """
69
+ # Generate story
70
+ with st.spinner("📝 Writing story..."):
71
+ t0 = time.time()
72
+ outputs = storyteller(
73
+ prompt,
74
+ max_new_tokens=120,
75
+ temperature=0.7,
76
+ top_p=0.9,
77
+ repetition_penalty=1.1,
78
+ no_repeat_ngram_size=3,
79
+ do_sample=True,
80
+ pad_token_id=storyteller.tokenizer.eos_token_id
81
+ )
82
+ story_text = outputs[0]["generated_text"].strip()
83
+ load_time = time.time() - t0
84
+ st.text(f"⏱ Story generated in {load_time:.1f}s")
85
+
86
+ # Post-process: strip prompt echo and truncate
87
+ if story_text.startswith(prompt):
88
+ story_text = story_text[len(prompt):].strip()
89
+ words = story_text.split()
90
+ if len(words) > 100:
91
+ story_text = " ".join(words[:100]) + ("." if not story_text.endswith('.') else "")
92
 
93
+ # Display story
94
+ st.subheader("📚 Your Magical Story")
95
+ st.write(story_text)
96
 
97
+ # Convert to audio
98
+ with st.spinner("🔊 Converting to audio..."):
99
+ try:
100
+ tts = gTTS(text=story_text, lang="en", slow=False)
101
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
102
+ tts.save(fp.name)
103
+ st.audio(fp.name, format="audio/mp3")
104
+ except Exception as e:
105
+ st.warning(f"⚠️ TTS failed: {e}")
106
 
107
+ # Footer
108
+ st.markdown("---\n*Made with ❤️ by your friendly story wizard* ")