mayf commited on
Commit
fd1d947
·
verified ·
1 Parent(s): 2c0fb69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -56
app.py CHANGED
@@ -3,102 +3,89 @@ import time
3
  import streamlit as st
4
  from PIL import Image
5
  from io import BytesIO
6
- from huggingface_hub import InferenceApi, login
7
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
8
- import torch
9
  from gtts import gTTS
10
  import tempfile
11
 
12
  # —––––––– Requirements —–––––––
13
- # Install transformers with remote code support:
14
- # pip install git+https://github.com/huggingface/transformers.git
15
- # plus:
16
- # pip install streamlit torch accelerate huggingface_hub sentencepiece pillow gTTS
17
 
18
  # —––––––– Page Config —–––––––
19
- st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
20
- st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")
21
 
22
  # —––––––– Load Clients & Pipelines (cached) —–––––––
23
  @st.cache_resource(show_spinner=False)
24
  def load_clients():
25
- hf_token = st.secrets["HF_TOKEN"]
26
- # Authenticate for Hugging Face Hub
27
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
28
- login(hf_token)
29
-
30
- # 1) BLIP captioning
31
- caption_client = InferenceApi(
32
- repo_id="Salesforce/blip-image-captioning-base",
33
- token=hf_token
 
 
34
  )
35
 
36
- # 2) Load Qwen2.5-Omni causal LM
37
- t0 = time.time()
38
- tokenizer = AutoTokenizer.from_pretrained(
39
- "Qwen/Qwen2.5-Omni-7B",
40
- trust_remote_code=True
41
- )
42
- model = AutoModelForCausalLM.from_pretrained(
43
- "Qwen/Qwen2.5-Omni-7B",
44
- trust_remote_code=True,
45
- device_map="auto",
46
- torch_dtype=torch.bfloat16,
47
- attn_implementation="flash_attention_2"
48
- )
49
- # 3) Text-generation pipeline
50
  storyteller = pipeline(
51
  task="text-generation",
52
- model=model,
53
- tokenizer=tokenizer,
54
- device_map="auto",
55
- temperature=0.7,
 
56
  top_p=0.9,
57
- repetition_penalty=1.2,
58
- no_repeat_ngram_size=3,
59
  max_new_tokens=120,
60
  return_full_text=False
61
  )
62
- load_time = time.time() - t0
63
- st.text(f"✅ Story model loaded in {load_time:.1f}s (cached)")
64
- return caption_client, storyteller
65
 
66
- caption_client, storyteller = load_clients()
 
 
67
 
68
  # —––––––– Helpers —–––––––
69
  def generate_caption(img: Image.Image) -> str:
70
- buf = BytesIO()
71
- img.save(buf, format="JPEG")
72
- resp = caption_client(data=buf.getvalue())
73
- if isinstance(resp, list) and resp:
74
- return resp[0].get("generated_text", "").strip()
75
  return ""
76
 
77
 
78
  def generate_story(caption: str) -> str:
 
79
  prompt = (
80
- "You are a creative children's-story author.\n"
81
- f"Image description: “{caption}”\n\n"
82
- "Write a coherent 50–100 word story."
83
  )
 
84
  t0 = time.time()
85
- outputs = storyteller(prompt)
 
 
86
  gen_time = time.time() - t0
87
- st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
88
 
89
- story = outputs[0]["generated_text"].strip()
90
- # Enforce 100 words
91
  words = story.split()
92
  if len(words) > 100:
93
  story = " ".join(words[:100]) + ('.' if not story.endswith('.') else '')
94
  return story
95
 
96
  # —––––––– Main App —–––––––
97
- uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
98
  if uploaded:
99
  img = Image.open(uploaded).convert("RGB")
100
  if max(img.size) > 2048:
101
- img.thumbnail((2048,2048))
102
  st.image(img, use_container_width=True)
103
 
104
  with st.spinner("🔍 Generating caption..."):
 
3
  import streamlit as st
4
  from PIL import Image
5
  from io import BytesIO
6
+ from transformers import pipeline
7
+ from huggingface_hub import login
 
8
  from gtts import gTTS
9
  import tempfile
10
 
11
  # —––––––– Requirements —–––––––
12
+ # pip install streamlit pillow gTTS transformers huggingface_hub
 
 
 
13
 
14
  # —––––––– Page Config —–––––––
15
+ st.set_page_config(page_title="Magic Story Generator (Local Pipeline)", layout="centered")
16
+ st.title("📖✨ Turn Images into Children's Stories")
17
 
18
  # —––––––– Load Clients & Pipelines (cached) —–––––––
19
  @st.cache_resource(show_spinner=False)
20
  def load_clients():
21
+ # Authenticate to pull private or remote-code models if needed
22
+ hf_token = st.secrets.get("HF_TOKEN")
23
+ if hf_token:
24
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
25
+ login(hf_token)
26
+
27
+ # 1) Image-captioning pipeline (BLIP)
28
+ captioner = pipeline(
29
+ task="image-to-text",
30
+ model="Salesforce/blip-image-captioning-base",
31
+ device=-1 # CPU; change to 0 for GPU
32
  )
33
 
34
+ # 2) Story-generation pipeline (DeepSeek-R1-Distill-Qwen)
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  storyteller = pipeline(
36
  task="text-generation",
37
+ model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
38
+ tokenizer="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
39
+ trust_remote_code=True,
40
+ device=-1, # CPU; set 0+ for GPU
41
+ temperature=0.6,
42
  top_p=0.9,
43
+ repetition_penalty=1.1,
44
+ no_repeat_ngram_size=2,
45
  max_new_tokens=120,
46
  return_full_text=False
47
  )
 
 
 
48
 
49
+ return captioner, storyteller
50
+
51
+ captioner, storyteller = load_clients()
52
 
53
  # —––––––– Helpers —–––––––
54
  def generate_caption(img: Image.Image) -> str:
55
+ # Use the BLIP pipeline to generate a caption
56
+ result = captioner(img)
57
+ if isinstance(result, list) and result:
58
+ return result[0].get("generated_text", "").strip()
 
59
  return ""
60
 
61
 
62
  def generate_story(caption: str) -> str:
63
+ # Build a simple prompt incorporating the caption
64
  prompt = (
65
+ f"Image description: {caption}\n"
66
+ "Write a coherent 50-100 word children's story that flows naturally."
 
67
  )
68
+
69
  t0 = time.time()
70
+ outputs = storyteller(
71
+ prompt
72
+ )
73
  gen_time = time.time() - t0
74
+ st.text(f"⏱ Generated in {gen_time:.1f}s")
75
 
76
+ story = outputs[0].get("generated_text", "").strip()
77
+ # Truncate to 100 words
78
  words = story.split()
79
  if len(words) > 100:
80
  story = " ".join(words[:100]) + ('.' if not story.endswith('.') else '')
81
  return story
82
 
83
  # —––––––– Main App —–––––––
84
+ uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
85
  if uploaded:
86
  img = Image.open(uploaded).convert("RGB")
87
  if max(img.size) > 2048:
88
+ img.thumbnail((2048, 2048))
89
  st.image(img, use_container_width=True)
90
 
91
  with st.spinner("🔍 Generating caption..."):