seawolf2357 commited on
Commit
2608108
·
verified ·
1 Parent(s): 6aa0725

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +378 -204
app.py CHANGED
@@ -1,26 +1,35 @@
 
 
 
 
 
1
  import os, json, random, gc
2
  import numpy as np
3
  import torch
4
  from PIL import Image
5
  import gradio as gr
6
- from gradio.themes import Soft
7
  from diffusers import StableDiffusionXLPipeline
8
  import open_clip
9
  from huggingface_hub import hf_hub_download
10
  from IP_Composer.IP_Adapter.ip_adapter import IPAdapterXL
11
- from IP_Composer.perform_swap import (compute_dataset_embeds_svd,
12
- get_modified_images_embeds_composition)
13
- from IP_Composer.generate_text_embeddings import (load_descriptions,
14
- generate_embeddings)
 
 
 
 
15
  import spaces
16
 
17
  # ─────────────────────────────
18
- # 1· Device
19
  # ─────────────────────────────
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  # ─────────────────────────────
23
- # 2· Stable-Diffusion XL
24
  # ─────────────────────────────
25
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
26
  pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -30,32 +39,30 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
30
  )
31
 
32
  # ─────────────────────────────
33
- # 3· IP-Adapter
34
  # ─────────────────────────────
35
- image_encoder_repo = 'h94/IP-Adapter'
36
- image_encoder_subfolder = 'models/image_encoder'
37
  ip_ckpt = hf_hub_download(
38
- 'h94/IP-Adapter',
39
- subfolder="sdxl_models",
40
- filename='ip-adapter_sdxl_vit-h.bin'
 
41
  )
42
- ip_model = IPAdapterXL(pipe, image_encoder_repo,
43
- image_encoder_subfolder,
44
- ip_ckpt, device)
45
 
46
  # ─────────────────────────────
47
- # 4· CLIP
48
  # ─────────────────────────────
49
  clip_model, _, preprocess = open_clip.create_model_and_transforms(
50
- 'hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K'
51
  )
52
  clip_model.to(device)
53
  tokenizer = open_clip.get_tokenizer(
54
- 'hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K'
55
  )
56
 
57
  # ─────────────────────────────
58
- # 5· Concept maps
59
  # ─────────────────────────────
60
  CONCEPTS_MAP = {
61
  "age": "age_descriptions.npy",
@@ -72,87 +79,158 @@ CONCEPTS_MAP = {
72
  "daytime": "times_of_day_descriptions.npy",
73
  "pose": "person_poses_descriptions.npy",
74
  "season": "season_descriptions.npy",
75
- "material": "material_descriptions_with_gems.npy"
76
  }
77
  RANKS_MAP = {
78
- "age": 30, "animal fur": 80, "dogs": 30, "emotions": 30,
79
- "flowers": 30, "fruit/vegtable": 30, "outfit type": 30,
80
- "outfit pattern (including color)": 80, "patterns": 80,
81
- "patterns (including color)": 80, "vehicle": 30,
82
- "daytime": 30, "pose": 30, "season": 30, "material": 80
 
 
 
 
 
 
 
 
 
 
83
  }
84
  concept_options = list(CONCEPTS_MAP.keys())
85
 
86
  # ─────────────────────────────
87
- # 6· Example tuples (base_img, c1_img, …)
88
  # ─────────────────────────────
89
  examples = [
90
- ['./IP_Composer/assets/patterns/base.jpg',
91
- './IP_Composer/assets/patterns/pattern.png',
92
- 'patterns (including color)', None, None, None, None,
93
- 80, 30, 30, None, 1.0, 0, 30],
94
- ['./IP_Composer/assets/flowers/base.png',
95
- './IP_Composer/assets/flowers/concept.png',
96
- 'flowers', None, None, None, None,
97
- 30, 30, 30, None, 1.0, 0, 30],
98
- ['./IP_Composer/assets/materials/base.png',
99
- './IP_Composer/assets/materials/concept.jpg',
100
- 'material', None, None, None, None,
101
- 80, 30, 30, None, 1.0, 0, 30],
102
- # … (생략 없이 추가 가능)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  ]
104
 
105
  # ----------------------------------------------------------
106
- # 7· Utility functions (unchanged except docstring tweaks)
107
  # ----------------------------------------------------------
108
- def generate_examples(base_image,
109
- concept_image1, concept_name1,
110
- concept_image2, concept_name2,
111
- concept_image3, concept_name3,
112
- rank1, rank2, rank3,
113
- prompt, scale, seed, num_inference_steps):
114
- return process_and_display(base_image,
115
- concept_image1, concept_name1,
116
- concept_image2, concept_name2,
117
- concept_image3, concept_name3,
118
- rank1, rank2, rank3,
119
- prompt, scale, seed, num_inference_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  MAX_SEED = np.iinfo(np.int32).max
 
 
122
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
123
  return random.randint(0, MAX_SEED) if randomize_seed else seed
124
 
125
- def change_rank_default(concept_name): # rank 자동 조정
 
126
  return RANKS_MAP.get(concept_name, 30)
127
 
 
128
  @spaces.GPU
129
  def match_image_to_concept(image):
130
  if image is None:
131
  return None
132
- img_pil = Image.fromarray(image).convert("RGB")
133
  img_embed = get_image_embeds(img_pil, clip_model, preprocess, device)
134
-
135
- similarities = {}
136
- for concept_name, concept_file in CONCEPTS_MAP.items():
137
  try:
138
- embeds_path = f"./IP_Composer/text_embeddings/{concept_file}"
139
- with open(embeds_path, "rb") as f:
140
- concept_embeds = np.load(f)
141
- sim_scores = []
142
- for embed in concept_embeds:
143
- sim = np.dot(img_embed.flatten()/np.linalg.norm(img_embed),
144
- embed.flatten()/np.linalg.norm(embed))
145
- sim_scores.append(sim)
146
- sim_scores.sort(reverse=True)
147
- similarities[concept_name] = np.mean(sim_scores[:5])
 
148
  except Exception as e:
149
- print(f"Concept {concept_name} error: {e}")
150
- if similarities:
151
- detected = max(similarities, key=similarities.get)
152
- gr.Info(f"Image automatically matched to concept: {detected}")
153
- return detected
154
  return None
155
 
 
156
  @spaces.GPU
157
  def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
158
  image = preproc(pil_image)[np.newaxis, :, :, :]
@@ -160,47 +238,63 @@ def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device
160
  embeds = model.encode_image(image.to(dev))
161
  return embeds.cpu().detach().numpy()
162
 
 
163
  @spaces.GPU
164
  def process_images(
165
- base_image,
166
- concept_image1, concept_name1,
167
- concept_image2=None, concept_name2=None,
168
- concept_image3=None, concept_name3=None,
169
- rank1=10, rank2=10, rank3=10,
170
- prompt=None, scale=1.0, seed=420, num_inference_steps=50,
171
- concpet_from_file_1=None, concpet_from_file_2=None, concpet_from_file_3=None,
172
- use_concpet_from_file_1=False, use_concpet_from_file_2=False, use_concpet_from_file_3=False
 
 
 
 
 
 
 
 
 
 
 
 
173
  ):
174
- base_pil = Image.fromarray(base_image).convert("RGB")
175
  base_embed = get_image_embeds(base_pil, clip_model, preprocess, device)
176
 
177
  concept_images, concept_descs, ranks = [], [], []
178
  skip = [False, False, False]
179
 
180
- # ─── concept 1
181
  if concept_image1 is None:
182
  return None, "Please upload at least one concept image"
183
  concept_images.append(concept_image1)
184
  if use_concpet_from_file_1 and concpet_from_file_1 is not None:
185
- concept_descs.append(concpet_from_file_1); skip[0] = True
 
186
  else:
187
  concept_descs.append(CONCEPTS_MAP[concept_name1])
188
  ranks.append(rank1)
189
 
190
- # ─── concept 2
191
  if concept_image2 is not None:
192
  concept_images.append(concept_image2)
193
  if use_concpet_from_file_2 and concpet_from_file_2 is not None:
194
- concept_descs.append(concpet_from_file_2); skip[1] = True
 
195
  else:
196
  concept_descs.append(CONCEPTS_MAP[concept_name2])
197
  ranks.append(rank2)
198
 
199
- # ─── concept 3
200
  if concept_image3 is not None:
201
  concept_images.append(concept_image3)
202
  if use_concpet_from_file_3 and concpet_from_file_3 is not None:
203
- concept_descs.append(concpet_from_file_3); skip[2] = True
 
204
  else:
205
  concept_descs.append(CONCEPTS_MAP[concept_name3])
206
  ranks.append(rank3)
@@ -220,29 +314,47 @@ def process_images(
220
  {"embed": e, "projection_matrix": p}
221
  for e, p in zip(concept_embeds, proj_mats)
222
  ]
223
- modified_images = get_modified_images_embeds_composition(
224
- base_embed, projections_data, ip_model,
225
- prompt=prompt, scale=scale,
226
- num_samples=1, seed=seed, num_inference_steps=num_inference_steps
 
 
 
 
 
227
  )
228
- return modified_images[0]
 
229
 
230
  @spaces.GPU
231
  def get_text_embeddings(concept_file):
232
- descriptions = load_descriptions(concept_file)
233
- embeddings = generate_embeddings(descriptions, clip_model,
234
- tokenizer, device, batch_size=100)
235
- return embeddings, True
236
 
237
  def process_and_display(
238
- base_image,
239
- concept_image1, concept_name1="age",
240
- concept_image2=None, concept_name2=None,
241
- concept_image3=None, concept_name3=None,
242
- rank1=30, rank2=30, rank3=30,
243
- prompt=None, scale=1.0, seed=0, num_inference_steps=50,
244
- concpet_from_file_1=None, concpet_from_file_2=None, concpet_from_file_3=None,
245
- use_concpet_from_file_1=False, use_concpet_from_file_2=False, use_concpet_from_file_3=False
 
 
 
 
 
 
 
 
 
 
 
 
246
  ):
247
  if base_image is None:
248
  raise gr.Error("Please upload a base image")
@@ -250,32 +362,45 @@ def process_and_display(
250
  raise gr.Error("Choose at least one concept image")
251
 
252
  return process_images(
253
- base_image, concept_image1, concept_name1,
254
- concept_image2, concept_name2,
255
- concept_image3, concept_name3,
256
- rank1, rank2, rank3,
257
- prompt, scale, seed, num_inference_steps,
258
- concpet_from_file_1, concpet_from_file_2, concpet_from_file_3,
259
- use_concpet_from_file_1, use_concpet_from_file_2, use_concpet_from_file_3
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  )
261
 
 
262
  # ----------------------------------------------------------
263
- # 8· 💄 THEME & CSS UPGRADE
264
  # ----------------------------------------------------------
265
- demo_theme = Soft( # ★ NEW
266
- primary_hue="purple",
267
- font=[gr.themes.GoogleFont("Inter")]
268
- )
269
  css = """
270
  body{
271
  background:#0f0c29;
272
  background:linear-gradient(135deg,#0f0c29,#302b63,#24243e);
273
  }
274
- #header{ text-align:center;
275
- padding:24px 0 8px;
276
- font-weight:700;
277
- font-size:2.1rem;
278
- color:#ffffff;}
 
 
279
  .gradio-container{max-width:1024px !important;margin:0 auto}
280
  .card{
281
  border-radius:18px;
@@ -288,137 +413,186 @@ body{
288
  """
289
 
290
  # ----------------------------------------------------------
291
- # 9· 🖼️ Demo UI
292
  # ----------------------------------------------------------
293
  example_gallery = [
294
- ['./IP_Composer/assets/patterns/base.jpg', "Patterns demo"],
295
- ['./IP_Composer/assets/flowers/base.png', "Flowers demo"],
296
- ['./IP_Composer/assets/materials/base.png',"Material demo"],
297
  ]
298
 
299
  with gr.Blocks(css=css, theme=demo_theme) as demo:
300
- # ─── Header
301
- gr.Markdown("<div id='header'>🌅 IP-Composer&nbsp;"
302
- "<sup style='font-size:14px'>SDXL</sup></div>")
303
-
304
- # ─── States for custom concepts
305
- concpet_from_file_1 = gr.State()
306
- concpet_from_file_2 = gr.State()
307
- concpet_from_file_3 = gr.State()
308
- use_concpet_from_file_1 = gr.State()
309
- use_concpet_from_file_2 = gr.State()
310
- use_concpet_from_file_3 = gr.State()
311
-
312
- # ─── Main layout
 
 
 
313
  with gr.Row(equal_height=True):
314
- # Base image card
315
  with gr.Column(elem_classes="card"):
316
- base_image = gr.Image(label="Base Image (Required)",
317
- type="numpy", height=400, width=400)
 
318
 
319
- # Concept cards (1 · 2 · 3)
320
  for idx in (1, 2, 3):
321
  with gr.Column(elem_classes="card"):
322
  locals()[f"concept_image{idx}"] = gr.Image(
323
- label=f"Concept Image {idx}" if idx == 1 else f"Concept {idx} (Optional)",
324
- type="numpy", height=400, width=400
 
 
 
 
325
  )
326
  locals()[f"concept_name{idx}"] = gr.Dropdown(
327
- concept_options, label=f"Concept {idx}",
 
328
  value=None if idx != 1 else "age",
329
- info="Pick concept type"
330
  )
331
  with gr.Accordion("💡 Or use a new concept 👇", open=False):
332
- gr.Markdown("1. Upload a file with **>100** text variations<br>"
333
- "2. Tip: Ask an LLM to list variations.")
 
 
334
  if idx == 1:
335
- concept_file_1 = gr.File("Concept variations",
336
- file_types=["text"])
 
337
  elif idx == 2:
338
- concept_file_2 = gr.File("Concept variations",
339
- file_types=["text"])
 
340
  else:
341
- concept_file_3 = gr.File("Concept variations",
342
- file_types=["text"])
 
343
 
344
- # ─── Advanced options card (full width)
345
  with gr.Column(elem_classes="card"):
346
  with gr.Accordion("⚙️ Advanced options", open=False):
347
- prompt = gr.Textbox(label="Guidance Prompt (Optional)",
348
- placeholder="Optional text prompt to guide generation")
349
- num_inference_steps = gr.Slider(1, 50, value=30, step=1,
350
- label="Num steps")
 
351
  with gr.Row():
352
- scale = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Scale")
353
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
354
  seed = gr.Number(value=0, label="Seed", precision=0)
355
- gr.Markdown("If a concept is not showing enough, **increase rank** ⬇️")
 
 
356
  with gr.Row():
357
- rank1 = gr.Slider(1, 150, value=30, step=1, label="Rank concept 1")
358
- rank2 = gr.Slider(1, 150, value=30, step=1, label="Rank concept 2")
359
- rank3 = gr.Slider(1, 150, value=30, step=1, label="Rank concept 3")
360
 
361
- # ─── Output & Generate button
362
  with gr.Column(elem_classes="card"):
363
  output_image = gr.Image(show_label=False, height=480)
364
  submit_btn = gr.Button("🔮 Generate", variant="primary", size="lg")
365
 
366
- # ─── Ready-made Gallery
367
  gr.Markdown("### 🔥 Ready-made examples")
368
- gr.Gallery(example_gallery, label="Preview",
369
- columns=[3], height="auto")
370
 
371
- # ─── Example usage (kept for quick test)
372
  gr.Examples(
373
  examples,
374
- inputs=[base_image, concept_image1, concept_name1,
375
- concept_image2, concept_name2,
376
- concept_image3, concept_name3,
377
- rank1, rank2, rank3,
378
- prompt, scale, seed, num_inference_steps],
 
 
 
 
 
 
 
 
 
 
 
379
  outputs=[output_image],
380
  fn=generate_examples,
381
- cache_examples=False
382
  )
383
 
384
- # ─── File upload triggers
385
- concept_file_1.upload(get_text_embeddings, [concept_file_1],
386
- [concpet_from_file_1, use_concpet_from_file_1])
387
- concept_file_2.upload(get_text_embeddings, [concept_file_2],
388
- [concpet_from_file_2, use_concpet_from_file_2])
389
- concept_file_3.upload(get_text_embeddings, [concept_file_3],
390
- [concpet_from_file_3, use_concpet_from_file_3])
391
- concept_file_1.delete(lambda x: False, [concept_file_1],
392
- [use_concpet_from_file_1])
393
- concept_file_2.delete(lambda x: False, [concept_file_2],
394
- [use_concpet_from_file_2])
395
- concept_file_3.delete(lambda x: False, [concept_file_3],
396
- [use_concpet_from_file_3])
397
-
398
- # ─── Dropdown auto-rank
 
 
 
 
 
 
 
 
 
 
 
 
399
  concept_name1.select(change_rank_default, [concept_name1], [rank1])
400
  concept_name2.select(change_rank_default, [concept_name2], [rank2])
401
  concept_name3.select(change_rank_default, [concept_name3], [rank3])
402
 
403
- # ─── Auto-match concept type on image upload
404
  concept_image1.upload(match_image_to_concept, [concept_image1], [concept_name1])
405
  concept_image2.upload(match_image_to_concept, [concept_image2], [concept_name2])
406
  concept_image3.upload(match_image_to_concept, [concept_image3], [concept_name3])
407
 
408
- # ─── Generate click chain
409
- submit_btn.click(randomize_seed_fn, [seed, randomize_seed], seed) \
410
- .then(process_and_display,
411
- [base_image, concept_image1, concept_name1,
412
- concept_image2, concept_name2,
413
- concept_image3, concept_name3,
414
- rank1, rank2, rank3,
415
- prompt, scale, seed, num_inference_steps,
416
- concpet_from_file_1, concpet_from_file_2, concpet_from_file_3,
417
- use_concpet_from_file_1, use_concpet_from_file_2, use_concpet_from_file_3],
418
- [output_image])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
  # ─────────────────────────────
421
- # 10· Launch
422
  # ─────────────────────────────
423
  if __name__ == "__main__":
424
  demo.launch()
 
1
+ # ===========================================
2
+ # IP-Composer 🌅✚🖌️ – FULL IMPROVED UI SCRIPT
3
+ # (기능 동일, UI·테마·갤러리 강화 + FileNotFoundError 수정)
4
+ # ===========================================
5
+
6
  import os, json, random, gc
7
  import numpy as np
8
  import torch
9
  from PIL import Image
10
  import gradio as gr
11
+ from gradio.themes import Soft
12
  from diffusers import StableDiffusionXLPipeline
13
  import open_clip
14
  from huggingface_hub import hf_hub_download
15
  from IP_Composer.IP_Adapter.ip_adapter import IPAdapterXL
16
+ from IP_Composer.perform_swap import (
17
+ compute_dataset_embeds_svd,
18
+ get_modified_images_embeds_composition,
19
+ )
20
+ from IP_Composer.generate_text_embeddings import (
21
+ load_descriptions,
22
+ generate_embeddings,
23
+ )
24
  import spaces
25
 
26
  # ─────────────────────────────
27
+ # 1 · Device
28
  # ─────────────────────────────
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
  # ─────────────────────────────
32
+ # 2 · Stable-Diffusion XL
33
  # ─────────────────────────────
34
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
35
  pipe = StableDiffusionXLPipeline.from_pretrained(
 
39
  )
40
 
41
  # ─────────────────────────────
42
+ # 3 · IP-Adapter
43
  # ─────────────────────────────
44
+ image_encoder_repo = "h94/IP-Adapter"
45
+ image_encoder_subfolder = "models/image_encoder"
46
  ip_ckpt = hf_hub_download(
47
+ "h94/IP-Adapter", subfolder="sdxl_models", filename="ip-adapter_sdxl_vit-h.bin"
48
+ )
49
+ ip_model = IPAdapterXL(
50
+ pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device
51
  )
 
 
 
52
 
53
  # ─────────────────────────────
54
+ # 4 · CLIP
55
  # ─────────────────────────────
56
  clip_model, _, preprocess = open_clip.create_model_and_transforms(
57
+ "hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
58
  )
59
  clip_model.to(device)
60
  tokenizer = open_clip.get_tokenizer(
61
+ "hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
62
  )
63
 
64
  # ─────────────────────────────
65
+ # 5 · Concept maps
66
  # ─────────────────────────────
67
  CONCEPTS_MAP = {
68
  "age": "age_descriptions.npy",
 
79
  "daytime": "times_of_day_descriptions.npy",
80
  "pose": "person_poses_descriptions.npy",
81
  "season": "season_descriptions.npy",
82
+ "material": "material_descriptions_with_gems.npy",
83
  }
84
  RANKS_MAP = {
85
+ "age": 30,
86
+ "animal fur": 80,
87
+ "dogs": 30,
88
+ "emotions": 30,
89
+ "flowers": 30,
90
+ "fruit/vegtable": 30,
91
+ "outfit type": 30,
92
+ "outfit pattern (including color)": 80,
93
+ "patterns": 80,
94
+ "patterns (including color)": 80,
95
+ "vehicle": 30,
96
+ "daytime": 30,
97
+ "pose": 30,
98
+ "season": 30,
99
+ "material": 80,
100
  }
101
  concept_options = list(CONCEPTS_MAP.keys())
102
 
103
  # ─────────────────────────────
104
+ # 6 · Example tuples (base_img, c1_img, …)
105
  # ─────────────────────────────
106
  examples = [
107
+ [
108
+ "./IP_Composer/assets/patterns/base.jpg",
109
+ "./IP_Composer/assets/patterns/pattern.png",
110
+ "patterns (including color)",
111
+ None,
112
+ None,
113
+ None,
114
+ None,
115
+ 80,
116
+ 30,
117
+ 30,
118
+ None,
119
+ 1.0,
120
+ 0,
121
+ 30,
122
+ ],
123
+ [
124
+ "./IP_Composer/assets/flowers/base.png",
125
+ "./IP_Composer/assets/flowers/concept.png",
126
+ "flowers",
127
+ None,
128
+ None,
129
+ None,
130
+ None,
131
+ 30,
132
+ 30,
133
+ 30,
134
+ None,
135
+ 1.0,
136
+ 0,
137
+ 30,
138
+ ],
139
+ [
140
+ "./IP_Composer/assets/materials/base.png",
141
+ "./IP_Composer/assets/materials/concept.jpg",
142
+ "material",
143
+ None,
144
+ None,
145
+ None,
146
+ None,
147
+ 80,
148
+ 30,
149
+ 30,
150
+ None,
151
+ 1.0,
152
+ 0,
153
+ 30,
154
+ ],
155
  ]
156
 
157
  # ----------------------------------------------------------
158
+ # 7 · Utility functions
159
  # ----------------------------------------------------------
160
+ def generate_examples(
161
+ base_image,
162
+ concept_image1,
163
+ concept_name1,
164
+ concept_image2,
165
+ concept_name2,
166
+ concept_image3,
167
+ concept_name3,
168
+ rank1,
169
+ rank2,
170
+ rank3,
171
+ prompt,
172
+ scale,
173
+ seed,
174
+ num_inference_steps,
175
+ ):
176
+ return process_and_display(
177
+ base_image,
178
+ concept_image1,
179
+ concept_name1,
180
+ concept_image2,
181
+ concept_name2,
182
+ concept_image3,
183
+ concept_name3,
184
+ rank1,
185
+ rank2,
186
+ rank3,
187
+ prompt,
188
+ scale,
189
+ seed,
190
+ num_inference_steps,
191
+ )
192
+
193
 
194
  MAX_SEED = np.iinfo(np.int32).max
195
+
196
+
197
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
198
  return random.randint(0, MAX_SEED) if randomize_seed else seed
199
 
200
+
201
+ def change_rank_default(concept_name):
202
  return RANKS_MAP.get(concept_name, 30)
203
 
204
+
205
  @spaces.GPU
206
  def match_image_to_concept(image):
207
  if image is None:
208
  return None
209
+ img_pil = Image.fromarray(image).convert("RGB")
210
  img_embed = get_image_embeds(img_pil, clip_model, preprocess, device)
211
+ sims = {}
212
+ for cname, cfile in CONCEPTS_MAP.items():
 
213
  try:
214
+ with open(f"./IP_Composer/text_embeddings/{cfile}", "rb") as f:
215
+ embeds = np.load(f)
216
+ scores = []
217
+ for e in embeds:
218
+ s = np.dot(
219
+ img_embed.flatten() / np.linalg.norm(img_embed),
220
+ e.flatten() / np.linalg.norm(e),
221
+ )
222
+ scores.append(s)
223
+ scores.sort(reverse=True)
224
+ sims[cname] = np.mean(scores[:5])
225
  except Exception as e:
226
+ print(cname, "error:", e)
227
+ if sims:
228
+ best = max(sims, key=sims.get)
229
+ gr.Info(f"Image automatically matched to concept: {best}")
230
+ return best
231
  return None
232
 
233
+
234
  @spaces.GPU
235
  def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
236
  image = preproc(pil_image)[np.newaxis, :, :, :]
 
238
  embeds = model.encode_image(image.to(dev))
239
  return embeds.cpu().detach().numpy()
240
 
241
+
242
  @spaces.GPU
243
  def process_images(
244
+ base_image,
245
+ concept_image1,
246
+ concept_name1,
247
+ concept_image2=None,
248
+ concept_name2=None,
249
+ concept_image3=None,
250
+ concept_name3=None,
251
+ rank1=10,
252
+ rank2=10,
253
+ rank3=10,
254
+ prompt=None,
255
+ scale=1.0,
256
+ seed=420,
257
+ num_inference_steps=50,
258
+ concpet_from_file_1=None,
259
+ concpet_from_file_2=None,
260
+ concpet_from_file_3=None,
261
+ use_concpet_from_file_1=False,
262
+ use_concpet_from_file_2=False,
263
+ use_concpet_from_file_3=False,
264
  ):
265
+ base_pil = Image.fromarray(base_image).convert("RGB")
266
  base_embed = get_image_embeds(base_pil, clip_model, preprocess, device)
267
 
268
  concept_images, concept_descs, ranks = [], [], []
269
  skip = [False, False, False]
270
 
271
+ # concept 1
272
  if concept_image1 is None:
273
  return None, "Please upload at least one concept image"
274
  concept_images.append(concept_image1)
275
  if use_concpet_from_file_1 and concpet_from_file_1 is not None:
276
+ concept_descs.append(concpet_from_file_1)
277
+ skip[0] = True
278
  else:
279
  concept_descs.append(CONCEPTS_MAP[concept_name1])
280
  ranks.append(rank1)
281
 
282
+ # concept 2
283
  if concept_image2 is not None:
284
  concept_images.append(concept_image2)
285
  if use_concpet_from_file_2 and concpet_from_file_2 is not None:
286
+ concept_descs.append(concpet_from_file_2)
287
+ skip[1] = True
288
  else:
289
  concept_descs.append(CONCEPTS_MAP[concept_name2])
290
  ranks.append(rank2)
291
 
292
+ # concept 3
293
  if concept_image3 is not None:
294
  concept_images.append(concept_image3)
295
  if use_concpet_from_file_3 and concpet_from_file_3 is not None:
296
+ concept_descs.append(concpet_from_file_3)
297
+ skip[2] = True
298
  else:
299
  concept_descs.append(CONCEPTS_MAP[concept_name3])
300
  ranks.append(rank3)
 
314
  {"embed": e, "projection_matrix": p}
315
  for e, p in zip(concept_embeds, proj_mats)
316
  ]
317
+ modified = get_modified_images_embeds_composition(
318
+ base_embed,
319
+ projections_data,
320
+ ip_model,
321
+ prompt=prompt,
322
+ scale=scale,
323
+ num_samples=1,
324
+ seed=seed,
325
+ num_inference_steps=num_inference_steps,
326
  )
327
+ return modified[0]
328
+
329
 
330
  @spaces.GPU
331
  def get_text_embeddings(concept_file):
332
+ descs = load_descriptions(concept_file)
333
+ embeds = generate_embeddings(descs, clip_model, tokenizer, device, batch_size=100)
334
+ return embeds, True
335
+
336
 
337
  def process_and_display(
338
+ base_image,
339
+ concept_image1,
340
+ concept_name1="age",
341
+ concept_image2=None,
342
+ concept_name2=None,
343
+ concept_image3=None,
344
+ concept_name3=None,
345
+ rank1=30,
346
+ rank2=30,
347
+ rank3=30,
348
+ prompt=None,
349
+ scale=1.0,
350
+ seed=0,
351
+ num_inference_steps=50,
352
+ concpet_from_file_1=None,
353
+ concpet_from_file_2=None,
354
+ concpet_from_file_3=None,
355
+ use_concpet_from_file_1=False,
356
+ use_concpet_from_file_2=False,
357
+ use_concpet_from_file_3=False,
358
  ):
359
  if base_image is None:
360
  raise gr.Error("Please upload a base image")
 
362
  raise gr.Error("Choose at least one concept image")
363
 
364
  return process_images(
365
+ base_image,
366
+ concept_image1,
367
+ concept_name1,
368
+ concept_image2,
369
+ concept_name2,
370
+ concept_image3,
371
+ concept_name3,
372
+ rank1,
373
+ rank2,
374
+ rank3,
375
+ prompt,
376
+ scale,
377
+ seed,
378
+ num_inference_steps,
379
+ concpet_from_file_1,
380
+ concpet_from_file_2,
381
+ concpet_from_file_3,
382
+ use_concpet_from_file_1,
383
+ use_concpet_from_file_2,
384
+ use_concpet_from_file_3,
385
  )
386
 
387
+
388
  # ----------------------------------------------------------
389
+ # 8 · THEME & CSS
390
  # ----------------------------------------------------------
391
+ demo_theme = Soft(primary_hue="purple", font=[gr.themes.GoogleFont("Inter")])
 
 
 
392
  css = """
393
  body{
394
  background:#0f0c29;
395
  background:linear-gradient(135deg,#0f0c29,#302b63,#24243e);
396
  }
397
+ #header{
398
+ text-align:center;
399
+ padding:24px 0 8px;
400
+ font-weight:700;
401
+ font-size:2.1rem;
402
+ color:#ffffff;
403
+ }
404
  .gradio-container{max-width:1024px !important;margin:0 auto}
405
  .card{
406
  border-radius:18px;
 
413
  """
414
 
415
  # ----------------------------------------------------------
416
+ # 9 · UI
417
  # ----------------------------------------------------------
418
  example_gallery = [
419
+ ["./IP_Composer/assets/patterns/base.jpg", "Patterns demo"],
420
+ ["./IP_Composer/assets/flowers/base.png", "Flowers demo"],
421
+ ["./IP_Composer/assets/materials/base.png", "Material demo"],
422
  ]
423
 
424
  with gr.Blocks(css=css, theme=demo_theme) as demo:
425
+ gr.Markdown(
426
+ "<div id='header'>🌅 IP-Composer&nbsp;"
427
+ "<sup style='font-size:14px'>SDXL</sup></div>"
428
+ )
429
+
430
+ concpet_from_file_1, concpet_from_file_2, concpet_from_file_3 = (
431
+ gr.State(),
432
+ gr.State(),
433
+ gr.State(),
434
+ )
435
+ use_concpet_from_file_1, use_concpet_from_file_2, use_concpet_from_file_3 = (
436
+ gr.State(),
437
+ gr.State(),
438
+ gr.State(),
439
+ )
440
+
441
  with gr.Row(equal_height=True):
 
442
  with gr.Column(elem_classes="card"):
443
+ base_image = gr.Image(
444
+ label="Base Image (Required)", type="numpy", height=400, width=400
445
+ )
446
 
 
447
  for idx in (1, 2, 3):
448
  with gr.Column(elem_classes="card"):
449
  locals()[f"concept_image{idx}"] = gr.Image(
450
+ label=f"Concept Image {idx}"
451
+ if idx == 1
452
+ else f"Concept {idx} (Optional)",
453
+ type="numpy",
454
+ height=400,
455
+ width=400,
456
  )
457
  locals()[f"concept_name{idx}"] = gr.Dropdown(
458
+ concept_options,
459
+ label=f"Concept {idx}",
460
  value=None if idx != 1 else "age",
461
+ info="Pick concept type",
462
  )
463
  with gr.Accordion("💡 Or use a new concept 👇", open=False):
464
+ gr.Markdown(
465
+ "1. Upload a file with **>100** text variations<br>"
466
+ "2. Tip: Ask an LLM to list variations."
467
+ )
468
  if idx == 1:
469
+ concept_file_1 = gr.File(
470
+ label="Concept variations", file_types=["text"]
471
+ )
472
  elif idx == 2:
473
+ concept_file_2 = gr.File(
474
+ label="Concept variations", file_types=["text"]
475
+ )
476
  else:
477
+ concept_file_3 = gr.File(
478
+ label="Concept variations", file_types=["text"]
479
+ )
480
 
 
481
  with gr.Column(elem_classes="card"):
482
  with gr.Accordion("⚙️ Advanced options", open=False):
483
+ prompt = gr.Textbox(
484
+ label="Guidance Prompt (Optional)",
485
+ placeholder="Optional text prompt to guide generation",
486
+ )
487
+ num_inference_steps = gr.Slider(1, 50, 30, step=1, label="Num steps")
488
  with gr.Row():
489
+ scale = gr.Slider(0.1, 2.0, 1.0, step=0.1, label="Scale")
490
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
491
  seed = gr.Number(value=0, label="Seed", precision=0)
492
+ gr.Markdown(
493
+ "If a concept is not showing enough, **increase rank** ⬇️"
494
+ )
495
  with gr.Row():
496
+ rank1 = gr.Slider(1, 150, 30, step=1, label="Rank concept 1")
497
+ rank2 = gr.Slider(1, 150, 30, step=1, label="Rank concept 2")
498
+ rank3 = gr.Slider(1, 150, 30, step=1, label="Rank concept 3")
499
 
 
500
  with gr.Column(elem_classes="card"):
501
  output_image = gr.Image(show_label=False, height=480)
502
  submit_btn = gr.Button("🔮 Generate", variant="primary", size="lg")
503
 
 
504
  gr.Markdown("### 🔥 Ready-made examples")
505
+ gr.Gallery(example_gallery, label="클릭해서 미리보기", columns=[3], height="auto")
 
506
 
 
507
  gr.Examples(
508
  examples,
509
+ inputs=[
510
+ base_image,
511
+ concept_image1,
512
+ concept_name1,
513
+ concept_image2,
514
+ concept_name2,
515
+ concept_image3,
516
+ concept_name3,
517
+ rank1,
518
+ rank2,
519
+ rank3,
520
+ prompt,
521
+ scale,
522
+ seed,
523
+ num_inference_steps,
524
+ ],
525
  outputs=[output_image],
526
  fn=generate_examples,
527
+ cache_examples=False,
528
  )
529
 
530
+ # Upload hooks
531
+ concept_file_1.upload(
532
+ get_text_embeddings,
533
+ [concept_file_1],
534
+ [concpet_from_file_1, use_concpet_from_file_1],
535
+ )
536
+ concept_file_2.upload(
537
+ get_text_embeddings,
538
+ [concept_file_2],
539
+ [concpet_from_file_2, use_concpet_from_file_2],
540
+ )
541
+ concept_file_3.upload(
542
+ get_text_embeddings,
543
+ [concept_file_3],
544
+ [concpet_from_file_3, use_concpet_from_file_3],
545
+ )
546
+ concept_file_1.delete(
547
+ lambda _: False, [concept_file_1], [use_concpet_from_file_1]
548
+ )
549
+ concept_file_2.delete(
550
+ lambda _: False, [concept_file_2], [use_concpet_from_file_2]
551
+ )
552
+ concept_file_3.delete(
553
+ lambda _: False, [concept_file_3], [use_concpet_from_file_3]
554
+ )
555
+
556
+ # Dropdown auto-rank
557
  concept_name1.select(change_rank_default, [concept_name1], [rank1])
558
  concept_name2.select(change_rank_default, [concept_name2], [rank2])
559
  concept_name3.select(change_rank_default, [concept_name3], [rank3])
560
 
561
+ # Auto-match on upload
562
  concept_image1.upload(match_image_to_concept, [concept_image1], [concept_name1])
563
  concept_image2.upload(match_image_to_concept, [concept_image2], [concept_name2])
564
  concept_image3.upload(match_image_to_concept, [concept_image3], [concept_name3])
565
 
566
+ # Generate chain
567
+ submit_btn.click(randomize_seed_fn, [seed, randomize_seed], seed).then(
568
+ process_and_display,
569
+ [
570
+ base_image,
571
+ concept_image1,
572
+ concept_name1,
573
+ concept_image2,
574
+ concept_name2,
575
+ concept_image3,
576
+ concept_name3,
577
+ rank1,
578
+ rank2,
579
+ rank3,
580
+ prompt,
581
+ scale,
582
+ seed,
583
+ num_inference_steps,
584
+ concpet_from_file_1,
585
+ concpet_from_file_2,
586
+ concpet_from_file_3,
587
+ use_concpet_from_file_1,
588
+ use_concpet_from_file_2,
589
+ use_concpet_from_file_3,
590
+ ],
591
+ [output_image],
592
+ )
593
 
594
  # ─────────────────────────────
595
+ # 10 · Launch
596
  # ─────────────────────────────
597
  if __name__ == "__main__":
598
  demo.launch()