Fucius commited on
Commit
64fdfbc
·
verified ·
1 Parent(s): 83de213

Delete gradio_demo/app_instantID.py

Browse files
Files changed (1) hide show
  1. gradio_demo/app_instantID.py +0 -701
gradio_demo/app_instantID.py DELETED
@@ -1,701 +0,0 @@
1
- import sys
2
- sys.path.append('./')
3
- import argparse
4
- import hashlib
5
- import json
6
- import os.path
7
- import numpy as np
8
- import torch
9
- from typing import Tuple, List
10
- from diffusers import DPMSolverMultistepScheduler
11
- from diffusers.models import T2IAdapter
12
- from PIL import Image
13
- import copy
14
- from diffusers import ControlNetModel, StableDiffusionXLPipeline
15
- from insightface.app import FaceAnalysis
16
- import gradio as gr
17
- import random
18
- from PIL import Image, ImageOps
19
- from transformers import DPTFeatureExtractor, DPTForDepthEstimation
20
- from controlnet_aux import OpenposeDetector
21
- from controlnet_aux.open_pose.body import Body
22
-
23
- try:
24
- from inference.models import YOLOWorld
25
- from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
26
- from src.efficientvit.sam_model_zoo import create_sam_model
27
- import supervision as sv
28
- except:
29
- print("YoloWorld can not be load")
30
-
31
- try:
32
- from groundingdino.models import build_model
33
- from groundingdino.util import box_ops
34
- from groundingdino.util.slconfig import SLConfig
35
- from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
36
- from groundingdino.util.inference import annotate, predict
37
- from segment_anything import build_sam, SamPredictor
38
- import groundingdino.datasets.transforms as T
39
- except:
40
- print("groundingdino can not be load")
41
-
42
- from src.pipelines.instantid_pipeline import InstantidMultiConceptPipeline
43
- from src.pipelines.instantid_single_pieline import InstantidSingleConceptPipeline
44
- from src.prompt_attention.p2p_attention import AttentionReplace
45
- from src.pipelines.instantid_pipeline import revise_regionally_controlnet_forward
46
- import cv2
47
- import math
48
- import PIL.Image
49
-
50
- from gradio_demo.character_template import styles, lorapath_styles
51
- STYLE_NAMES = list(styles.keys())
52
-
53
-
54
-
55
- MAX_SEED = np.iinfo(np.int32).max
56
-
57
- title = r"""
58
- <h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models (OMG + InstantID)</h1>
59
- """
60
-
61
- description = r"""
62
- <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/kongzhecn/OMG/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<be>.<br>
63
- <a href='https://kongzhecn.github.io/omg-project/' target='_blank'><b>[Project]</b></a>.<a href='https://github.com/kongzhecn/OMG/' target='_blank'><b>[Code]</b></a>.<a href='https://arxiv.org/abs/2403.10983/' target='_blank'><b>[Arxiv]</b></a>.<br>
64
- How to use:<br>
65
- 1. Select two characters.
66
- 2. Enter a text prompt as done in normal text-to-image models.
67
- 3. Click the <b>Submit</b> button to start customizing.
68
- 4. Enjoy the generated image😊!
69
- """
70
-
71
- article = r"""
72
- ---
73
- 📝 **Citation**
74
- <br>
75
- If our work is helpful for your research or applications, please cite us via:
76
- ```bibtex
77
- @article{,
78
- title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models},
79
- author={},
80
- journal={},
81
- year={}
82
- }
83
- ```
84
- """
85
-
86
- tips = r"""
87
- ### Usage tips of OMG
88
- 1. Input text prompts to describe a man and a woman
89
- """
90
-
91
- css = '''
92
- .gradio-container {width: 85% !important}
93
- '''
94
-
95
-
96
-
97
- def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
98
- ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
99
- ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
100
- groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
101
- sam = build_sam(checkpoint=sam_checkpoint)
102
- sam.cuda()
103
- sam_predictor = SamPredictor(sam)
104
- return groundingdino_model, sam_predictor
105
-
106
- def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
107
- args = SLConfig.fromfile(ckpt_config_filename)
108
- model = build_model(args)
109
- args.device = device
110
-
111
- checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
112
- log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
113
- print("Model loaded from {} \n => {}".format(filename, log))
114
- _ = model.eval()
115
- return model
116
-
117
- def build_yolo_segment_model(sam_path, device):
118
- yolo_world = YOLOWorld(model_id="yolo_world/l")
119
- sam = EfficientViTSamPredictor(
120
- create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
121
- )
122
- return yolo_world, sam
123
-
124
- def sample_image(pipe,
125
- input_prompt,
126
- input_neg_prompt=None,
127
- generator=None,
128
- concept_models=None,
129
- num_inference_steps=50,
130
- guidance_scale=7.5,
131
- controller=None,
132
- face_app=None,
133
- image=None,
134
- stage=None,
135
- region_masks=None,
136
- controlnet_conditioning_scale=None,
137
- **extra_kargs
138
- ):
139
-
140
- if image is not None:
141
- image_condition = [image]
142
- else:
143
- image_condition = None
144
-
145
-
146
- images = pipe(
147
- prompt=input_prompt,
148
- concept_models=concept_models,
149
- negative_prompt=input_neg_prompt,
150
- generator=generator,
151
- guidance_scale=guidance_scale,
152
- num_inference_steps=num_inference_steps,
153
- cross_attention_kwargs={"scale": 0.8},
154
- controller=controller,
155
- image=image_condition,
156
- face_app=face_app,
157
- stage=stage,
158
- controlnet_conditioning_scale = controlnet_conditioning_scale,
159
- region_masks=region_masks,
160
- **extra_kargs).images
161
- return images
162
-
163
- def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
164
- image = np.asarray(image_source)
165
- return image
166
-
167
- def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
168
- transform = T.Compose(
169
- [
170
- T.RandomResize([800], max_size=1333),
171
- T.ToTensor(),
172
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
173
- ]
174
- )
175
- image = np.asarray(image_source)
176
- image_transformed, _ = transform(image_source, None)
177
- return image, image_transformed
178
-
179
- def draw_kps_multi(image_pil, kps_list, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
180
- stickwidth = 4
181
- limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
182
-
183
-
184
- w, h = image_pil.size
185
- out_img = np.zeros([h, w, 3])
186
-
187
- for kps in kps_list:
188
- kps = np.array(kps)
189
- for i in range(len(limbSeq)):
190
- index = limbSeq[i]
191
- color = color_list[index[0]]
192
-
193
- x = kps[index][:, 0]
194
- y = kps[index][:, 1]
195
- length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
196
- angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
197
- polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0,
198
- 360, 1)
199
- out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
200
- out_img = (out_img * 0.6).astype(np.uint8)
201
-
202
- for idx_kp, kp in enumerate(kps):
203
- color = color_list[idx_kp]
204
- x, y = kp
205
- out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
206
-
207
- out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
208
- return out_img_pil
209
-
210
- def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
211
- if segmentType=='GroundingDINO':
212
- image_source, image = load_image_dino(image)
213
- boxes, logits, phrases = predict(
214
- model=segmentmodel,
215
- image=image,
216
- caption=TEXT_PROMPT,
217
- box_threshold=0.3,
218
- text_threshold=0.25
219
- )
220
- sam.set_image(image_source)
221
- H, W, _ = image_source.shape
222
- boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
223
-
224
- transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda()
225
- masks, _, _ = sam.predict_torch(
226
- point_coords=None,
227
- point_labels=None,
228
- boxes=transformed_boxes,
229
- multimask_output=False,
230
- )
231
- masks=masks[0].squeeze(0)
232
- else:
233
- image_source = load_image_yoloworld(image)
234
- segmentmodel.set_classes(TEXT_PROMPT)
235
- results = segmentmodel.infer(image_source, confidence=confidence)
236
- detections = sv.Detections.from_inference(results).with_nms(
237
- class_agnostic=True, threshold=threshold
238
- )
239
-
240
- masks_list = []
241
- sam.set_image(image_source, image_format="RGB")
242
- for xyxy in detections.xyxy:
243
- mask, _, _ = sam.predict(box=xyxy, multimask_output=False)
244
- masks_list.append(mask.squeeze())
245
- detections.mask = np.array(masks_list)
246
-
247
- mask_1 = []
248
- mask_2 = []
249
- for i, (class_id, confidence) in enumerate(zip(detections.class_id, detections.confidence)):
250
- if class_id==0:
251
- mask_1.append(torch.from_numpy(detections.mask[i]))
252
- if class_id==1:
253
- mask_2.append(torch.from_numpy(detections.mask[i]))
254
- if len(mask_1)==0:
255
- mask_1.append(None)
256
- if len(mask_2)==0:
257
- mask_2.append(None)
258
- if len(TEXT_PROMPT)==2:
259
- return mask_1[0], mask_2[0]
260
-
261
- return mask_1[0]
262
-
263
- def build_model_sd(pretrained_model, controlnet_path, face_adapter, device, prompts, antelopev2_path, width, height, style_lora):
264
- controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
265
- pipe = InstantidMultiConceptPipeline.from_pretrained(
266
- pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
267
-
268
- controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.},
269
- self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, width=width, height=height,
270
- dtype=torch.float16)
271
- revise_regionally_controlnet_forward(pipe.unet, controller)
272
-
273
- controlnet_concept = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
274
- pipe_concept = InstantidSingleConceptPipeline.from_pretrained(
275
- pretrained_model,
276
- controlnet=controlnet_concept,
277
- torch_dtype=torch.float16
278
- )
279
- pipe_concept.load_ip_adapter_instantid(face_adapter)
280
- pipe_concept.set_ip_adapter_scale(0.8)
281
- pipe_concept.to(device)
282
- pipe_concept.image_proj_model.to(pipe_concept._execution_device)
283
-
284
- if style_lora is not None and os.path.exists(style_lora):
285
- pipe.load_lora_weights(style_lora, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
286
- pipe_concept.load_lora_weights(style_lora, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
287
-
288
-
289
- # modify
290
- app = FaceAnalysis(name='antelopev2', root=antelopev2_path,
291
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
292
- app.prepare(ctx_id=0, det_size=(640, 640))
293
-
294
- return pipe, controller, pipe_concept, app
295
-
296
-
297
- def prepare_text(prompt, region_prompts):
298
- '''
299
- Args:
300
- prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
301
- Returns:
302
- full_prompt: subject1, attribute1 and subject2, attribute2, global text
303
- context_prompt: subject1 and subject2, global text
304
- entity_collection: [(subject1, attribute1), Location1]
305
- '''
306
- region_collection = []
307
-
308
- regions = region_prompts.split('|')
309
-
310
- for region in regions:
311
- if region == '':
312
- break
313
- prompt_region, neg_prompt_region, ref_img = region.split('-*-')
314
- prompt_region = prompt_region.replace('[', '').replace(']', '')
315
- neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
316
-
317
- region_collection.append((prompt_region, neg_prompt_region, ref_img))
318
- return (prompt, region_collection)
319
-
320
- def build_model_lora(pipe, pipe_concept, style_path, condition, condition_img):
321
- if condition == "Human pose" and condition_img is not None:
322
- controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
323
- pipe.controlnet2 = controlnet
324
- elif condition == "Canny Edge" and condition_img is not None:
325
- controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16, variant="fp16").to(device)
326
- pipe.controlnet2 = controlnet
327
- elif condition == "Depth" and condition_img is not None:
328
- controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
329
- pipe.controlnet2 = controlnet
330
-
331
- if style_path is not None and os.path.exists(style_path):
332
- pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
333
- pipe.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
334
-
335
- def resize_and_center_crop(image, output_size=(1024, 576)):
336
- width, height = image.size
337
- aspect_ratio = width / height
338
- new_height = output_size[1]
339
- new_width = int(aspect_ratio * new_height)
340
-
341
- resized_image = image.resize((new_width, new_height), Image.LANCZOS)
342
-
343
- if new_width < output_size[0] or new_height < output_size[1]:
344
- padding_color = "gray"
345
- resized_image = ImageOps.expand(resized_image,
346
- ((output_size[0] - new_width) // 2,
347
- (output_size[1] - new_height) // 2,
348
- (output_size[0] - new_width + 1) // 2,
349
- (output_size[1] - new_height + 1) // 2),
350
- fill=padding_color)
351
-
352
- left = (resized_image.width - output_size[0]) / 2
353
- top = (resized_image.height - output_size[1]) / 2
354
- right = (resized_image.width + output_size[0]) / 2
355
- bottom = (resized_image.height + output_size[1]) / 2
356
-
357
- cropped_image = resized_image.crop((left, top, right, bottom))
358
-
359
- return cropped_image
360
-
361
- def main(device, segment_type):
362
- pipe, controller, pipe_concepts, face_app = build_model_sd(args.pretrained_model, args.controlnet_path,
363
- args.face_adapter_path, device, prompts_tmp,
364
- args.antelopev2_path, width // 32, height // 32,
365
- args.style_lora)
366
- if segment_type == 'GroundingDINO':
367
- detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
368
- else:
369
- detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
370
-
371
- resolution_list = ["1440*728",
372
- "1344*768",
373
- "1216*832",
374
- "1152*896",
375
- "1024*1024",
376
- "896*1152",
377
- "832*1216",
378
- "768*1344",
379
- "728*1440"]
380
- ratio_list = [1440 / 728, 1344 / 768, 1216 / 832, 1152 / 896, 1024 / 1024, 896 / 1152, 832 / 1216, 768 / 1344,
381
- 728 / 1440]
382
- condition_list = ["None",
383
- "Human pose",
384
- "Canny Edge",
385
- "Depth"]
386
-
387
- depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
388
- feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
389
- body_model = Body(args.pose_detector_checkpoint)
390
- openpose = OpenposeDetector(body_model)
391
-
392
- prompts_rewrite = [args.prompt_rewrite]
393
- input_prompt_test = [prepare_text(p, p_w) for p, p_w in zip(prompts, prompts_rewrite)]
394
- input_prompt_test = [prompts, input_prompt_test[0][1]]
395
-
396
- def remove_tips():
397
- return gr.update(visible=False)
398
-
399
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
400
- if randomize_seed:
401
- seed = random.randint(0, MAX_SEED)
402
- return seed
403
-
404
- def get_humanpose(img):
405
- openpose_image = openpose(img)
406
- return openpose_image
407
-
408
- def get_cannyedge(image):
409
- image = np.array(image)
410
- image = cv2.Canny(image, 100, 200)
411
- image = image[:, :, None]
412
- image = np.concatenate([image, image, image], axis=2)
413
- canny_image = Image.fromarray(image)
414
- return canny_image
415
-
416
- def get_depth(image):
417
- image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
418
- with torch.no_grad(), torch.autocast("cuda"):
419
- depth_map = depth_estimator(image).predicted_depth
420
-
421
- depth_map = torch.nn.functional.interpolate(
422
- depth_map.unsqueeze(1),
423
- size=(1024, 1024),
424
- mode="bicubic",
425
- align_corners=False,
426
- )
427
- depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
428
- depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
429
- depth_map = (depth_map - depth_min) / (depth_max - depth_min)
430
- image = torch.cat([depth_map] * 3, dim=1)
431
- image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
432
- image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
433
- return image
434
-
435
- def generate_image(prompt1, negative_prompt, reference_1, reference_2, resolution, local_prompt1, local_prompt2, seed, style, identitynet_strength_ratio, adapter_strength_ratio, condition, condition_img, controlnet_ratio):
436
- identitynet_strength_ratio = float(identitynet_strength_ratio)
437
- adapter_strength_ratio = float(adapter_strength_ratio)
438
- controlnet_ratio = float(controlnet_ratio)
439
- if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
440
- styleL = True
441
- else:
442
- styleL = False
443
-
444
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
445
- kwargs = {
446
- 'height': height,
447
- 'width': width,
448
- 't2i_controlnet_conditioning_scale': controlnet_ratio,
449
- }
450
-
451
- if condition == 'Human pose' and condition_img is not None:
452
- index = ratio_list.index(
453
- min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
454
- resolution = resolution_list[index]
455
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
456
- kwargs['height'] = height
457
- kwargs['width'] = width
458
- condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
459
- spatial_condition = get_humanpose(condition_img)
460
- elif condition == 'Canny Edge' and condition_img is not None:
461
- index = ratio_list.index(
462
- min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
463
- resolution = resolution_list[index]
464
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
465
- kwargs['height'] = height
466
- kwargs['width'] = width
467
- condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
468
- spatial_condition = get_cannyedge(condition_img)
469
- elif condition == 'Depth' and condition_img is not None:
470
- index = ratio_list.index(
471
- min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
472
- resolution = resolution_list[index]
473
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
474
- kwargs['height'] = height
475
- kwargs['width'] = width
476
- condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
477
- spatial_condition = get_depth(condition_img)
478
- else:
479
- spatial_condition = None
480
-
481
- kwargs['t2i_image'] = spatial_condition
482
- pipe.unload_lora_weights()
483
- pipe_concepts.unload_lora_weights()
484
- build_model_lora(pipe, pipe_concepts, lorapath_styles[style], condition, condition_img)
485
- pipe_concepts.set_ip_adapter_scale(adapter_strength_ratio)
486
-
487
- input_list = [prompt1]
488
-
489
-
490
- for prompt in input_list:
491
- if prompt != '':
492
- input_prompt = []
493
- p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
494
- if styleL:
495
- p = styles[style] + p
496
- input_prompt.append([p.replace('{prompt}', prompt), p.replace("{prompt}", prompt)])
497
- if styleL:
498
- input_prompt.append([(styles[style] + local_prompt1, 'noisy, blurry, soft, deformed, ugly',
499
- PIL.Image.fromarray(reference_1)),
500
- (styles[style] + local_prompt2, 'noisy, blurry, soft, deformed, ugly',
501
- PIL.Image.fromarray(reference_2))])
502
- else:
503
- input_prompt.append(
504
- [(local_prompt1, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_1)),
505
- (local_prompt2, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_2))])
506
-
507
-
508
- controller.reset()
509
- image = sample_image(
510
- pipe,
511
- input_prompt=input_prompt,
512
- concept_models=pipe_concepts,
513
- input_neg_prompt=[negative_prompt] * len(input_prompt),
514
- generator=torch.Generator(device).manual_seed(seed),
515
- controller=controller,
516
- face_app=face_app,
517
- controlnet_conditioning_scale=identitynet_strength_ratio,
518
- stage=1,
519
- **kwargs)
520
-
521
- controller.reset()
522
-
523
- if (pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]) and (
524
- pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]):
525
- mask1, mask2 = predict_mask(detect_model, sam, image[0], ['man', 'woman'], args.segment_type, confidence=0.3,
526
- threshold=0.5)
527
-
528
- elif pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
529
- mask1 = predict_mask(detect_model, sam, image[0], ['man'], args.segment_type, confidence=0.3,
530
- threshold=0.5)
531
- mask2 = None
532
-
533
- elif pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
534
- mask2 = predict_mask(detect_model, sam, image[0], ['woman'], args.segment_type, confidence=0.3,
535
- threshold=0.5)
536
- mask1 = None
537
- else:
538
- mask1 = mask2 = None
539
-
540
- if mask1 is not None or mask2 is not None:
541
- face_info = face_app.get(cv2.cvtColor(np.array(image[0]), cv2.COLOR_RGB2BGR))
542
- face_kps = draw_kps_multi(image[0], [face['kps'] for face in face_info])
543
-
544
- image = sample_image(
545
- pipe,
546
- input_prompt=input_prompt,
547
- concept_models=pipe_concepts,
548
- input_neg_prompt=[negative_prompt] * len(input_prompt),
549
- generator=torch.Generator(device).manual_seed(seed),
550
- controller=controller,
551
- face_app=face_app,
552
- image=face_kps,
553
- stage=2,
554
- controlnet_conditioning_scale=identitynet_strength_ratio,
555
- region_masks=[mask1, mask2],
556
- **kwargs)
557
-
558
- # return [image[1], spatial_condition]
559
- return image
560
-
561
- with gr.Blocks(css=css) as demo:
562
- # description
563
- gr.Markdown(title)
564
- gr.Markdown(description)
565
-
566
- with gr.Row():
567
- gallery = gr.Image(label="Generated Images", height=512, width=512)
568
- gallery1 = gr.Image(label="Generated Images", height=512, width=512)
569
- usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
570
-
571
-
572
- with gr.Row():
573
- reference_1 = gr.Image(label="Input an RGB image for Character man", height=128, width=128)
574
- reference_2 = gr.Image(label="Input an RGB image for Character woman", height=128, width=128)
575
- condition_img1 = gr.Image(label="Input an RGB image for condition (Optional)", height=128, width=128)
576
-
577
-
578
-
579
-
580
- with gr.Row():
581
- local_prompt1 = gr.Textbox(label="Character1_prompt",
582
- info="Describe the Character 1",
583
- value="Close-up photo of the a man, 35mm photograph, professional, 4k, highly detailed.")
584
- local_prompt2 = gr.Textbox(label="Character2_prompt",
585
- info="Describe the Character 2",
586
- value="Close-up photo of the a woman, 35mm photograph, professional, 4k, highly detailed.")
587
- with gr.Row():
588
- identitynet_strength_ratio = gr.Slider(
589
- label="IdentityNet strength (for fidelity)",
590
- minimum=0,
591
- maximum=1.5,
592
- step=0.05,
593
- value=0.80,
594
- )
595
- adapter_strength_ratio = gr.Slider(
596
- label="Image adapter strength (for detail)",
597
- minimum=0,
598
- maximum=1.5,
599
- step=0.05,
600
- value=0.80,
601
- )
602
- controlnet_ratio = gr.Slider(
603
- label="ControlNet strength",
604
- minimum=0,
605
- maximum=1.5,
606
- step=0.05,
607
- value=1,
608
- )
609
- resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list,
610
- value="1024*1024")
611
- style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
612
- condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
613
-
614
-
615
- # prompt
616
- with gr.Column():
617
- prompt = gr.Textbox(label="Prompt 1",
618
- info="Give a simple prompt to describe the first image content",
619
- placeholder="Required",
620
- value="close-up shot, photography, a man and a woman on the street, facing the camera smiling")
621
-
622
-
623
- with gr.Accordion(open=False, label="Advanced Options"):
624
- seed = gr.Slider(
625
- label="Seed",
626
- minimum=0,
627
- maximum=MAX_SEED,
628
- step=1,
629
- value=42,
630
- )
631
- negative_prompt = gr.Textbox(label="Negative Prompt",
632
- placeholder="noisy, blurry, soft, deformed, ugly",
633
- value="noisy, blurry, soft, deformed, ugly")
634
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
635
-
636
- submit = gr.Button("Submit", variant="primary")
637
-
638
- submit.click(
639
- fn=remove_tips,
640
- outputs=usage_tips,
641
- ).then(
642
- fn=randomize_seed_fn,
643
- inputs=[seed, randomize_seed],
644
- outputs=seed,
645
- queue=False,
646
- api_name=False,
647
- ).then(
648
- fn=generate_image,
649
- inputs=[prompt, negative_prompt, reference_1, reference_2, resolution, local_prompt1, local_prompt2, seed, style, identitynet_strength_ratio, adapter_strength_ratio, condition, condition_img1, controlnet_ratio],
650
- outputs=[gallery, gallery1]
651
- )
652
- demo.launch(server_name='0.0.0.0',server_port=7861, debug=True)
653
-
654
- def parse_args():
655
- parser = argparse.ArgumentParser('', add_help=False)
656
- parser.add_argument('--pretrained_model', default='/home/data1/kz_dir/checkpoint/YamerMIX_v8', type=str)
657
- parser.add_argument('--controlnet_path', default='../checkpoint/InstantID/ControlNetModel', type=str)
658
- parser.add_argument('--face_adapter_path', default='../checkpoint/InstantID/ip-adapter.bin', type=str)
659
- parser.add_argument('--openpose_checkpoint', default='../checkpoint/controlnet-openpose-sdxl-1.0', type=str)
660
- parser.add_argument('--canny_checkpoint', default='../checkpoint/controlnet-canny-sdxl-1.0', type=str)
661
- parser.add_argument('--depth_checkpoint', default='../checkpoint/controlnet-depth-sdxl-1.0', type=str)
662
- parser.add_argument('--dpt_checkpoint', default='../checkpoint/dpt-hybrid-midas', type=str)
663
- parser.add_argument('--pose_detector_checkpoint',
664
- default='../checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
665
- parser.add_argument('--efficientViT_checkpoint', default='../checkpoint/sam/xl1.pt', type=str)
666
- parser.add_argument('--dino_checkpoint', default='../checkpoint/GroundingDINO', type=str)
667
- parser.add_argument('--sam_checkpoint', default='../checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
668
- parser.add_argument('--antelopev2_path', default='../checkpoint/antelopev2', type=str)
669
- parser.add_argument('--save_dir', default='results/instantID', type=str)
670
- parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
671
- parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
672
- parser.add_argument('--prompt_rewrite',
673
- default='[Close-up photo of a man, 35mm photograph, professional, 4k, highly detailed.]-*'
674
- '-[noisy, blurry, soft, deformed, ugly]-*-'
675
- '../example/chris-evans.jpg|'
676
- '[Close-up photo of a woman, 35mm photograph, professional, 4k, highly detailed.]-'
677
- '*-[noisy, blurry, soft, deformed, ugly]-*-'
678
- '../example/TaylorSwift.png',
679
- type=str)
680
- parser.add_argument('--seed', default=0, type=int)
681
- parser.add_argument('--suffix', default='', type=str)
682
- parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
683
- parser.add_argument('--style_lora', default='', type=str)
684
- return parser.parse_args()
685
-
686
- if __name__ == '__main__':
687
- args = parse_args()
688
-
689
- prompts = [args.prompt] * 2
690
-
691
- prompts_tmp = copy.deepcopy(prompts)
692
-
693
- width, height = 1024, 1024
694
- kwargs = {
695
- 'height': height,
696
- 'width': width,
697
- }
698
-
699
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
700
- main(device, args.segment_type)
701
-