mayf commited on
Commit
e616e4e
·
verified ·
1 Parent(s): 5a9c362

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -4,9 +4,8 @@ import streamlit as st
4
  from PIL import Image
5
  from io import BytesIO
6
  from huggingface_hub import InferenceApi, login
7
- from transformers import pipeline
8
  import torch
9
- from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniTokenizer
10
  from gtts import gTTS
11
  import tempfile
12
 
@@ -18,29 +17,30 @@ st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")
18
  @st.cache_resource(show_spinner=False)
19
  def load_clients():
20
  hf_token = st.secrets["HF_TOKEN"]
21
- # Authenticate for HF Hub
22
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
23
  login(hf_token)
24
 
25
- # 1) BLIP captioning via HF Inference API
26
  caption_client = InferenceApi(
27
  repo_id="Salesforce/blip-image-captioning-base",
28
  token=hf_token
29
  )
30
 
31
- # 2) Qwen2.5-Omni story generator
32
  t0 = time.time()
33
- model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
34
  "Qwen/Qwen2.5-Omni-7B",
35
- device_map="auto",
36
- torch_dtype=torch.bfloat16,
37
- attn_implementation="flash_attention_2",
38
  trust_remote_code=True
39
  )
40
- tokenizer = Qwen2_5OmniTokenizer.from_pretrained(
41
  "Qwen/Qwen2.5-Omni-7B",
42
- trust_remote_code=True
 
 
 
43
  )
 
44
  storyteller = pipeline(
45
  task="text2text-generation",
46
  model=model,
@@ -53,8 +53,7 @@ def load_clients():
53
  max_new_tokens=120
54
  )
55
  load_time = time.time() - t0
56
- st.text(f"✅ Story model loaded in {load_time:.1f}s (cached thereafter)")
57
-
58
  return caption_client, storyteller
59
 
60
  caption_client, storyteller = load_clients()
@@ -73,14 +72,19 @@ def generate_story(caption: str) -> str:
73
  prompt = (
74
  "You are a creative children's-story author.\n"
75
  f"Image description: “{caption}”\n\n"
76
- "Write a coherent 50–100 word story\n"
 
 
 
 
 
77
  )
78
  t0 = time.time()
79
- outputs = storyteller(prompt)
80
  gen_time = time.time() - t0
81
  st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
82
 
83
- story = outputs[0]["generated_text"].strip()
84
  # Enforce ≤100 words
85
  words = story.split()
86
  if len(words) > 100:
@@ -120,4 +124,4 @@ if uploaded:
120
  st.warning(f"⚠️ TTS failed: {e}")
121
 
122
  # Footer
123
- st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
 
4
  from PIL import Image
5
  from io import BytesIO
6
  from huggingface_hub import InferenceApi, login
7
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
8
  import torch
 
9
  from gtts import gTTS
10
  import tempfile
11
 
 
17
  @st.cache_resource(show_spinner=False)
18
  def load_clients():
19
  hf_token = st.secrets["HF_TOKEN"]
20
+ # Authenticate for Hugging Face Hub
21
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
22
  login(hf_token)
23
 
24
+ # 1) BLIP captioning via HTTP API
25
  caption_client = InferenceApi(
26
  repo_id="Salesforce/blip-image-captioning-base",
27
  token=hf_token
28
  )
29
 
30
+ # 2) Load Qwen2.5-Omni model & tokenizer
31
  t0 = time.time()
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
  "Qwen/Qwen2.5-Omni-7B",
 
 
 
34
  trust_remote_code=True
35
  )
36
+ model = AutoModelForSeq2SeqLM.from_pretrained(
37
  "Qwen/Qwen2.5-Omni-7B",
38
+ trust_remote_code=True,
39
+ device_map="auto",
40
+ torch_dtype=torch.bfloat16,
41
+ attn_implementation="flash_attention_2"
42
  )
43
+ # 3) Build text2text pipeline
44
  storyteller = pipeline(
45
  task="text2text-generation",
46
  model=model,
 
53
  max_new_tokens=120
54
  )
55
  load_time = time.time() - t0
56
+ st.text(f"✅ Story model loaded in {load_time:.1f}s (cached)")
 
57
  return caption_client, storyteller
58
 
59
  caption_client, storyteller = load_clients()
 
72
  prompt = (
73
  "You are a creative children's-story author.\n"
74
  f"Image description: “{caption}”\n\n"
75
+ "Write a coherent 50–100 word story that:\n"
76
+ "1. Introduces the main character.\n"
77
+ "2. Shows a simple problem or discovery.\n"
78
+ "3. Has a happy resolution.\n"
79
+ "4. Uses clear language for ages 3–8.\n"
80
+ "5. Keeps each sentence under 20 words.\n"
81
  )
82
  t0 = time.time()
83
+ result = storyteller(prompt)
84
  gen_time = time.time() - t0
85
  st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
86
 
87
+ story = result[0]["generated_text"].strip()
88
  # Enforce ≤100 words
89
  words = story.split()
90
  if len(words) > 100:
 
124
  st.warning(f"⚠️ TTS failed: {e}")
125
 
126
  # Footer
127
+ st.markdown("---\n*Made with ❤️ by your friendly story wizard*")