alexnasa commited on
Commit
cebd833
·
verified ·
1 Parent(s): e7ad2ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -241
app.py CHANGED
@@ -31,117 +31,117 @@ os.environ["NCCL_P2P_DISABLE"]="1"
31
  os.environ["NCCL_IB_DISABLE"]="1"
32
 
33
  import src.flux.generate
34
- # from src.flux.generate import generate_from_test_sample, seed_everything
35
- # from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, load_dit_lora
36
- # from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes
37
- # from eval.tools.face_id import FaceID
38
- # from eval.tools.florence_sam import ObjectDetector
39
- # import shutil
40
- # import yaml
41
- # import numpy as np
42
- # from huggingface_hub import snapshot_download, hf_hub_download
43
  import torch
44
 
45
- # # FLUX.1-dev
46
- # snapshot_download(
47
- # repo_id="black-forest-labs/FLUX.1-dev",
48
- # local_dir="./checkpoints/FLUX.1-dev",
49
- # local_dir_use_symlinks=False
50
- # )
51
-
52
- # # Florence-2-large
53
- # snapshot_download(
54
- # repo_id="microsoft/Florence-2-large",
55
- # local_dir="./checkpoints/Florence-2-large",
56
- # local_dir_use_symlinks=False
57
- # )
58
-
59
- # # CLIP ViT Large
60
- # snapshot_download(
61
- # repo_id="openai/clip-vit-large-patch14",
62
- # local_dir="./checkpoints/clip-vit-large-patch14",
63
- # local_dir_use_symlinks=False
64
- # )
65
-
66
- # # DINO ViT-s16
67
- # snapshot_download(
68
- # repo_id="facebook/dino-vits16",
69
- # local_dir="./checkpoints/dino-vits16",
70
- # local_dir_use_symlinks=False
71
- # )
72
-
73
- # # mPLUG Visual Question Answering
74
- # snapshot_download(
75
- # repo_id="xingjianleng/mplug_visual-question-answering_coco_large_en",
76
- # local_dir="./checkpoints/mplug_visual-question-answering_coco_large_en",
77
- # local_dir_use_symlinks=False
78
- # )
79
-
80
- # # XVerse
81
- # snapshot_download(
82
- # repo_id="ByteDance/XVerse",
83
- # local_dir="./checkpoints/XVerse",
84
- # local_dir_use_symlinks=False
85
- # )
86
-
87
- # hf_hub_download(
88
- # repo_id="facebook/sam2.1-hiera-large",
89
- # local_dir="./checkpoints/",
90
- # filename="sam2.1_hiera_large.pt",
91
- # )
92
-
93
-
94
-
95
- # os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large"
96
- # os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt"
97
- # os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth"
98
- # os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14"
99
- # os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-dev"
100
- # os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en"
101
- # os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16"
102
-
103
- # dtype = torch.bfloat16
104
- # device = "cuda"
105
-
106
- # config_path = "train/config/XVerse_config_demo.yaml"
107
-
108
- # config = config_train = get_train_config(config_path)
109
- # # config["model"]["dit_quant"] = "int8-quanto"
110
- # config["model"]["use_dit_lora"] = False
111
- # model = CustomFluxPipeline(
112
- # config, device, torch_dtype=dtype,
113
- # )
114
- # model.pipe.set_progress_bar_config(leave=False)
115
-
116
- # face_model = FaceID(device)
117
- # detector = ObjectDetector(device)
118
-
119
- # config = get_train_config(config_path)
120
- # model.config = config
121
-
122
- # run_mode = "mod_only" # orig_only, mod_only, both
123
- # store_attn_map = False
124
- # run_name = time.strftime("%m%d-%H%M")
125
-
126
- # num_inputs = 6
127
-
128
- # ckpt_root = "./checkpoints/XVerse"
129
- # model.clear_modulation_adapters()
130
- # model.pipe.unload_lora_weights()
131
- # if not os.path.exists(ckpt_root):
132
- # print("Checkpoint root does not exist.")
133
-
134
- # modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False)
135
- # model.add_modulation_adapter(modulation_adapter)
136
- # if config["model"]["use_dit_lora"]:
137
- # load_dit_lora(model, model.pipe, config, dtype, device, f"{ckpt_root}", is_training=False)
138
-
139
- # vae_skip_iter = None
140
- # attn_skip_iter = 0
141
-
142
-
143
- # def clear_images():
144
- # return [None, ]*num_inputs
145
 
146
  @spaces.GPU()
147
  def det_seg_img(image, label):
@@ -211,145 +211,143 @@ def generate_image(
211
  indexs, # 新增参数
212
  # *images_captions_faces, # Combine all unpacked arguments into one tuple
213
  ):
214
- # torch.cuda.empty_cache()
215
- # num_images = 1
216
-
217
- # # Determine the number of images, captions, and faces based on the indexs length
218
- # images = list(images_captions_faces[:num_inputs])
219
- # captions = list(images_captions_faces[num_inputs:2 * num_inputs])
220
- # idips_checkboxes = list(images_captions_faces[2 * num_inputs:3 * num_inputs])
221
- # images = [images[i] for i in indexs]
222
- # captions = [captions[i] for i in indexs]
223
- # idips_checkboxes = [idips_checkboxes[i] for i in indexs]
224
-
225
- # print(f"Length of images: {len(images)}")
226
- # print(f"Length of captions: {len(captions)}")
227
- # print(f"Indexs: {indexs}")
228
 
229
- # print(f"Control weight lambda: {control_weight_lambda}")
230
- # if control_weight_lambda != "no":
231
- # parts = control_weight_lambda.split(',')
232
- # new_parts = []
233
- # for part in parts:
234
- # if ':' in part:
235
- # left, right = part.split(':')
236
- # values = right.split('/')
237
- # # 保存整体值
238
- # global_value = values[0]
239
- # id_value = values[1]
240
- # ip_value = values[2]
241
- # new_values = [global_value]
242
- # for is_id in idips_checkboxes:
243
- # if is_id:
244
- # new_values.append(id_value)
245
- # else:
246
- # new_values.append(ip_value)
247
- # new_part = f"{left}:{('/'.join(new_values))}"
248
- # new_parts.append(new_part)
249
- # else:
250
- # new_parts.append(part)
251
- # control_weight_lambda = ','.join(new_parts)
252
 
253
- # print(f"Control weight lambda: {control_weight_lambda}")
254
-
255
- # src_inputs = []
256
- # use_words = []
257
- # cur_run_time = time.strftime("%m%d-%H%M%S")
258
- # tmp_dir_root = f"tmp/gradio_demo/{run_name}"
259
- # temp_dir = f"{tmp_dir_root}/{cur_run_time}_{generate_random_string(4)}"
260
- # os.makedirs(temp_dir, exist_ok=True)
261
- # print(f"Temporary directory created: {temp_dir}")
262
- # for i, (image_path, caption) in enumerate(zip(images, captions)):
263
- # if image_path:
264
- # if caption.startswith("a ") or caption.startswith("A "):
265
- # word = caption[2:]
266
- # else:
267
- # word = caption
268
 
269
- # if f"ENT{i+1}" in prompt:
270
- # prompt = prompt.replace(f"ENT{i+1}", caption)
271
 
272
- # image = resize_keep_aspect_ratio(Image.open(image_path), 768)
273
- # save_path = f"{temp_dir}/tmp_resized_input_{i}.png"
274
- # image.save(save_path)
275
 
276
- # input_image_path = save_path
277
-
278
- # src_inputs.append(
279
- # {
280
- # "image_path": input_image_path,
281
- # "caption": caption
282
- # }
283
- # )
284
- # use_words.append((i, word, word))
285
-
286
-
287
- # test_sample = dict(
288
- # input_images=[], position_delta=[0, -32],
289
- # prompt=prompt,
290
- # target_height=target_height,
291
- # target_width=target_width,
292
- # seed=seed,
293
- # cond_size=cond_size,
294
- # vae_skip_iter=vae_skip_iter,
295
- # lora_scale=ip_scale,
296
- # control_weight_lambda=control_weight_lambda,
297
- # latent_sblora_scale=latent_sblora_scale_str,
298
- # condition_sblora_scale=vae_lora_scale,
299
- # double_attention=double_attention,
300
- # single_attention=single_attention,
301
- # )
302
- # if len(src_inputs) > 0:
303
- # test_sample["modulation"] = [
304
- # dict(
305
- # type="adapter",
306
- # src_inputs=src_inputs,
307
- # use_words=use_words,
308
- # ),
309
- # ]
310
 
311
- # json_dump(test_sample, f"{temp_dir}/test_sample.json", 'utf-8')
312
- # assert single_attention == True
313
- # target_size = int(round((target_width * target_height) ** 0.5) // 16 * 16)
314
- # print(test_sample)
315
 
316
- # model.config["train"]["dataset"]["val_condition_size"] = cond_size
317
- # model.config["train"]["dataset"]["val_target_size"] = target_size
318
 
319
- # if control_weight_lambda == "no":
320
- # control_weight_lambda = None
321
- # if vae_skip_iter == "no":
322
- # vae_skip_iter = None
323
- # use_condition_sblora_control = True
324
- # use_latent_sblora_control = True
325
- # image = generate_from_test_sample(
326
- # test_sample, model.pipe, model.config,
327
- # num_images=num_images,
328
- # target_height=target_height,
329
- # target_width=target_width,
330
- # seed=seed,
331
- # store_attn_map=store_attn_map,
332
- # vae_skip_iter=vae_skip_iter, # 使用新的参数
333
- # control_weight_lambda=control_weight_lambda, # 传递新的参数
334
- # double_attention=double_attention, # 新增参数
335
- # single_attention=single_attention, # 新增参数
336
- # ip_scale=ip_scale,
337
- # use_latent_sblora_control=use_latent_sblora_control,
338
- # latent_sblora_scale=latent_sblora_scale_str,
339
- # use_condition_sblora_control=use_condition_sblora_control,
340
- # condition_sblora_scale=vae_lora_scale,
341
- # )
342
- # if isinstance(image, list):
343
- # num_cols = 2
344
- # num_rows = int(math.ceil(num_images / num_cols))
345
- # image = image_grid(image, num_rows, num_cols)
346
-
347
- # save_path = f"{temp_dir}/tmp_result.png"
348
- # image.save(save_path)
349
-
350
- # return image
351
-
352
- return None
353
 
354
 
355
 
@@ -533,7 +531,7 @@ if __name__ == "__main__":
533
  )
534
 
535
  # # 修改清空函数的输出参数
536
- # clear_btn.click(clear_images, outputs=images)
537
 
538
  face_btn_1.click(crop_face_img, inputs=[image_1], outputs=[image_1])
539
  det_btn_1.click(det_seg_img, inputs=[image_1, caption_1], outputs=[image_1])
 
31
  os.environ["NCCL_IB_DISABLE"]="1"
32
 
33
  import src.flux.generate
34
+ from src.flux.generate import generate_from_test_sample, seed_everything
35
+ from src.flux.pipeline_tools import CustomFluxPipeline, load_modulation_adapter, load_dit_lora
36
+ from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes
37
+ from eval.tools.face_id import FaceID
38
+ from eval.tools.florence_sam import ObjectDetector
39
+ import shutil
40
+ import yaml
41
+ import numpy as np
42
+ from huggingface_hub import snapshot_download, hf_hub_download
43
  import torch
44
 
45
+ # FLUX.1-dev
46
+ snapshot_download(
47
+ repo_id="black-forest-labs/FLUX.1-dev",
48
+ local_dir="./checkpoints/FLUX.1-dev",
49
+ local_dir_use_symlinks=False
50
+ )
51
+
52
+ # Florence-2-large
53
+ snapshot_download(
54
+ repo_id="microsoft/Florence-2-large",
55
+ local_dir="./checkpoints/Florence-2-large",
56
+ local_dir_use_symlinks=False
57
+ )
58
+
59
+ # CLIP ViT Large
60
+ snapshot_download(
61
+ repo_id="openai/clip-vit-large-patch14",
62
+ local_dir="./checkpoints/clip-vit-large-patch14",
63
+ local_dir_use_symlinks=False
64
+ )
65
+
66
+ # DINO ViT-s16
67
+ snapshot_download(
68
+ repo_id="facebook/dino-vits16",
69
+ local_dir="./checkpoints/dino-vits16",
70
+ local_dir_use_symlinks=False
71
+ )
72
+
73
+ # mPLUG Visual Question Answering
74
+ snapshot_download(
75
+ repo_id="xingjianleng/mplug_visual-question-answering_coco_large_en",
76
+ local_dir="./checkpoints/mplug_visual-question-answering_coco_large_en",
77
+ local_dir_use_symlinks=False
78
+ )
79
+
80
+ # XVerse
81
+ snapshot_download(
82
+ repo_id="ByteDance/XVerse",
83
+ local_dir="./checkpoints/XVerse",
84
+ local_dir_use_symlinks=False
85
+ )
86
+
87
+ hf_hub_download(
88
+ repo_id="facebook/sam2.1-hiera-large",
89
+ local_dir="./checkpoints/",
90
+ filename="sam2.1_hiera_large.pt",
91
+ )
92
+
93
+
94
+
95
+ os.environ["FLORENCE2_MODEL_PATH"] = "./checkpoints/Florence-2-large"
96
+ os.environ["SAM2_MODEL_PATH"] = "./checkpoints/sam2.1_hiera_large.pt"
97
+ os.environ["FACE_ID_MODEL_PATH"] = "./checkpoints/model_ir_se50.pth"
98
+ os.environ["CLIP_MODEL_PATH"] = "./checkpoints/clip-vit-large-patch14"
99
+ os.environ["FLUX_MODEL_PATH"] = "./checkpoints/FLUX.1-dev"
100
+ os.environ["DPG_VQA_MODEL_PATH"] = "./checkpoints/mplug_visual-question-answering_coco_large_en"
101
+ os.environ["DINO_MODEL_PATH"] = "./checkpoints/dino-vits16"
102
+
103
+ dtype = torch.bfloat16
104
+ device = "cuda"
105
+
106
+ config_path = "train/config/XVerse_config_demo.yaml"
107
+
108
+ config = config_train = get_train_config(config_path)
109
+ # config["model"]["dit_quant"] = "int8-quanto"
110
+ config["model"]["use_dit_lora"] = False
111
+ model = CustomFluxPipeline(
112
+ config, device, torch_dtype=dtype,
113
+ )
114
+ model.pipe.set_progress_bar_config(leave=False)
115
+
116
+ face_model = FaceID(device)
117
+ detector = ObjectDetector(device)
118
+
119
+ config = get_train_config(config_path)
120
+ model.config = config
121
+
122
+ run_mode = "mod_only" # orig_only, mod_only, both
123
+ store_attn_map = False
124
+ run_name = time.strftime("%m%d-%H%M")
125
+
126
+ num_inputs = 6
127
+
128
+ ckpt_root = "./checkpoints/XVerse"
129
+ model.clear_modulation_adapters()
130
+ model.pipe.unload_lora_weights()
131
+ if not os.path.exists(ckpt_root):
132
+ print("Checkpoint root does not exist.")
133
+
134
+ modulation_adapter = load_modulation_adapter(model, config, dtype, device, f"{ckpt_root}/modulation_adapter", is_training=False)
135
+ model.add_modulation_adapter(modulation_adapter)
136
+ if config["model"]["use_dit_lora"]:
137
+ load_dit_lora(model, model.pipe, config, dtype, device, f"{ckpt_root}", is_training=False)
138
+
139
+ vae_skip_iter = None
140
+ attn_skip_iter = 0
141
+
142
+
143
+ def clear_images():
144
+ return [None, ]*num_inputs
145
 
146
  @spaces.GPU()
147
  def det_seg_img(image, label):
 
211
  indexs, # 新增参数
212
  # *images_captions_faces, # Combine all unpacked arguments into one tuple
213
  ):
214
+ torch.cuda.empty_cache()
215
+ num_images = 1
216
+
217
+ # Determine the number of images, captions, and faces based on the indexs length
218
+ images = list(images_captions_faces[:num_inputs])
219
+ captions = list(images_captions_faces[num_inputs:2 * num_inputs])
220
+ idips_checkboxes = list(images_captions_faces[2 * num_inputs:3 * num_inputs])
221
+ images = [images[i] for i in indexs]
222
+ captions = [captions[i] for i in indexs]
223
+ idips_checkboxes = [idips_checkboxes[i] for i in indexs]
224
+
225
+ print(f"Length of images: {len(images)}")
226
+ print(f"Length of captions: {len(captions)}")
227
+ print(f"Indexs: {indexs}")
228
 
229
+ print(f"Control weight lambda: {control_weight_lambda}")
230
+ if control_weight_lambda != "no":
231
+ parts = control_weight_lambda.split(',')
232
+ new_parts = []
233
+ for part in parts:
234
+ if ':' in part:
235
+ left, right = part.split(':')
236
+ values = right.split('/')
237
+ # 保存整体值
238
+ global_value = values[0]
239
+ id_value = values[1]
240
+ ip_value = values[2]
241
+ new_values = [global_value]
242
+ for is_id in idips_checkboxes:
243
+ if is_id:
244
+ new_values.append(id_value)
245
+ else:
246
+ new_values.append(ip_value)
247
+ new_part = f"{left}:{('/'.join(new_values))}"
248
+ new_parts.append(new_part)
249
+ else:
250
+ new_parts.append(part)
251
+ control_weight_lambda = ','.join(new_parts)
252
 
253
+ print(f"Control weight lambda: {control_weight_lambda}")
254
+
255
+ src_inputs = []
256
+ use_words = []
257
+ cur_run_time = time.strftime("%m%d-%H%M%S")
258
+ tmp_dir_root = f"tmp/gradio_demo/{run_name}"
259
+ temp_dir = f"{tmp_dir_root}/{cur_run_time}_{generate_random_string(4)}"
260
+ os.makedirs(temp_dir, exist_ok=True)
261
+ print(f"Temporary directory created: {temp_dir}")
262
+ for i, (image_path, caption) in enumerate(zip(images, captions)):
263
+ if image_path:
264
+ if caption.startswith("a ") or caption.startswith("A "):
265
+ word = caption[2:]
266
+ else:
267
+ word = caption
268
 
269
+ if f"ENT{i+1}" in prompt:
270
+ prompt = prompt.replace(f"ENT{i+1}", caption)
271
 
272
+ image = resize_keep_aspect_ratio(Image.open(image_path), 768)
273
+ save_path = f"{temp_dir}/tmp_resized_input_{i}.png"
274
+ image.save(save_path)
275
 
276
+ input_image_path = save_path
277
+
278
+ src_inputs.append(
279
+ {
280
+ "image_path": input_image_path,
281
+ "caption": caption
282
+ }
283
+ )
284
+ use_words.append((i, word, word))
285
+
286
+
287
+ test_sample = dict(
288
+ input_images=[], position_delta=[0, -32],
289
+ prompt=prompt,
290
+ target_height=target_height,
291
+ target_width=target_width,
292
+ seed=seed,
293
+ cond_size=cond_size,
294
+ vae_skip_iter=vae_skip_iter,
295
+ lora_scale=ip_scale,
296
+ control_weight_lambda=control_weight_lambda,
297
+ latent_sblora_scale=latent_sblora_scale_str,
298
+ condition_sblora_scale=vae_lora_scale,
299
+ double_attention=double_attention,
300
+ single_attention=single_attention,
301
+ )
302
+ if len(src_inputs) > 0:
303
+ test_sample["modulation"] = [
304
+ dict(
305
+ type="adapter",
306
+ src_inputs=src_inputs,
307
+ use_words=use_words,
308
+ ),
309
+ ]
310
 
311
+ json_dump(test_sample, f"{temp_dir}/test_sample.json", 'utf-8')
312
+ assert single_attention == True
313
+ target_size = int(round((target_width * target_height) ** 0.5) // 16 * 16)
314
+ print(test_sample)
315
 
316
+ model.config["train"]["dataset"]["val_condition_size"] = cond_size
317
+ model.config["train"]["dataset"]["val_target_size"] = target_size
318
 
319
+ if control_weight_lambda == "no":
320
+ control_weight_lambda = None
321
+ if vae_skip_iter == "no":
322
+ vae_skip_iter = None
323
+ use_condition_sblora_control = True
324
+ use_latent_sblora_control = True
325
+ image = generate_from_test_sample(
326
+ test_sample, model.pipe, model.config,
327
+ num_images=num_images,
328
+ target_height=target_height,
329
+ target_width=target_width,
330
+ seed=seed,
331
+ store_attn_map=store_attn_map,
332
+ vae_skip_iter=vae_skip_iter, # 使用新的参数
333
+ control_weight_lambda=control_weight_lambda, # 传递新的参数
334
+ double_attention=double_attention, # 新增参数
335
+ single_attention=single_attention, # 新增参数
336
+ ip_scale=ip_scale,
337
+ use_latent_sblora_control=use_latent_sblora_control,
338
+ latent_sblora_scale=latent_sblora_scale_str,
339
+ use_condition_sblora_control=use_condition_sblora_control,
340
+ condition_sblora_scale=vae_lora_scale,
341
+ )
342
+ if isinstance(image, list):
343
+ num_cols = 2
344
+ num_rows = int(math.ceil(num_images / num_cols))
345
+ image = image_grid(image, num_rows, num_cols)
346
+
347
+ save_path = f"{temp_dir}/tmp_result.png"
348
+ image.save(save_path)
349
+
350
+ return image
 
 
351
 
352
 
353
 
 
531
  )
532
 
533
  # # 修改清空函数的输出参数
534
+ clear_btn.click(clear_images, outputs=images)
535
 
536
  face_btn_1.click(crop_face_img, inputs=[image_1], outputs=[image_1])
537
  det_btn_1.click(det_seg_img, inputs=[image_1, caption_1], outputs=[image_1])