Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,26 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os, json, random, gc
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
from PIL import Image
|
5 |
import gradio as gr
|
6 |
-
from gradio.themes import Soft
|
7 |
from diffusers import StableDiffusionXLPipeline
|
8 |
import open_clip
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from IP_Composer.IP_Adapter.ip_adapter import IPAdapterXL
|
11 |
-
from IP_Composer.perform_swap import (
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
import spaces
|
16 |
|
17 |
# ─────────────────────────────
|
18 |
-
# 1
|
19 |
# ─────────────────────────────
|
20 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
|
22 |
# ─────────────────────────────
|
23 |
-
# 2
|
24 |
# ─────────────────────────────
|
25 |
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
26 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
@@ -30,32 +39,30 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
30 |
)
|
31 |
|
32 |
# ─────────────────────────────
|
33 |
-
# 3
|
34 |
# ─────────────────────────────
|
35 |
-
image_encoder_repo
|
36 |
-
image_encoder_subfolder =
|
37 |
ip_ckpt = hf_hub_download(
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
41 |
)
|
42 |
-
ip_model = IPAdapterXL(pipe, image_encoder_repo,
|
43 |
-
image_encoder_subfolder,
|
44 |
-
ip_ckpt, device)
|
45 |
|
46 |
# ─────────────────────────────
|
47 |
-
# 4
|
48 |
# ─────────────────────────────
|
49 |
clip_model, _, preprocess = open_clip.create_model_and_transforms(
|
50 |
-
|
51 |
)
|
52 |
clip_model.to(device)
|
53 |
tokenizer = open_clip.get_tokenizer(
|
54 |
-
|
55 |
)
|
56 |
|
57 |
# ─────────────────────────────
|
58 |
-
# 5
|
59 |
# ─────────────────────────────
|
60 |
CONCEPTS_MAP = {
|
61 |
"age": "age_descriptions.npy",
|
@@ -72,87 +79,158 @@ CONCEPTS_MAP = {
|
|
72 |
"daytime": "times_of_day_descriptions.npy",
|
73 |
"pose": "person_poses_descriptions.npy",
|
74 |
"season": "season_descriptions.npy",
|
75 |
-
"material": "material_descriptions_with_gems.npy"
|
76 |
}
|
77 |
RANKS_MAP = {
|
78 |
-
"age": 30,
|
79 |
-
"
|
80 |
-
"
|
81 |
-
"
|
82 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
}
|
84 |
concept_options = list(CONCEPTS_MAP.keys())
|
85 |
|
86 |
# ─────────────────────────────
|
87 |
-
# 6
|
88 |
# ─────────────────────────────
|
89 |
examples = [
|
90 |
-
[
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
]
|
104 |
|
105 |
# ----------------------------------------------------------
|
106 |
-
# 7
|
107 |
# ----------------------------------------------------------
|
108 |
-
def generate_examples(
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
|
122 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
123 |
return random.randint(0, MAX_SEED) if randomize_seed else seed
|
124 |
|
125 |
-
|
|
|
126 |
return RANKS_MAP.get(concept_name, 30)
|
127 |
|
|
|
128 |
@spaces.GPU
|
129 |
def match_image_to_concept(image):
|
130 |
if image is None:
|
131 |
return None
|
132 |
-
img_pil
|
133 |
img_embed = get_image_embeds(img_pil, clip_model, preprocess, device)
|
134 |
-
|
135 |
-
|
136 |
-
for concept_name, concept_file in CONCEPTS_MAP.items():
|
137 |
try:
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
148 |
except Exception as e:
|
149 |
-
print(
|
150 |
-
if
|
151 |
-
|
152 |
-
gr.Info(f"Image automatically matched to concept: {
|
153 |
-
return
|
154 |
return None
|
155 |
|
|
|
156 |
@spaces.GPU
|
157 |
def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
|
158 |
image = preproc(pil_image)[np.newaxis, :, :, :]
|
@@ -160,47 +238,63 @@ def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device
|
|
160 |
embeds = model.encode_image(image.to(dev))
|
161 |
return embeds.cpu().detach().numpy()
|
162 |
|
|
|
163 |
@spaces.GPU
|
164 |
def process_images(
|
165 |
-
base_image,
|
166 |
-
concept_image1,
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
):
|
174 |
-
base_pil
|
175 |
base_embed = get_image_embeds(base_pil, clip_model, preprocess, device)
|
176 |
|
177 |
concept_images, concept_descs, ranks = [], [], []
|
178 |
skip = [False, False, False]
|
179 |
|
180 |
-
#
|
181 |
if concept_image1 is None:
|
182 |
return None, "Please upload at least one concept image"
|
183 |
concept_images.append(concept_image1)
|
184 |
if use_concpet_from_file_1 and concpet_from_file_1 is not None:
|
185 |
-
concept_descs.append(concpet_from_file_1)
|
|
|
186 |
else:
|
187 |
concept_descs.append(CONCEPTS_MAP[concept_name1])
|
188 |
ranks.append(rank1)
|
189 |
|
190 |
-
#
|
191 |
if concept_image2 is not None:
|
192 |
concept_images.append(concept_image2)
|
193 |
if use_concpet_from_file_2 and concpet_from_file_2 is not None:
|
194 |
-
concept_descs.append(concpet_from_file_2)
|
|
|
195 |
else:
|
196 |
concept_descs.append(CONCEPTS_MAP[concept_name2])
|
197 |
ranks.append(rank2)
|
198 |
|
199 |
-
#
|
200 |
if concept_image3 is not None:
|
201 |
concept_images.append(concept_image3)
|
202 |
if use_concpet_from_file_3 and concpet_from_file_3 is not None:
|
203 |
-
concept_descs.append(concpet_from_file_3)
|
|
|
204 |
else:
|
205 |
concept_descs.append(CONCEPTS_MAP[concept_name3])
|
206 |
ranks.append(rank3)
|
@@ -220,29 +314,47 @@ def process_images(
|
|
220 |
{"embed": e, "projection_matrix": p}
|
221 |
for e, p in zip(concept_embeds, proj_mats)
|
222 |
]
|
223 |
-
|
224 |
-
base_embed,
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
227 |
)
|
228 |
-
return
|
|
|
229 |
|
230 |
@spaces.GPU
|
231 |
def get_text_embeddings(concept_file):
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
|
237 |
def process_and_display(
|
238 |
-
base_image,
|
239 |
-
concept_image1,
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
):
|
247 |
if base_image is None:
|
248 |
raise gr.Error("Please upload a base image")
|
@@ -250,32 +362,45 @@ def process_and_display(
|
|
250 |
raise gr.Error("Choose at least one concept image")
|
251 |
|
252 |
return process_images(
|
253 |
-
base_image,
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
)
|
261 |
|
|
|
262 |
# ----------------------------------------------------------
|
263 |
-
# 8
|
264 |
# ----------------------------------------------------------
|
265 |
-
demo_theme = Soft(
|
266 |
-
primary_hue="purple",
|
267 |
-
font=[gr.themes.GoogleFont("Inter")]
|
268 |
-
)
|
269 |
css = """
|
270 |
body{
|
271 |
background:#0f0c29;
|
272 |
background:linear-gradient(135deg,#0f0c29,#302b63,#24243e);
|
273 |
}
|
274 |
-
#header{
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
|
|
|
|
279 |
.gradio-container{max-width:1024px !important;margin:0 auto}
|
280 |
.card{
|
281 |
border-radius:18px;
|
@@ -288,137 +413,186 @@ body{
|
|
288 |
"""
|
289 |
|
290 |
# ----------------------------------------------------------
|
291 |
-
# 9
|
292 |
# ----------------------------------------------------------
|
293 |
example_gallery = [
|
294 |
-
[
|
295 |
-
[
|
296 |
-
[
|
297 |
]
|
298 |
|
299 |
with gr.Blocks(css=css, theme=demo_theme) as demo:
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
concpet_from_file_1 =
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
use_concpet_from_file_3 =
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
313 |
with gr.Row(equal_height=True):
|
314 |
-
# Base image card
|
315 |
with gr.Column(elem_classes="card"):
|
316 |
-
base_image = gr.Image(
|
317 |
-
|
|
|
318 |
|
319 |
-
# Concept cards (1 · 2 · 3)
|
320 |
for idx in (1, 2, 3):
|
321 |
with gr.Column(elem_classes="card"):
|
322 |
locals()[f"concept_image{idx}"] = gr.Image(
|
323 |
-
label=f"Concept Image {idx}"
|
324 |
-
|
|
|
|
|
|
|
|
|
325 |
)
|
326 |
locals()[f"concept_name{idx}"] = gr.Dropdown(
|
327 |
-
concept_options,
|
|
|
328 |
value=None if idx != 1 else "age",
|
329 |
-
info="Pick concept type"
|
330 |
)
|
331 |
with gr.Accordion("💡 Or use a new concept 👇", open=False):
|
332 |
-
gr.Markdown(
|
333 |
-
|
|
|
|
|
334 |
if idx == 1:
|
335 |
-
concept_file_1 = gr.File(
|
336 |
-
|
|
|
337 |
elif idx == 2:
|
338 |
-
concept_file_2 = gr.File(
|
339 |
-
|
|
|
340 |
else:
|
341 |
-
concept_file_3 = gr.File(
|
342 |
-
|
|
|
343 |
|
344 |
-
# ─── Advanced options card (full width)
|
345 |
with gr.Column(elem_classes="card"):
|
346 |
with gr.Accordion("⚙️ Advanced options", open=False):
|
347 |
-
prompt = gr.Textbox(
|
348 |
-
|
349 |
-
|
350 |
-
|
|
|
351 |
with gr.Row():
|
352 |
-
scale = gr.Slider(0.1, 2.0,
|
353 |
randomize_seed = gr.Checkbox(True, label="Randomize seed")
|
354 |
seed = gr.Number(value=0, label="Seed", precision=0)
|
355 |
-
gr.Markdown(
|
|
|
|
|
356 |
with gr.Row():
|
357 |
-
rank1 = gr.Slider(1, 150,
|
358 |
-
rank2 = gr.Slider(1, 150,
|
359 |
-
rank3 = gr.Slider(1, 150,
|
360 |
|
361 |
-
# ─── Output & Generate button
|
362 |
with gr.Column(elem_classes="card"):
|
363 |
output_image = gr.Image(show_label=False, height=480)
|
364 |
submit_btn = gr.Button("🔮 Generate", variant="primary", size="lg")
|
365 |
|
366 |
-
# ─── Ready-made Gallery
|
367 |
gr.Markdown("### 🔥 Ready-made examples")
|
368 |
-
gr.Gallery(example_gallery, label="
|
369 |
-
columns=[3], height="auto")
|
370 |
|
371 |
-
# ─── Example usage (kept for quick test)
|
372 |
gr.Examples(
|
373 |
examples,
|
374 |
-
inputs=[
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
outputs=[output_image],
|
380 |
fn=generate_examples,
|
381 |
-
cache_examples=False
|
382 |
)
|
383 |
|
384 |
-
#
|
385 |
-
concept_file_1.upload(
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
concept_file_3.
|
396 |
-
|
397 |
-
|
398 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
concept_name1.select(change_rank_default, [concept_name1], [rank1])
|
400 |
concept_name2.select(change_rank_default, [concept_name2], [rank2])
|
401 |
concept_name3.select(change_rank_default, [concept_name3], [rank3])
|
402 |
|
403 |
-
#
|
404 |
concept_image1.upload(match_image_to_concept, [concept_image1], [concept_name1])
|
405 |
concept_image2.upload(match_image_to_concept, [concept_image2], [concept_name2])
|
406 |
concept_image3.upload(match_image_to_concept, [concept_image3], [concept_name3])
|
407 |
|
408 |
-
#
|
409 |
-
submit_btn.click(randomize_seed_fn, [seed, randomize_seed], seed)
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
|
420 |
# ─────────────────────────────
|
421 |
-
# 10
|
422 |
# ─────────────────────────────
|
423 |
if __name__ == "__main__":
|
424 |
demo.launch()
|
|
|
1 |
+
# ===========================================
|
2 |
+
# IP-Composer 🌅✚🖌️ – FULL IMPROVED UI SCRIPT
|
3 |
+
# (기능 동일, UI·테마·갤러리 강화 + FileNotFoundError 수정)
|
4 |
+
# ===========================================
|
5 |
+
|
6 |
import os, json, random, gc
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
from PIL import Image
|
10 |
import gradio as gr
|
11 |
+
from gradio.themes import Soft
|
12 |
from diffusers import StableDiffusionXLPipeline
|
13 |
import open_clip
|
14 |
from huggingface_hub import hf_hub_download
|
15 |
from IP_Composer.IP_Adapter.ip_adapter import IPAdapterXL
|
16 |
+
from IP_Composer.perform_swap import (
|
17 |
+
compute_dataset_embeds_svd,
|
18 |
+
get_modified_images_embeds_composition,
|
19 |
+
)
|
20 |
+
from IP_Composer.generate_text_embeddings import (
|
21 |
+
load_descriptions,
|
22 |
+
generate_embeddings,
|
23 |
+
)
|
24 |
import spaces
|
25 |
|
26 |
# ─────────────────────────────
|
27 |
+
# 1 · Device
|
28 |
# ─────────────────────────────
|
29 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
|
31 |
# ─────────────────────────────
|
32 |
+
# 2 · Stable-Diffusion XL
|
33 |
# ─────────────────────────────
|
34 |
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
35 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
39 |
)
|
40 |
|
41 |
# ─────────────────────────────
|
42 |
+
# 3 · IP-Adapter
|
43 |
# ─────────────────────────────
|
44 |
+
image_encoder_repo = "h94/IP-Adapter"
|
45 |
+
image_encoder_subfolder = "models/image_encoder"
|
46 |
ip_ckpt = hf_hub_download(
|
47 |
+
"h94/IP-Adapter", subfolder="sdxl_models", filename="ip-adapter_sdxl_vit-h.bin"
|
48 |
+
)
|
49 |
+
ip_model = IPAdapterXL(
|
50 |
+
pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device
|
51 |
)
|
|
|
|
|
|
|
52 |
|
53 |
# ─────────────────────────────
|
54 |
+
# 4 · CLIP
|
55 |
# ─────────────────────────────
|
56 |
clip_model, _, preprocess = open_clip.create_model_and_transforms(
|
57 |
+
"hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
58 |
)
|
59 |
clip_model.to(device)
|
60 |
tokenizer = open_clip.get_tokenizer(
|
61 |
+
"hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
62 |
)
|
63 |
|
64 |
# ─────────────────────────────
|
65 |
+
# 5 · Concept maps
|
66 |
# ─────────────────────────────
|
67 |
CONCEPTS_MAP = {
|
68 |
"age": "age_descriptions.npy",
|
|
|
79 |
"daytime": "times_of_day_descriptions.npy",
|
80 |
"pose": "person_poses_descriptions.npy",
|
81 |
"season": "season_descriptions.npy",
|
82 |
+
"material": "material_descriptions_with_gems.npy",
|
83 |
}
|
84 |
RANKS_MAP = {
|
85 |
+
"age": 30,
|
86 |
+
"animal fur": 80,
|
87 |
+
"dogs": 30,
|
88 |
+
"emotions": 30,
|
89 |
+
"flowers": 30,
|
90 |
+
"fruit/vegtable": 30,
|
91 |
+
"outfit type": 30,
|
92 |
+
"outfit pattern (including color)": 80,
|
93 |
+
"patterns": 80,
|
94 |
+
"patterns (including color)": 80,
|
95 |
+
"vehicle": 30,
|
96 |
+
"daytime": 30,
|
97 |
+
"pose": 30,
|
98 |
+
"season": 30,
|
99 |
+
"material": 80,
|
100 |
}
|
101 |
concept_options = list(CONCEPTS_MAP.keys())
|
102 |
|
103 |
# ─────────────────────────────
|
104 |
+
# 6 · Example tuples (base_img, c1_img, …)
|
105 |
# ─────────────────────────────
|
106 |
examples = [
|
107 |
+
[
|
108 |
+
"./IP_Composer/assets/patterns/base.jpg",
|
109 |
+
"./IP_Composer/assets/patterns/pattern.png",
|
110 |
+
"patterns (including color)",
|
111 |
+
None,
|
112 |
+
None,
|
113 |
+
None,
|
114 |
+
None,
|
115 |
+
80,
|
116 |
+
30,
|
117 |
+
30,
|
118 |
+
None,
|
119 |
+
1.0,
|
120 |
+
0,
|
121 |
+
30,
|
122 |
+
],
|
123 |
+
[
|
124 |
+
"./IP_Composer/assets/flowers/base.png",
|
125 |
+
"./IP_Composer/assets/flowers/concept.png",
|
126 |
+
"flowers",
|
127 |
+
None,
|
128 |
+
None,
|
129 |
+
None,
|
130 |
+
None,
|
131 |
+
30,
|
132 |
+
30,
|
133 |
+
30,
|
134 |
+
None,
|
135 |
+
1.0,
|
136 |
+
0,
|
137 |
+
30,
|
138 |
+
],
|
139 |
+
[
|
140 |
+
"./IP_Composer/assets/materials/base.png",
|
141 |
+
"./IP_Composer/assets/materials/concept.jpg",
|
142 |
+
"material",
|
143 |
+
None,
|
144 |
+
None,
|
145 |
+
None,
|
146 |
+
None,
|
147 |
+
80,
|
148 |
+
30,
|
149 |
+
30,
|
150 |
+
None,
|
151 |
+
1.0,
|
152 |
+
0,
|
153 |
+
30,
|
154 |
+
],
|
155 |
]
|
156 |
|
157 |
# ----------------------------------------------------------
|
158 |
+
# 7 · Utility functions
|
159 |
# ----------------------------------------------------------
|
160 |
+
def generate_examples(
|
161 |
+
base_image,
|
162 |
+
concept_image1,
|
163 |
+
concept_name1,
|
164 |
+
concept_image2,
|
165 |
+
concept_name2,
|
166 |
+
concept_image3,
|
167 |
+
concept_name3,
|
168 |
+
rank1,
|
169 |
+
rank2,
|
170 |
+
rank3,
|
171 |
+
prompt,
|
172 |
+
scale,
|
173 |
+
seed,
|
174 |
+
num_inference_steps,
|
175 |
+
):
|
176 |
+
return process_and_display(
|
177 |
+
base_image,
|
178 |
+
concept_image1,
|
179 |
+
concept_name1,
|
180 |
+
concept_image2,
|
181 |
+
concept_name2,
|
182 |
+
concept_image3,
|
183 |
+
concept_name3,
|
184 |
+
rank1,
|
185 |
+
rank2,
|
186 |
+
rank3,
|
187 |
+
prompt,
|
188 |
+
scale,
|
189 |
+
seed,
|
190 |
+
num_inference_steps,
|
191 |
+
)
|
192 |
+
|
193 |
|
194 |
MAX_SEED = np.iinfo(np.int32).max
|
195 |
+
|
196 |
+
|
197 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
198 |
return random.randint(0, MAX_SEED) if randomize_seed else seed
|
199 |
|
200 |
+
|
201 |
+
def change_rank_default(concept_name):
|
202 |
return RANKS_MAP.get(concept_name, 30)
|
203 |
|
204 |
+
|
205 |
@spaces.GPU
|
206 |
def match_image_to_concept(image):
|
207 |
if image is None:
|
208 |
return None
|
209 |
+
img_pil = Image.fromarray(image).convert("RGB")
|
210 |
img_embed = get_image_embeds(img_pil, clip_model, preprocess, device)
|
211 |
+
sims = {}
|
212 |
+
for cname, cfile in CONCEPTS_MAP.items():
|
|
|
213 |
try:
|
214 |
+
with open(f"./IP_Composer/text_embeddings/{cfile}", "rb") as f:
|
215 |
+
embeds = np.load(f)
|
216 |
+
scores = []
|
217 |
+
for e in embeds:
|
218 |
+
s = np.dot(
|
219 |
+
img_embed.flatten() / np.linalg.norm(img_embed),
|
220 |
+
e.flatten() / np.linalg.norm(e),
|
221 |
+
)
|
222 |
+
scores.append(s)
|
223 |
+
scores.sort(reverse=True)
|
224 |
+
sims[cname] = np.mean(scores[:5])
|
225 |
except Exception as e:
|
226 |
+
print(cname, "error:", e)
|
227 |
+
if sims:
|
228 |
+
best = max(sims, key=sims.get)
|
229 |
+
gr.Info(f"Image automatically matched to concept: {best}")
|
230 |
+
return best
|
231 |
return None
|
232 |
|
233 |
+
|
234 |
@spaces.GPU
|
235 |
def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
|
236 |
image = preproc(pil_image)[np.newaxis, :, :, :]
|
|
|
238 |
embeds = model.encode_image(image.to(dev))
|
239 |
return embeds.cpu().detach().numpy()
|
240 |
|
241 |
+
|
242 |
@spaces.GPU
|
243 |
def process_images(
|
244 |
+
base_image,
|
245 |
+
concept_image1,
|
246 |
+
concept_name1,
|
247 |
+
concept_image2=None,
|
248 |
+
concept_name2=None,
|
249 |
+
concept_image3=None,
|
250 |
+
concept_name3=None,
|
251 |
+
rank1=10,
|
252 |
+
rank2=10,
|
253 |
+
rank3=10,
|
254 |
+
prompt=None,
|
255 |
+
scale=1.0,
|
256 |
+
seed=420,
|
257 |
+
num_inference_steps=50,
|
258 |
+
concpet_from_file_1=None,
|
259 |
+
concpet_from_file_2=None,
|
260 |
+
concpet_from_file_3=None,
|
261 |
+
use_concpet_from_file_1=False,
|
262 |
+
use_concpet_from_file_2=False,
|
263 |
+
use_concpet_from_file_3=False,
|
264 |
):
|
265 |
+
base_pil = Image.fromarray(base_image).convert("RGB")
|
266 |
base_embed = get_image_embeds(base_pil, clip_model, preprocess, device)
|
267 |
|
268 |
concept_images, concept_descs, ranks = [], [], []
|
269 |
skip = [False, False, False]
|
270 |
|
271 |
+
# concept 1
|
272 |
if concept_image1 is None:
|
273 |
return None, "Please upload at least one concept image"
|
274 |
concept_images.append(concept_image1)
|
275 |
if use_concpet_from_file_1 and concpet_from_file_1 is not None:
|
276 |
+
concept_descs.append(concpet_from_file_1)
|
277 |
+
skip[0] = True
|
278 |
else:
|
279 |
concept_descs.append(CONCEPTS_MAP[concept_name1])
|
280 |
ranks.append(rank1)
|
281 |
|
282 |
+
# concept 2
|
283 |
if concept_image2 is not None:
|
284 |
concept_images.append(concept_image2)
|
285 |
if use_concpet_from_file_2 and concpet_from_file_2 is not None:
|
286 |
+
concept_descs.append(concpet_from_file_2)
|
287 |
+
skip[1] = True
|
288 |
else:
|
289 |
concept_descs.append(CONCEPTS_MAP[concept_name2])
|
290 |
ranks.append(rank2)
|
291 |
|
292 |
+
# concept 3
|
293 |
if concept_image3 is not None:
|
294 |
concept_images.append(concept_image3)
|
295 |
if use_concpet_from_file_3 and concpet_from_file_3 is not None:
|
296 |
+
concept_descs.append(concpet_from_file_3)
|
297 |
+
skip[2] = True
|
298 |
else:
|
299 |
concept_descs.append(CONCEPTS_MAP[concept_name3])
|
300 |
ranks.append(rank3)
|
|
|
314 |
{"embed": e, "projection_matrix": p}
|
315 |
for e, p in zip(concept_embeds, proj_mats)
|
316 |
]
|
317 |
+
modified = get_modified_images_embeds_composition(
|
318 |
+
base_embed,
|
319 |
+
projections_data,
|
320 |
+
ip_model,
|
321 |
+
prompt=prompt,
|
322 |
+
scale=scale,
|
323 |
+
num_samples=1,
|
324 |
+
seed=seed,
|
325 |
+
num_inference_steps=num_inference_steps,
|
326 |
)
|
327 |
+
return modified[0]
|
328 |
+
|
329 |
|
330 |
@spaces.GPU
|
331 |
def get_text_embeddings(concept_file):
|
332 |
+
descs = load_descriptions(concept_file)
|
333 |
+
embeds = generate_embeddings(descs, clip_model, tokenizer, device, batch_size=100)
|
334 |
+
return embeds, True
|
335 |
+
|
336 |
|
337 |
def process_and_display(
|
338 |
+
base_image,
|
339 |
+
concept_image1,
|
340 |
+
concept_name1="age",
|
341 |
+
concept_image2=None,
|
342 |
+
concept_name2=None,
|
343 |
+
concept_image3=None,
|
344 |
+
concept_name3=None,
|
345 |
+
rank1=30,
|
346 |
+
rank2=30,
|
347 |
+
rank3=30,
|
348 |
+
prompt=None,
|
349 |
+
scale=1.0,
|
350 |
+
seed=0,
|
351 |
+
num_inference_steps=50,
|
352 |
+
concpet_from_file_1=None,
|
353 |
+
concpet_from_file_2=None,
|
354 |
+
concpet_from_file_3=None,
|
355 |
+
use_concpet_from_file_1=False,
|
356 |
+
use_concpet_from_file_2=False,
|
357 |
+
use_concpet_from_file_3=False,
|
358 |
):
|
359 |
if base_image is None:
|
360 |
raise gr.Error("Please upload a base image")
|
|
|
362 |
raise gr.Error("Choose at least one concept image")
|
363 |
|
364 |
return process_images(
|
365 |
+
base_image,
|
366 |
+
concept_image1,
|
367 |
+
concept_name1,
|
368 |
+
concept_image2,
|
369 |
+
concept_name2,
|
370 |
+
concept_image3,
|
371 |
+
concept_name3,
|
372 |
+
rank1,
|
373 |
+
rank2,
|
374 |
+
rank3,
|
375 |
+
prompt,
|
376 |
+
scale,
|
377 |
+
seed,
|
378 |
+
num_inference_steps,
|
379 |
+
concpet_from_file_1,
|
380 |
+
concpet_from_file_2,
|
381 |
+
concpet_from_file_3,
|
382 |
+
use_concpet_from_file_1,
|
383 |
+
use_concpet_from_file_2,
|
384 |
+
use_concpet_from_file_3,
|
385 |
)
|
386 |
|
387 |
+
|
388 |
# ----------------------------------------------------------
|
389 |
+
# 8 · THEME & CSS
|
390 |
# ----------------------------------------------------------
|
391 |
+
demo_theme = Soft(primary_hue="purple", font=[gr.themes.GoogleFont("Inter")])
|
|
|
|
|
|
|
392 |
css = """
|
393 |
body{
|
394 |
background:#0f0c29;
|
395 |
background:linear-gradient(135deg,#0f0c29,#302b63,#24243e);
|
396 |
}
|
397 |
+
#header{
|
398 |
+
text-align:center;
|
399 |
+
padding:24px 0 8px;
|
400 |
+
font-weight:700;
|
401 |
+
font-size:2.1rem;
|
402 |
+
color:#ffffff;
|
403 |
+
}
|
404 |
.gradio-container{max-width:1024px !important;margin:0 auto}
|
405 |
.card{
|
406 |
border-radius:18px;
|
|
|
413 |
"""
|
414 |
|
415 |
# ----------------------------------------------------------
|
416 |
+
# 9 · UI
|
417 |
# ----------------------------------------------------------
|
418 |
example_gallery = [
|
419 |
+
["./IP_Composer/assets/patterns/base.jpg", "Patterns demo"],
|
420 |
+
["./IP_Composer/assets/flowers/base.png", "Flowers demo"],
|
421 |
+
["./IP_Composer/assets/materials/base.png", "Material demo"],
|
422 |
]
|
423 |
|
424 |
with gr.Blocks(css=css, theme=demo_theme) as demo:
|
425 |
+
gr.Markdown(
|
426 |
+
"<div id='header'>🌅 IP-Composer "
|
427 |
+
"<sup style='font-size:14px'>SDXL</sup></div>"
|
428 |
+
)
|
429 |
+
|
430 |
+
concpet_from_file_1, concpet_from_file_2, concpet_from_file_3 = (
|
431 |
+
gr.State(),
|
432 |
+
gr.State(),
|
433 |
+
gr.State(),
|
434 |
+
)
|
435 |
+
use_concpet_from_file_1, use_concpet_from_file_2, use_concpet_from_file_3 = (
|
436 |
+
gr.State(),
|
437 |
+
gr.State(),
|
438 |
+
gr.State(),
|
439 |
+
)
|
440 |
+
|
441 |
with gr.Row(equal_height=True):
|
|
|
442 |
with gr.Column(elem_classes="card"):
|
443 |
+
base_image = gr.Image(
|
444 |
+
label="Base Image (Required)", type="numpy", height=400, width=400
|
445 |
+
)
|
446 |
|
|
|
447 |
for idx in (1, 2, 3):
|
448 |
with gr.Column(elem_classes="card"):
|
449 |
locals()[f"concept_image{idx}"] = gr.Image(
|
450 |
+
label=f"Concept Image {idx}"
|
451 |
+
if idx == 1
|
452 |
+
else f"Concept {idx} (Optional)",
|
453 |
+
type="numpy",
|
454 |
+
height=400,
|
455 |
+
width=400,
|
456 |
)
|
457 |
locals()[f"concept_name{idx}"] = gr.Dropdown(
|
458 |
+
concept_options,
|
459 |
+
label=f"Concept {idx}",
|
460 |
value=None if idx != 1 else "age",
|
461 |
+
info="Pick concept type",
|
462 |
)
|
463 |
with gr.Accordion("💡 Or use a new concept 👇", open=False):
|
464 |
+
gr.Markdown(
|
465 |
+
"1. Upload a file with **>100** text variations<br>"
|
466 |
+
"2. Tip: Ask an LLM to list variations."
|
467 |
+
)
|
468 |
if idx == 1:
|
469 |
+
concept_file_1 = gr.File(
|
470 |
+
label="Concept variations", file_types=["text"]
|
471 |
+
)
|
472 |
elif idx == 2:
|
473 |
+
concept_file_2 = gr.File(
|
474 |
+
label="Concept variations", file_types=["text"]
|
475 |
+
)
|
476 |
else:
|
477 |
+
concept_file_3 = gr.File(
|
478 |
+
label="Concept variations", file_types=["text"]
|
479 |
+
)
|
480 |
|
|
|
481 |
with gr.Column(elem_classes="card"):
|
482 |
with gr.Accordion("⚙️ Advanced options", open=False):
|
483 |
+
prompt = gr.Textbox(
|
484 |
+
label="Guidance Prompt (Optional)",
|
485 |
+
placeholder="Optional text prompt to guide generation",
|
486 |
+
)
|
487 |
+
num_inference_steps = gr.Slider(1, 50, 30, step=1, label="Num steps")
|
488 |
with gr.Row():
|
489 |
+
scale = gr.Slider(0.1, 2.0, 1.0, step=0.1, label="Scale")
|
490 |
randomize_seed = gr.Checkbox(True, label="Randomize seed")
|
491 |
seed = gr.Number(value=0, label="Seed", precision=0)
|
492 |
+
gr.Markdown(
|
493 |
+
"If a concept is not showing enough, **increase rank** ⬇️"
|
494 |
+
)
|
495 |
with gr.Row():
|
496 |
+
rank1 = gr.Slider(1, 150, 30, step=1, label="Rank concept 1")
|
497 |
+
rank2 = gr.Slider(1, 150, 30, step=1, label="Rank concept 2")
|
498 |
+
rank3 = gr.Slider(1, 150, 30, step=1, label="Rank concept 3")
|
499 |
|
|
|
500 |
with gr.Column(elem_classes="card"):
|
501 |
output_image = gr.Image(show_label=False, height=480)
|
502 |
submit_btn = gr.Button("🔮 Generate", variant="primary", size="lg")
|
503 |
|
|
|
504 |
gr.Markdown("### 🔥 Ready-made examples")
|
505 |
+
gr.Gallery(example_gallery, label="클릭해서 미리보기", columns=[3], height="auto")
|
|
|
506 |
|
|
|
507 |
gr.Examples(
|
508 |
examples,
|
509 |
+
inputs=[
|
510 |
+
base_image,
|
511 |
+
concept_image1,
|
512 |
+
concept_name1,
|
513 |
+
concept_image2,
|
514 |
+
concept_name2,
|
515 |
+
concept_image3,
|
516 |
+
concept_name3,
|
517 |
+
rank1,
|
518 |
+
rank2,
|
519 |
+
rank3,
|
520 |
+
prompt,
|
521 |
+
scale,
|
522 |
+
seed,
|
523 |
+
num_inference_steps,
|
524 |
+
],
|
525 |
outputs=[output_image],
|
526 |
fn=generate_examples,
|
527 |
+
cache_examples=False,
|
528 |
)
|
529 |
|
530 |
+
# Upload hooks
|
531 |
+
concept_file_1.upload(
|
532 |
+
get_text_embeddings,
|
533 |
+
[concept_file_1],
|
534 |
+
[concpet_from_file_1, use_concpet_from_file_1],
|
535 |
+
)
|
536 |
+
concept_file_2.upload(
|
537 |
+
get_text_embeddings,
|
538 |
+
[concept_file_2],
|
539 |
+
[concpet_from_file_2, use_concpet_from_file_2],
|
540 |
+
)
|
541 |
+
concept_file_3.upload(
|
542 |
+
get_text_embeddings,
|
543 |
+
[concept_file_3],
|
544 |
+
[concpet_from_file_3, use_concpet_from_file_3],
|
545 |
+
)
|
546 |
+
concept_file_1.delete(
|
547 |
+
lambda _: False, [concept_file_1], [use_concpet_from_file_1]
|
548 |
+
)
|
549 |
+
concept_file_2.delete(
|
550 |
+
lambda _: False, [concept_file_2], [use_concpet_from_file_2]
|
551 |
+
)
|
552 |
+
concept_file_3.delete(
|
553 |
+
lambda _: False, [concept_file_3], [use_concpet_from_file_3]
|
554 |
+
)
|
555 |
+
|
556 |
+
# Dropdown auto-rank
|
557 |
concept_name1.select(change_rank_default, [concept_name1], [rank1])
|
558 |
concept_name2.select(change_rank_default, [concept_name2], [rank2])
|
559 |
concept_name3.select(change_rank_default, [concept_name3], [rank3])
|
560 |
|
561 |
+
# Auto-match on upload
|
562 |
concept_image1.upload(match_image_to_concept, [concept_image1], [concept_name1])
|
563 |
concept_image2.upload(match_image_to_concept, [concept_image2], [concept_name2])
|
564 |
concept_image3.upload(match_image_to_concept, [concept_image3], [concept_name3])
|
565 |
|
566 |
+
# Generate chain
|
567 |
+
submit_btn.click(randomize_seed_fn, [seed, randomize_seed], seed).then(
|
568 |
+
process_and_display,
|
569 |
+
[
|
570 |
+
base_image,
|
571 |
+
concept_image1,
|
572 |
+
concept_name1,
|
573 |
+
concept_image2,
|
574 |
+
concept_name2,
|
575 |
+
concept_image3,
|
576 |
+
concept_name3,
|
577 |
+
rank1,
|
578 |
+
rank2,
|
579 |
+
rank3,
|
580 |
+
prompt,
|
581 |
+
scale,
|
582 |
+
seed,
|
583 |
+
num_inference_steps,
|
584 |
+
concpet_from_file_1,
|
585 |
+
concpet_from_file_2,
|
586 |
+
concpet_from_file_3,
|
587 |
+
use_concpet_from_file_1,
|
588 |
+
use_concpet_from_file_2,
|
589 |
+
use_concpet_from_file_3,
|
590 |
+
],
|
591 |
+
[output_image],
|
592 |
+
)
|
593 |
|
594 |
# ─────────────────────────────
|
595 |
+
# 10 · Launch
|
596 |
# ─────────────────────────────
|
597 |
if __name__ == "__main__":
|
598 |
demo.launch()
|