mayf commited on
Commit
88ee0a7
·
verified ·
1 Parent(s): bddc67c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -19
app.py CHANGED
@@ -5,41 +5,60 @@ from PIL import Image
5
  from io import BytesIO
6
  from huggingface_hub import InferenceApi, login
7
  from transformers import pipeline
 
 
8
  from gtts import gTTS
9
  import tempfile
10
 
11
  # —––––––– Page Config —–––––––
12
- st.set_page_config(page_title="Magic Story Generator", layout="centered")
13
- st.title("📖✨ Turn Images into Children's Stories")
14
 
15
  # —––––––– Load Clients & Pipelines (cached) —–––––––
16
  @st.cache_resource(show_spinner=False)
17
  def load_clients():
18
  hf_token = st.secrets["HF_TOKEN"]
19
- # authenticate so transformers can pick up your token
20
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
21
  login(hf_token)
22
 
23
- # BLIP captioning via Hugging Face Inference API
24
  caption_client = InferenceApi(
25
  repo_id="Salesforce/blip-image-captioning-base",
26
  token=hf_token
27
  )
28
 
29
- # Instruction-tuned story generator: Flan-T5
30
  t0 = time.time()
 
 
 
 
 
 
 
 
 
 
 
31
  storyteller = pipeline(
32
  task="text2text-generation",
33
- model="google/flan-t5-small",
34
- device=-1, # CPU
35
- max_length=150 # prompt + generation cap
 
 
 
 
 
36
  )
37
- st.text(f"✅ Story model loaded in {time.time() - t0:.1f}s")
 
 
38
  return caption_client, storyteller
39
 
40
  caption_client, storyteller = load_clients()
41
 
42
-
43
  # —––––––– Helpers —–––––––
44
  def generate_caption(img: Image.Image) -> str:
45
  buf = BytesIO()
@@ -49,32 +68,33 @@ def generate_caption(img: Image.Image) -> str:
49
  return resp[0].get("generated_text", "").strip()
50
  return ""
51
 
 
52
  def generate_story(caption: str) -> str:
53
  prompt = (
54
- "You are a creative childrens-story author.\n"
55
  f"Image description: “{caption}”\n\n"
56
  "Write a coherent 50–100 word story\n"
57
  )
58
  t0 = time.time()
59
- out = storyteller(prompt, max_new_tokens=120, temperature=0.7, top_p=0.9)[0]["generated_text"]
60
- st.text(f"⏱ Generated in {time.time() - t0:.1f}s")
61
- story = out.strip()
62
 
63
- # Truncate to at most 100 words
 
64
  words = story.split()
65
  if len(words) > 100:
66
  story = " ".join(words[:100])
67
- if not story.endswith("."):
68
- story += "."
69
  return story
70
 
71
-
72
  # —––––––– Main App —–––––––
73
  uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
74
  if uploaded:
75
  img = Image.open(uploaded).convert("RGB")
76
  if max(img.size) > 2048:
77
- img.thumbnail((2048, 2048))
78
  st.image(img, use_container_width=True)
79
 
80
  with st.spinner("🔍 Generating caption..."):
 
5
  from io import BytesIO
6
  from huggingface_hub import InferenceApi, login
7
  from transformers import pipeline
8
+ import torch
9
+ from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniTokenizer
10
  from gtts import gTTS
11
  import tempfile
12
 
13
  # —––––––– Page Config —–––––––
14
+ st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
15
+ st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")
16
 
17
  # —––––––– Load Clients & Pipelines (cached) —–––––––
18
  @st.cache_resource(show_spinner=False)
19
  def load_clients():
20
  hf_token = st.secrets["HF_TOKEN"]
21
+ # Authenticate for HF Hub
22
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
23
  login(hf_token)
24
 
25
+ # 1) BLIP captioning via HF Inference API
26
  caption_client = InferenceApi(
27
  repo_id="Salesforce/blip-image-captioning-base",
28
  token=hf_token
29
  )
30
 
31
+ # 2) Qwen2.5-Omni story generator
32
  t0 = time.time()
33
+ model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
34
+ "Qwen/Qwen2.5-Omni-7B",
35
+ device_map="auto",
36
+ torch_dtype=torch.bfloat16,
37
+ attn_implementation="flash_attention_2",
38
+ trust_remote_code=True
39
+ )
40
+ tokenizer = Qwen2_5OmniTokenizer.from_pretrained(
41
+ "Qwen/Qwen2.5-Omni-7B",
42
+ trust_remote_code=True
43
+ )
44
  storyteller = pipeline(
45
  task="text2text-generation",
46
+ model=model,
47
+ tokenizer=tokenizer,
48
+ device_map="auto",
49
+ temperature=0.7,
50
+ top_p=0.9,
51
+ repetition_penalty=1.2,
52
+ no_repeat_ngram_size=3,
53
+ max_new_tokens=120
54
  )
55
+ load_time = time.time() - t0
56
+ st.text(f"✅ Story model loaded in {load_time:.1f}s (cached thereafter)")
57
+
58
  return caption_client, storyteller
59
 
60
  caption_client, storyteller = load_clients()
61
 
 
62
  # —––––––– Helpers —–––––––
63
  def generate_caption(img: Image.Image) -> str:
64
  buf = BytesIO()
 
68
  return resp[0].get("generated_text", "").strip()
69
  return ""
70
 
71
+
72
  def generate_story(caption: str) -> str:
73
  prompt = (
74
+ "You are a creative children's-story author.\n"
75
  f"Image description: “{caption}”\n\n"
76
  "Write a coherent 50–100 word story\n"
77
  )
78
  t0 = time.time()
79
+ outputs = storyteller(prompt)
80
+ gen_time = time.time() - t0
81
+ st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
82
 
83
+ story = outputs[0]["generated_text"].strip()
84
+ # Enforce ≤100 words
85
  words = story.split()
86
  if len(words) > 100:
87
  story = " ".join(words[:100])
88
+ if not story.endswith('.'):
89
+ story += '.'
90
  return story
91
 
 
92
  # —––––––– Main App —–––––––
93
  uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
94
  if uploaded:
95
  img = Image.open(uploaded).convert("RGB")
96
  if max(img.size) > 2048:
97
+ img.thumbnail((2048,2048))
98
  st.image(img, use_container_width=True)
99
 
100
  with st.spinner("🔍 Generating caption..."):