Update app.py
Browse files
app.py
CHANGED
@@ -5,41 +5,60 @@ from PIL import Image
|
|
5 |
from io import BytesIO
|
6 |
from huggingface_hub import InferenceApi, login
|
7 |
from transformers import pipeline
|
|
|
|
|
8 |
from gtts import gTTS
|
9 |
import tempfile
|
10 |
|
11 |
# —––––––– Page Config —–––––––
|
12 |
-
st.set_page_config(page_title="Magic Story Generator", layout="centered")
|
13 |
-
st.title("📖✨ Turn Images into Children's Stories")
|
14 |
|
15 |
# —––––––– Load Clients & Pipelines (cached) —–––––––
|
16 |
@st.cache_resource(show_spinner=False)
|
17 |
def load_clients():
|
18 |
hf_token = st.secrets["HF_TOKEN"]
|
19 |
-
#
|
20 |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
|
21 |
login(hf_token)
|
22 |
|
23 |
-
# BLIP captioning via
|
24 |
caption_client = InferenceApi(
|
25 |
repo_id="Salesforce/blip-image-captioning-base",
|
26 |
token=hf_token
|
27 |
)
|
28 |
|
29 |
-
#
|
30 |
t0 = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
storyteller = pipeline(
|
32 |
task="text2text-generation",
|
33 |
-
model=
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
)
|
37 |
-
|
|
|
|
|
38 |
return caption_client, storyteller
|
39 |
|
40 |
caption_client, storyteller = load_clients()
|
41 |
|
42 |
-
|
43 |
# —––––––– Helpers —–––––––
|
44 |
def generate_caption(img: Image.Image) -> str:
|
45 |
buf = BytesIO()
|
@@ -49,32 +68,33 @@ def generate_caption(img: Image.Image) -> str:
|
|
49 |
return resp[0].get("generated_text", "").strip()
|
50 |
return ""
|
51 |
|
|
|
52 |
def generate_story(caption: str) -> str:
|
53 |
prompt = (
|
54 |
-
"You are a creative children
|
55 |
f"Image description: “{caption}”\n\n"
|
56 |
"Write a coherent 50–100 word story\n"
|
57 |
)
|
58 |
t0 = time.time()
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
|
63 |
-
|
|
|
64 |
words = story.split()
|
65 |
if len(words) > 100:
|
66 |
story = " ".join(words[:100])
|
67 |
-
if not story.endswith(
|
68 |
-
story +=
|
69 |
return story
|
70 |
|
71 |
-
|
72 |
# —––––––– Main App —–––––––
|
73 |
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
|
74 |
if uploaded:
|
75 |
img = Image.open(uploaded).convert("RGB")
|
76 |
if max(img.size) > 2048:
|
77 |
-
img.thumbnail((2048,
|
78 |
st.image(img, use_container_width=True)
|
79 |
|
80 |
with st.spinner("🔍 Generating caption..."):
|
|
|
5 |
from io import BytesIO
|
6 |
from huggingface_hub import InferenceApi, login
|
7 |
from transformers import pipeline
|
8 |
+
import torch
|
9 |
+
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniTokenizer
|
10 |
from gtts import gTTS
|
11 |
import tempfile
|
12 |
|
13 |
# —––––––– Page Config —–––––––
|
14 |
+
st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
|
15 |
+
st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")
|
16 |
|
17 |
# —––––––– Load Clients & Pipelines (cached) —–––––––
|
18 |
@st.cache_resource(show_spinner=False)
|
19 |
def load_clients():
|
20 |
hf_token = st.secrets["HF_TOKEN"]
|
21 |
+
# Authenticate for HF Hub
|
22 |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
|
23 |
login(hf_token)
|
24 |
|
25 |
+
# 1) BLIP captioning via HF Inference API
|
26 |
caption_client = InferenceApi(
|
27 |
repo_id="Salesforce/blip-image-captioning-base",
|
28 |
token=hf_token
|
29 |
)
|
30 |
|
31 |
+
# 2) Qwen2.5-Omni story generator
|
32 |
t0 = time.time()
|
33 |
+
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
34 |
+
"Qwen/Qwen2.5-Omni-7B",
|
35 |
+
device_map="auto",
|
36 |
+
torch_dtype=torch.bfloat16,
|
37 |
+
attn_implementation="flash_attention_2",
|
38 |
+
trust_remote_code=True
|
39 |
+
)
|
40 |
+
tokenizer = Qwen2_5OmniTokenizer.from_pretrained(
|
41 |
+
"Qwen/Qwen2.5-Omni-7B",
|
42 |
+
trust_remote_code=True
|
43 |
+
)
|
44 |
storyteller = pipeline(
|
45 |
task="text2text-generation",
|
46 |
+
model=model,
|
47 |
+
tokenizer=tokenizer,
|
48 |
+
device_map="auto",
|
49 |
+
temperature=0.7,
|
50 |
+
top_p=0.9,
|
51 |
+
repetition_penalty=1.2,
|
52 |
+
no_repeat_ngram_size=3,
|
53 |
+
max_new_tokens=120
|
54 |
)
|
55 |
+
load_time = time.time() - t0
|
56 |
+
st.text(f"✅ Story model loaded in {load_time:.1f}s (cached thereafter)")
|
57 |
+
|
58 |
return caption_client, storyteller
|
59 |
|
60 |
caption_client, storyteller = load_clients()
|
61 |
|
|
|
62 |
# —––––––– Helpers —–––––––
|
63 |
def generate_caption(img: Image.Image) -> str:
|
64 |
buf = BytesIO()
|
|
|
68 |
return resp[0].get("generated_text", "").strip()
|
69 |
return ""
|
70 |
|
71 |
+
|
72 |
def generate_story(caption: str) -> str:
|
73 |
prompt = (
|
74 |
+
"You are a creative children's-story author.\n"
|
75 |
f"Image description: “{caption}”\n\n"
|
76 |
"Write a coherent 50–100 word story\n"
|
77 |
)
|
78 |
t0 = time.time()
|
79 |
+
outputs = storyteller(prompt)
|
80 |
+
gen_time = time.time() - t0
|
81 |
+
st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
|
82 |
|
83 |
+
story = outputs[0]["generated_text"].strip()
|
84 |
+
# Enforce ≤100 words
|
85 |
words = story.split()
|
86 |
if len(words) > 100:
|
87 |
story = " ".join(words[:100])
|
88 |
+
if not story.endswith('.'):
|
89 |
+
story += '.'
|
90 |
return story
|
91 |
|
|
|
92 |
# —––––––– Main App —–––––––
|
93 |
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
|
94 |
if uploaded:
|
95 |
img = Image.open(uploaded).convert("RGB")
|
96 |
if max(img.size) > 2048:
|
97 |
+
img.thumbnail((2048,2048))
|
98 |
st.image(img, use_container_width=True)
|
99 |
|
100 |
with st.spinner("🔍 Generating caption..."):
|