mayf commited on
Commit
748a576
·
verified ·
1 Parent(s): 2eb596f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -9,6 +9,13 @@ import torch
9
  from gtts import gTTS
10
  import tempfile
11
 
 
 
 
 
 
 
 
12
  # —––––––– Page Config —–––––––
13
  st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
14
  st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")
@@ -27,7 +34,7 @@ def load_clients():
27
  token=hf_token
28
  )
29
 
30
- # 2) Load Qwen2.5-Omni model & tokenizer
31
  t0 = time.time()
32
  tokenizer = AutoTokenizer.from_pretrained(
33
  "Qwen/Qwen2.5-Omni-7B",
@@ -40,7 +47,6 @@ def load_clients():
40
  torch_dtype=torch.bfloat16,
41
  attn_implementation="flash_attention_2"
42
  )
43
- # 3) Build text2text pipeline
44
  storyteller = pipeline(
45
  task="text2text-generation",
46
  model=model,
@@ -85,12 +91,9 @@ def generate_story(caption: str) -> str:
85
  st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
86
 
87
  story = result[0]["generated_text"].strip()
88
- # Enforce ≤100 words
89
  words = story.split()
90
  if len(words) > 100:
91
- story = " ".join(words[:100])
92
- if not story.endswith('.'):
93
- story += '.'
94
  return story
95
 
96
  # —––––––– Main App —–––––––
 
9
  from gtts import gTTS
10
  import tempfile
11
 
12
+ # —––––––– Requirements —–––––––
13
+ # This app uses a Hugging Face Transformers version that supports
14
+ # the Qwen2.5-Omni architecture via `trust_remote_code`.
15
+ # Install using:
16
+ # pip install git+https://github.com/huggingface/transformers.git
17
+ # and the rest of the requirements listed at the end.
18
+
19
  # —––––––– Page Config —–––––––
20
  st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
21
  st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")
 
34
  token=hf_token
35
  )
36
 
37
+ # 2) Load Qwen2.5-Omni model & tokenizer via remote code
38
  t0 = time.time()
39
  tokenizer = AutoTokenizer.from_pretrained(
40
  "Qwen/Qwen2.5-Omni-7B",
 
47
  torch_dtype=torch.bfloat16,
48
  attn_implementation="flash_attention_2"
49
  )
 
50
  storyteller = pipeline(
51
  task="text2text-generation",
52
  model=model,
 
91
  st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
92
 
93
  story = result[0]["generated_text"].strip()
 
94
  words = story.split()
95
  if len(words) > 100:
96
+ story = " ".join(words[:100]) + ('.' if not story.endswith('.') else '')
 
 
97
  return story
98
 
99
  # —––––––– Main App —–––––––