mayf commited on
Commit
2c0fb69
·
verified ·
1 Parent(s): 748a576

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -4,17 +4,16 @@ import streamlit as st
4
  from PIL import Image
5
  from io import BytesIO
6
  from huggingface_hub import InferenceApi, login
7
- from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
8
  import torch
9
  from gtts import gTTS
10
  import tempfile
11
 
12
  # —––––––– Requirements —–––––––
13
- # This app uses a Hugging Face Transformers version that supports
14
- # the Qwen2.5-Omni architecture via `trust_remote_code`.
15
- # Install using:
16
  # pip install git+https://github.com/huggingface/transformers.git
17
- # and the rest of the requirements listed at the end.
 
18
 
19
  # —––––––– Page Config —–––––––
20
  st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
@@ -28,27 +27,28 @@ def load_clients():
28
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
29
  login(hf_token)
30
 
31
- # 1) BLIP captioning via HTTP API
32
  caption_client = InferenceApi(
33
  repo_id="Salesforce/blip-image-captioning-base",
34
  token=hf_token
35
  )
36
 
37
- # 2) Load Qwen2.5-Omni model & tokenizer via remote code
38
  t0 = time.time()
39
  tokenizer = AutoTokenizer.from_pretrained(
40
  "Qwen/Qwen2.5-Omni-7B",
41
  trust_remote_code=True
42
  )
43
- model = AutoModelForSeq2SeqLM.from_pretrained(
44
  "Qwen/Qwen2.5-Omni-7B",
45
  trust_remote_code=True,
46
  device_map="auto",
47
  torch_dtype=torch.bfloat16,
48
  attn_implementation="flash_attention_2"
49
  )
 
50
  storyteller = pipeline(
51
- task="text2text-generation",
52
  model=model,
53
  tokenizer=tokenizer,
54
  device_map="auto",
@@ -56,7 +56,8 @@ def load_clients():
56
  top_p=0.9,
57
  repetition_penalty=1.2,
58
  no_repeat_ngram_size=3,
59
- max_new_tokens=120
 
60
  )
61
  load_time = time.time() - t0
62
  st.text(f"✅ Story model loaded in {load_time:.1f}s (cached)")
@@ -78,19 +79,15 @@ def generate_story(caption: str) -> str:
78
  prompt = (
79
  "You are a creative children's-story author.\n"
80
  f"Image description: “{caption}”\n\n"
81
- "Write a coherent 50–100 word story that:\n"
82
- "1. Introduces the main character.\n"
83
- "2. Shows a simple problem or discovery.\n"
84
- "3. Has a happy resolution.\n"
85
- "4. Uses clear language for ages 3–8.\n"
86
- "5. Keeps each sentence under 20 words.\n"
87
  )
88
  t0 = time.time()
89
- result = storyteller(prompt)
90
  gen_time = time.time() - t0
91
  st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
92
 
93
- story = result[0]["generated_text"].strip()
 
94
  words = story.split()
95
  if len(words) > 100:
96
  story = " ".join(words[:100]) + ('.' if not story.endswith('.') else '')
@@ -128,3 +125,4 @@ if uploaded:
128
 
129
  # Footer
130
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
 
 
4
  from PIL import Image
5
  from io import BytesIO
6
  from huggingface_hub import InferenceApi, login
7
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
8
  import torch
9
  from gtts import gTTS
10
  import tempfile
11
 
12
  # —––––––– Requirements —–––––––
13
+ # Install transformers with remote code support:
 
 
14
  # pip install git+https://github.com/huggingface/transformers.git
15
+ # plus:
16
+ # pip install streamlit torch accelerate huggingface_hub sentencepiece pillow gTTS
17
 
18
  # —––––––– Page Config —–––––––
19
  st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
 
27
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
28
  login(hf_token)
29
 
30
+ # 1) BLIP captioning
31
  caption_client = InferenceApi(
32
  repo_id="Salesforce/blip-image-captioning-base",
33
  token=hf_token
34
  )
35
 
36
+ # 2) Load Qwen2.5-Omni causal LM
37
  t0 = time.time()
38
  tokenizer = AutoTokenizer.from_pretrained(
39
  "Qwen/Qwen2.5-Omni-7B",
40
  trust_remote_code=True
41
  )
42
+ model = AutoModelForCausalLM.from_pretrained(
43
  "Qwen/Qwen2.5-Omni-7B",
44
  trust_remote_code=True,
45
  device_map="auto",
46
  torch_dtype=torch.bfloat16,
47
  attn_implementation="flash_attention_2"
48
  )
49
+ # 3) Text-generation pipeline
50
  storyteller = pipeline(
51
+ task="text-generation",
52
  model=model,
53
  tokenizer=tokenizer,
54
  device_map="auto",
 
56
  top_p=0.9,
57
  repetition_penalty=1.2,
58
  no_repeat_ngram_size=3,
59
+ max_new_tokens=120,
60
+ return_full_text=False
61
  )
62
  load_time = time.time() - t0
63
  st.text(f"✅ Story model loaded in {load_time:.1f}s (cached)")
 
79
  prompt = (
80
  "You are a creative children's-story author.\n"
81
  f"Image description: “{caption}”\n\n"
82
+ "Write a coherent 50–100 word story."
 
 
 
 
 
83
  )
84
  t0 = time.time()
85
+ outputs = storyteller(prompt)
86
  gen_time = time.time() - t0
87
  st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
88
 
89
+ story = outputs[0]["generated_text"].strip()
90
+ # Enforce ≤100 words
91
  words = story.split()
92
  if len(words) > 100:
93
  story = " ".join(words[:100]) + ('.' if not story.endswith('.') else '')
 
125
 
126
  # Footer
127
  st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
128
+