Fucius commited on
Commit
548d42d
·
verified ·
1 Parent(s): 2eafbc4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +708 -0
app.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import sys
3
+ import os
4
+ import torch
5
+ torch.jit.script = lambda f: f
6
+
7
+ import sys
8
+ sys.path.append('./')
9
+ import argparse
10
+ import hashlib
11
+ import json
12
+ import os.path
13
+ import numpy as np
14
+ import torch
15
+ from typing import Tuple, List
16
+ from diffusers import DPMSolverMultistepScheduler
17
+ from diffusers.models import T2IAdapter
18
+ from PIL import Image
19
+ import copy
20
+ from diffusers import ControlNetModel, StableDiffusionXLPipeline
21
+ from insightface.app import FaceAnalysis
22
+ import gradio as gr
23
+ import random
24
+ from PIL import Image, ImageOps
25
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
26
+ from controlnet_aux import OpenposeDetector
27
+ from controlnet_aux.open_pose.body import Body
28
+
29
+ try:
30
+ from inference.models import YOLOWorld
31
+ from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
32
+ from src.efficientvit.sam_model_zoo import create_sam_model
33
+ import supervision as sv
34
+ except:
35
+ print("YoloWorld can not be load")
36
+
37
+ try:
38
+ from groundingdino.models import build_model
39
+ from groundingdino.util import box_ops
40
+ from groundingdino.util.slconfig import SLConfig
41
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
42
+ from groundingdino.util.inference import annotate, predict
43
+ from segment_anything import build_sam, SamPredictor
44
+ import groundingdino.datasets.transforms as T
45
+ except:
46
+ print("groundingdino can not be load")
47
+
48
+ from src.pipelines.instantid_pipeline import InstantidMultiConceptPipeline
49
+ from src.pipelines.instantid_single_pieline import InstantidSingleConceptPipeline
50
+ from src.prompt_attention.p2p_attention import AttentionReplace
51
+ from src.pipelines.instantid_pipeline import revise_regionally_controlnet_forward
52
+ import cv2
53
+ import math
54
+ import PIL.Image
55
+
56
+ from gradio_demo.character_template import styles, lorapath_styles
57
+ STYLE_NAMES = list(styles.keys())
58
+
59
+
60
+
61
+ MAX_SEED = np.iinfo(np.int32).max
62
+
63
+ title = r"""
64
+ <h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models (OMG + InstantID)</h1>
65
+ """
66
+
67
+ description = r"""
68
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/kongzhecn/OMG/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<be>.<br>
69
+ <a href='https://kongzhecn.github.io/omg-project/' target='_blank'><b>[Project]</b></a>.<a href='https://github.com/kongzhecn/OMG/' target='_blank'><b>[Code]</b></a>.<a href='https://arxiv.org/abs/2403.10983/' target='_blank'><b>[Arxiv]</b></a>.<br>
70
+ How to use:<br>
71
+ 1. Select two characters.
72
+ 2. Enter a text prompt as done in normal text-to-image models.
73
+ 3. Click the <b>Submit</b> button to start customizing.
74
+ 4. Enjoy the generated image😊!
75
+ """
76
+
77
+ article = r"""
78
+ ---
79
+ 📝 **Citation**
80
+ <br>
81
+ If our work is helpful for your research or applications, please cite us via:
82
+ ```bibtex
83
+ @article{,
84
+ title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models},
85
+ author={},
86
+ journal={},
87
+ year={}
88
+ }
89
+ ```
90
+ """
91
+
92
+ tips = r"""
93
+ ### Usage tips of OMG
94
+ 1. Input text prompts to describe a man and a woman
95
+ """
96
+
97
+ css = '''
98
+ .gradio-container {width: 85% !important}
99
+ '''
100
+
101
+
102
+
103
+ def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
104
+ ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
105
+ ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
106
+ groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
107
+ sam = build_sam(checkpoint=sam_checkpoint)
108
+ sam.cuda()
109
+ sam_predictor = SamPredictor(sam)
110
+ return groundingdino_model, sam_predictor
111
+
112
+ def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
113
+ args = SLConfig.fromfile(ckpt_config_filename)
114
+ model = build_model(args)
115
+ args.device = device
116
+
117
+ checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
118
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
119
+ print("Model loaded from {} \n => {}".format(filename, log))
120
+ _ = model.eval()
121
+ return model
122
+
123
+ def build_yolo_segment_model(sam_path, device):
124
+ yolo_world = YOLOWorld(model_id="yolo_world/l")
125
+ sam = EfficientViTSamPredictor(
126
+ create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
127
+ )
128
+ return yolo_world, sam
129
+
130
+ def sample_image(pipe,
131
+ input_prompt,
132
+ input_neg_prompt=None,
133
+ generator=None,
134
+ concept_models=None,
135
+ num_inference_steps=50,
136
+ guidance_scale=7.5,
137
+ controller=None,
138
+ face_app=None,
139
+ image=None,
140
+ stage=None,
141
+ region_masks=None,
142
+ controlnet_conditioning_scale=None,
143
+ **extra_kargs
144
+ ):
145
+
146
+ if image is not None:
147
+ image_condition = [image]
148
+ else:
149
+ image_condition = None
150
+
151
+
152
+ images = pipe(
153
+ prompt=input_prompt,
154
+ concept_models=concept_models,
155
+ negative_prompt=input_neg_prompt,
156
+ generator=generator,
157
+ guidance_scale=guidance_scale,
158
+ num_inference_steps=num_inference_steps,
159
+ cross_attention_kwargs={"scale": 0.8},
160
+ controller=controller,
161
+ image=image_condition,
162
+ face_app=face_app,
163
+ stage=stage,
164
+ controlnet_conditioning_scale = controlnet_conditioning_scale,
165
+ region_masks=region_masks,
166
+ **extra_kargs).images
167
+ return images
168
+
169
+ def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
170
+ image = np.asarray(image_source)
171
+ return image
172
+
173
+ def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
174
+ transform = T.Compose(
175
+ [
176
+ T.RandomResize([800], max_size=1333),
177
+ T.ToTensor(),
178
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
179
+ ]
180
+ )
181
+ image = np.asarray(image_source)
182
+ image_transformed, _ = transform(image_source, None)
183
+ return image, image_transformed
184
+
185
+ def draw_kps_multi(image_pil, kps_list, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
186
+ stickwidth = 4
187
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
188
+
189
+
190
+ w, h = image_pil.size
191
+ out_img = np.zeros([h, w, 3])
192
+
193
+ for kps in kps_list:
194
+ kps = np.array(kps)
195
+ for i in range(len(limbSeq)):
196
+ index = limbSeq[i]
197
+ color = color_list[index[0]]
198
+
199
+ x = kps[index][:, 0]
200
+ y = kps[index][:, 1]
201
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
202
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
203
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0,
204
+ 360, 1)
205
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
206
+ out_img = (out_img * 0.6).astype(np.uint8)
207
+
208
+ for idx_kp, kp in enumerate(kps):
209
+ color = color_list[idx_kp]
210
+ x, y = kp
211
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
212
+
213
+ out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
214
+ return out_img_pil
215
+
216
+ def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
217
+ if segmentType=='GroundingDINO':
218
+ image_source, image = load_image_dino(image)
219
+ boxes, logits, phrases = predict(
220
+ model=segmentmodel,
221
+ image=image,
222
+ caption=TEXT_PROMPT,
223
+ box_threshold=0.3,
224
+ text_threshold=0.25
225
+ )
226
+ sam.set_image(image_source)
227
+ H, W, _ = image_source.shape
228
+ boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
229
+
230
+ transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda()
231
+ masks, _, _ = sam.predict_torch(
232
+ point_coords=None,
233
+ point_labels=None,
234
+ boxes=transformed_boxes,
235
+ multimask_output=False,
236
+ )
237
+ masks=masks[0].squeeze(0)
238
+ else:
239
+ image_source = load_image_yoloworld(image)
240
+ segmentmodel.set_classes(TEXT_PROMPT)
241
+ results = segmentmodel.infer(image_source, confidence=confidence)
242
+ detections = sv.Detections.from_inference(results).with_nms(
243
+ class_agnostic=True, threshold=threshold
244
+ )
245
+
246
+ masks_list = []
247
+ sam.set_image(image_source, image_format="RGB")
248
+ for xyxy in detections.xyxy:
249
+ mask, _, _ = sam.predict(box=xyxy, multimask_output=False)
250
+ masks_list.append(mask.squeeze())
251
+ detections.mask = np.array(masks_list)
252
+
253
+ mask_1 = []
254
+ mask_2 = []
255
+ for i, (class_id, confidence) in enumerate(zip(detections.class_id, detections.confidence)):
256
+ if class_id==0:
257
+ mask_1.append(torch.from_numpy(detections.mask[i]))
258
+ if class_id==1:
259
+ mask_2.append(torch.from_numpy(detections.mask[i]))
260
+ if len(mask_1)==0:
261
+ mask_1.append(None)
262
+ if len(mask_2)==0:
263
+ mask_2.append(None)
264
+ if len(TEXT_PROMPT)==2:
265
+ return mask_1[0], mask_2[0]
266
+
267
+ return mask_1[0]
268
+
269
+ def build_model_sd(pretrained_model, controlnet_path, face_adapter, device, prompts, antelopev2_path, width, height, style_lora):
270
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
271
+ pipe = InstantidMultiConceptPipeline.from_pretrained(
272
+ pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
273
+
274
+ controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.},
275
+ self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, width=width, height=height,
276
+ dtype=torch.float16)
277
+ revise_regionally_controlnet_forward(pipe.unet, controller)
278
+
279
+ controlnet_concept = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
280
+ pipe_concept = InstantidSingleConceptPipeline.from_pretrained(
281
+ pretrained_model,
282
+ controlnet=controlnet_concept,
283
+ torch_dtype=torch.float16
284
+ )
285
+ pipe_concept.load_ip_adapter_instantid(face_adapter)
286
+ pipe_concept.set_ip_adapter_scale(0.8)
287
+ pipe_concept.to(device)
288
+ pipe_concept.image_proj_model.to(pipe_concept._execution_device)
289
+
290
+ if style_lora is not None and os.path.exists(style_lora):
291
+ pipe.load_lora_weights(style_lora, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
292
+ pipe_concept.load_lora_weights(style_lora, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
293
+
294
+
295
+ # modify
296
+ app = FaceAnalysis(name='antelopev2', root=antelopev2_path,
297
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
298
+ app.prepare(ctx_id=0, det_size=(640, 640))
299
+
300
+ return pipe, controller, pipe_concept, app
301
+
302
+
303
+ def prepare_text(prompt, region_prompts):
304
+ '''
305
+ Args:
306
+ prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
307
+ Returns:
308
+ full_prompt: subject1, attribute1 and subject2, attribute2, global text
309
+ context_prompt: subject1 and subject2, global text
310
+ entity_collection: [(subject1, attribute1), Location1]
311
+ '''
312
+ region_collection = []
313
+
314
+ regions = region_prompts.split('|')
315
+
316
+ for region in regions:
317
+ if region == '':
318
+ break
319
+ prompt_region, neg_prompt_region, ref_img = region.split('-*-')
320
+ prompt_region = prompt_region.replace('[', '').replace(']', '')
321
+ neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
322
+
323
+ region_collection.append((prompt_region, neg_prompt_region, ref_img))
324
+ return (prompt, region_collection)
325
+
326
+ def build_model_lora(pipe, pipe_concept, style_path, condition, condition_img):
327
+ if condition == "Human pose" and condition_img is not None:
328
+ controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
329
+ pipe.controlnet2 = controlnet
330
+ elif condition == "Canny Edge" and condition_img is not None:
331
+ controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16, variant="fp16").to(device)
332
+ pipe.controlnet2 = controlnet
333
+ elif condition == "Depth" and condition_img is not None:
334
+ controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
335
+ pipe.controlnet2 = controlnet
336
+
337
+ if style_path is not None and os.path.exists(style_path):
338
+ pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
339
+ pipe.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
340
+
341
+ def resize_and_center_crop(image, output_size=(1024, 576)):
342
+ width, height = image.size
343
+ aspect_ratio = width / height
344
+ new_height = output_size[1]
345
+ new_width = int(aspect_ratio * new_height)
346
+
347
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
348
+
349
+ if new_width < output_size[0] or new_height < output_size[1]:
350
+ padding_color = "gray"
351
+ resized_image = ImageOps.expand(resized_image,
352
+ ((output_size[0] - new_width) // 2,
353
+ (output_size[1] - new_height) // 2,
354
+ (output_size[0] - new_width + 1) // 2,
355
+ (output_size[1] - new_height + 1) // 2),
356
+ fill=padding_color)
357
+
358
+ left = (resized_image.width - output_size[0]) / 2
359
+ top = (resized_image.height - output_size[1]) / 2
360
+ right = (resized_image.width + output_size[0]) / 2
361
+ bottom = (resized_image.height + output_size[1]) / 2
362
+
363
+ cropped_image = resized_image.crop((left, top, right, bottom))
364
+
365
+ return cropped_image
366
+
367
+ def main(device, segment_type):
368
+ pipe, controller, pipe_concepts, face_app = build_model_sd(args.pretrained_model, args.controlnet_path,
369
+ args.face_adapter_path, device, prompts_tmp,
370
+ args.antelopev2_path, width // 32, height // 32,
371
+ args.style_lora)
372
+ if segment_type == 'GroundingDINO':
373
+ detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
374
+ else:
375
+ detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
376
+
377
+ resolution_list = ["1440*728",
378
+ "1344*768",
379
+ "1216*832",
380
+ "1152*896",
381
+ "1024*1024",
382
+ "896*1152",
383
+ "832*1216",
384
+ "768*1344",
385
+ "728*1440"]
386
+ ratio_list = [1440 / 728, 1344 / 768, 1216 / 832, 1152 / 896, 1024 / 1024, 896 / 1152, 832 / 1216, 768 / 1344,
387
+ 728 / 1440]
388
+ condition_list = ["None",
389
+ "Human pose",
390
+ "Canny Edge",
391
+ "Depth"]
392
+
393
+ depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
394
+ feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
395
+ body_model = Body(args.pose_detector_checkpoint)
396
+ openpose = OpenposeDetector(body_model)
397
+
398
+ prompts_rewrite = [args.prompt_rewrite]
399
+ input_prompt_test = [prepare_text(p, p_w) for p, p_w in zip(prompts, prompts_rewrite)]
400
+ input_prompt_test = [prompts, input_prompt_test[0][1]]
401
+
402
+ def remove_tips():
403
+ return gr.update(visible=False)
404
+
405
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
406
+ if randomize_seed:
407
+ seed = random.randint(0, MAX_SEED)
408
+ return seed
409
+
410
+ def get_humanpose(img):
411
+ openpose_image = openpose(img)
412
+ return openpose_image
413
+
414
+ def get_cannyedge(image):
415
+ image = np.array(image)
416
+ image = cv2.Canny(image, 100, 200)
417
+ image = image[:, :, None]
418
+ image = np.concatenate([image, image, image], axis=2)
419
+ canny_image = Image.fromarray(image)
420
+ return canny_image
421
+
422
+ def get_depth(image):
423
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
424
+ with torch.no_grad(), torch.autocast("cuda"):
425
+ depth_map = depth_estimator(image).predicted_depth
426
+
427
+ depth_map = torch.nn.functional.interpolate(
428
+ depth_map.unsqueeze(1),
429
+ size=(1024, 1024),
430
+ mode="bicubic",
431
+ align_corners=False,
432
+ )
433
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
434
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
435
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
436
+ image = torch.cat([depth_map] * 3, dim=1)
437
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
438
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
439
+ return image
440
+
441
+ @spaces.GPU
442
+ def generate_image(prompt1, negative_prompt, reference_1, reference_2, resolution, local_prompt1, local_prompt2, seed, style, identitynet_strength_ratio, adapter_strength_ratio, condition, condition_img, controlnet_ratio):
443
+ identitynet_strength_ratio = float(identitynet_strength_ratio)
444
+ adapter_strength_ratio = float(adapter_strength_ratio)
445
+ controlnet_ratio = float(controlnet_ratio)
446
+ if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
447
+ styleL = True
448
+ else:
449
+ styleL = False
450
+
451
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
452
+ kwargs = {
453
+ 'height': height,
454
+ 'width': width,
455
+ 't2i_controlnet_conditioning_scale': controlnet_ratio,
456
+ }
457
+
458
+ if condition == 'Human pose' and condition_img is not None:
459
+ index = ratio_list.index(
460
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
461
+ resolution = resolution_list[index]
462
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
463
+ kwargs['height'] = height
464
+ kwargs['width'] = width
465
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
466
+ spatial_condition = get_humanpose(condition_img)
467
+ elif condition == 'Canny Edge' and condition_img is not None:
468
+ index = ratio_list.index(
469
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
470
+ resolution = resolution_list[index]
471
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
472
+ kwargs['height'] = height
473
+ kwargs['width'] = width
474
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
475
+ spatial_condition = get_cannyedge(condition_img)
476
+ elif condition == 'Depth' and condition_img is not None:
477
+ index = ratio_list.index(
478
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
479
+ resolution = resolution_list[index]
480
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
481
+ kwargs['height'] = height
482
+ kwargs['width'] = width
483
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
484
+ spatial_condition = get_depth(condition_img)
485
+ else:
486
+ spatial_condition = None
487
+
488
+ kwargs['t2i_image'] = spatial_condition
489
+ pipe.unload_lora_weights()
490
+ pipe_concepts.unload_lora_weights()
491
+ build_model_lora(pipe, pipe_concepts, lorapath_styles[style], condition, condition_img)
492
+ pipe_concepts.set_ip_adapter_scale(adapter_strength_ratio)
493
+
494
+ input_list = [prompt1]
495
+
496
+
497
+ for prompt in input_list:
498
+ if prompt != '':
499
+ input_prompt = []
500
+ p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
501
+ if styleL:
502
+ p = styles[style] + p
503
+ input_prompt.append([p.replace('{prompt}', prompt), p.replace("{prompt}", prompt)])
504
+ if styleL:
505
+ input_prompt.append([(styles[style] + local_prompt1, 'noisy, blurry, soft, deformed, ugly',
506
+ PIL.Image.fromarray(reference_1)),
507
+ (styles[style] + local_prompt2, 'noisy, blurry, soft, deformed, ugly',
508
+ PIL.Image.fromarray(reference_2))])
509
+ else:
510
+ input_prompt.append(
511
+ [(local_prompt1, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_1)),
512
+ (local_prompt2, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_2))])
513
+
514
+
515
+ controller.reset()
516
+ image = sample_image(
517
+ pipe,
518
+ input_prompt=input_prompt,
519
+ concept_models=pipe_concepts,
520
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
521
+ generator=torch.Generator(device).manual_seed(seed),
522
+ controller=controller,
523
+ face_app=face_app,
524
+ controlnet_conditioning_scale=identitynet_strength_ratio,
525
+ stage=1,
526
+ **kwargs)
527
+
528
+ controller.reset()
529
+
530
+ if (pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]) and (
531
+ pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]):
532
+ mask1, mask2 = predict_mask(detect_model, sam, image[0], ['man', 'woman'], args.segment_type, confidence=0.3,
533
+ threshold=0.5)
534
+
535
+ elif pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
536
+ mask1 = predict_mask(detect_model, sam, image[0], ['man'], args.segment_type, confidence=0.3,
537
+ threshold=0.5)
538
+ mask2 = None
539
+
540
+ elif pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
541
+ mask2 = predict_mask(detect_model, sam, image[0], ['woman'], args.segment_type, confidence=0.3,
542
+ threshold=0.5)
543
+ mask1 = None
544
+ else:
545
+ mask1 = mask2 = None
546
+
547
+ if mask1 is not None or mask2 is not None:
548
+ face_info = face_app.get(cv2.cvtColor(np.array(image[0]), cv2.COLOR_RGB2BGR))
549
+ face_kps = draw_kps_multi(image[0], [face['kps'] for face in face_info])
550
+
551
+ image = sample_image(
552
+ pipe,
553
+ input_prompt=input_prompt,
554
+ concept_models=pipe_concepts,
555
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
556
+ generator=torch.Generator(device).manual_seed(seed),
557
+ controller=controller,
558
+ face_app=face_app,
559
+ image=face_kps,
560
+ stage=2,
561
+ controlnet_conditioning_scale=identitynet_strength_ratio,
562
+ region_masks=[mask1, mask2],
563
+ **kwargs)
564
+
565
+ # return [image[1], spatial_condition]
566
+ return image
567
+
568
+ with gr.Blocks(css=css) as demo:
569
+ # description
570
+ gr.Markdown(title)
571
+ gr.Markdown(description)
572
+
573
+ with gr.Row():
574
+ gallery = gr.Image(label="Generated Images", height=512, width=512)
575
+ gallery1 = gr.Image(label="Generated Images", height=512, width=512)
576
+ usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
577
+
578
+
579
+ with gr.Row():
580
+ reference_1 = gr.Image(label="Input an RGB image for Character man", height=128, width=128)
581
+ reference_2 = gr.Image(label="Input an RGB image for Character woman", height=128, width=128)
582
+ condition_img1 = gr.Image(label="Input an RGB image for condition (Optional)", height=128, width=128)
583
+
584
+
585
+
586
+
587
+ with gr.Row():
588
+ local_prompt1 = gr.Textbox(label="Character1_prompt",
589
+ info="Describe the Character 1",
590
+ value="Close-up photo of the a man, 35mm photograph, professional, 4k, highly detailed.")
591
+ local_prompt2 = gr.Textbox(label="Character2_prompt",
592
+ info="Describe the Character 2",
593
+ value="Close-up photo of the a woman, 35mm photograph, professional, 4k, highly detailed.")
594
+ with gr.Row():
595
+ identitynet_strength_ratio = gr.Slider(
596
+ label="IdentityNet strength (for fidelity)",
597
+ minimum=0,
598
+ maximum=1.5,
599
+ step=0.05,
600
+ value=0.80,
601
+ )
602
+ adapter_strength_ratio = gr.Slider(
603
+ label="Image adapter strength (for detail)",
604
+ minimum=0,
605
+ maximum=1.5,
606
+ step=0.05,
607
+ value=0.80,
608
+ )
609
+ controlnet_ratio = gr.Slider(
610
+ label="ControlNet strength",
611
+ minimum=0,
612
+ maximum=1.5,
613
+ step=0.05,
614
+ value=1,
615
+ )
616
+ resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list,
617
+ value="1024*1024")
618
+ style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
619
+ condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
620
+
621
+
622
+ # prompt
623
+ with gr.Column():
624
+ prompt = gr.Textbox(label="Prompt 1",
625
+ info="Give a simple prompt to describe the first image content",
626
+ placeholder="Required",
627
+ value="close-up shot, photography, a man and a woman on the street, facing the camera smiling")
628
+
629
+
630
+ with gr.Accordion(open=False, label="Advanced Options"):
631
+ seed = gr.Slider(
632
+ label="Seed",
633
+ minimum=0,
634
+ maximum=MAX_SEED,
635
+ step=1,
636
+ value=42,
637
+ )
638
+ negative_prompt = gr.Textbox(label="Negative Prompt",
639
+ placeholder="noisy, blurry, soft, deformed, ugly",
640
+ value="noisy, blurry, soft, deformed, ugly")
641
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
642
+
643
+ submit = gr.Button("Submit", variant="primary")
644
+
645
+ submit.click(
646
+ fn=remove_tips,
647
+ outputs=usage_tips,
648
+ ).then(
649
+ fn=randomize_seed_fn,
650
+ inputs=[seed, randomize_seed],
651
+ outputs=seed,
652
+ queue=False,
653
+ api_name=False,
654
+ ).then(
655
+ fn=generate_image,
656
+ inputs=[prompt, negative_prompt, reference_1, reference_2, resolution, local_prompt1, local_prompt2, seed, style, identitynet_strength_ratio, adapter_strength_ratio, condition, condition_img1, controlnet_ratio],
657
+ outputs=[gallery, gallery1]
658
+ )
659
+ demo.launch(server_name='0.0.0.0',server_port=7861, debug=True)
660
+
661
+ def parse_args():
662
+ parser = argparse.ArgumentParser('', add_help=False)
663
+ parser.add_argument('--pretrained_model', default='/home/data1/kz_dir/checkpoint/YamerMIX_v8', type=str)
664
+ parser.add_argument('--controlnet_path', default='../checkpoint/InstantID/ControlNetModel', type=str)
665
+ parser.add_argument('--face_adapter_path', default='../checkpoint/InstantID/ip-adapter.bin', type=str)
666
+ parser.add_argument('--openpose_checkpoint', default='../checkpoint/controlnet-openpose-sdxl-1.0', type=str)
667
+ parser.add_argument('--canny_checkpoint', default='../checkpoint/controlnet-canny-sdxl-1.0', type=str)
668
+ parser.add_argument('--depth_checkpoint', default='../checkpoint/controlnet-depth-sdxl-1.0', type=str)
669
+ parser.add_argument('--dpt_checkpoint', default='../checkpoint/dpt-hybrid-midas', type=str)
670
+ parser.add_argument('--pose_detector_checkpoint',
671
+ default='../checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
672
+ parser.add_argument('--efficientViT_checkpoint', default='../checkpoint/sam/xl1.pt', type=str)
673
+ parser.add_argument('--dino_checkpoint', default='../checkpoint/GroundingDINO', type=str)
674
+ parser.add_argument('--sam_checkpoint', default='../checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
675
+ parser.add_argument('--antelopev2_path', default='../checkpoint/antelopev2', type=str)
676
+ parser.add_argument('--save_dir', default='results/instantID', type=str)
677
+ parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
678
+ parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
679
+ parser.add_argument('--prompt_rewrite',
680
+ default='[Close-up photo of a man, 35mm photograph, professional, 4k, highly detailed.]-*'
681
+ '-[noisy, blurry, soft, deformed, ugly]-*-'
682
+ '../example/chris-evans.jpg|'
683
+ '[Close-up photo of a woman, 35mm photograph, professional, 4k, highly detailed.]-'
684
+ '*-[noisy, blurry, soft, deformed, ugly]-*-'
685
+ '../example/TaylorSwift.png',
686
+ type=str)
687
+ parser.add_argument('--seed', default=0, type=int)
688
+ parser.add_argument('--suffix', default='', type=str)
689
+ parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
690
+ parser.add_argument('--style_lora', default='', type=str)
691
+ return parser.parse_args()
692
+
693
+ if __name__ == '__main__':
694
+ args = parse_args()
695
+
696
+ prompts = [args.prompt] * 2
697
+
698
+ prompts_tmp = copy.deepcopy(prompts)
699
+
700
+ width, height = 1024, 1024
701
+ kwargs = {
702
+ 'height': height,
703
+ 'width': width,
704
+ }
705
+
706
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
707
+ main(device, args.segment_type)
708
+