# Copyright 2024 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: Apache-2.0 import torch from pathlib import Path import numpy as np from diffusers import ControlNetModel, EulerDiscreteScheduler from diffusers.loaders.unet import UNet2DConditionLoadersMixin from .pipeline_idpatch_sd_xl import StableDiffusionXLIDPatchPipeline class IDPatchInferencer: def __init__(self, base_model_path, idp_model_path, patch_size=64, torch_device='cuda:0', torch_dtype=torch.bfloat16): super().__init__() self.patch_size = patch_size self.torch_device = torch_device self.torch_dtype = torch_dtype idp_state_dict = torch.load(Path(idp_model_path) / 'id-patch.bin', map_location="cpu") loader = UNet2DConditionLoadersMixin() self.id_patch_projection = loader._convert_ip_adapter_image_proj_to_diffusers(idp_state_dict['patch_proj']).to(self.torch_device, dtype=self.torch_dtype).eval() self.id_prompt_projection = loader._convert_ip_adapter_image_proj_to_diffusers(idp_state_dict['prompt_proj']).to(self.torch_device, dtype=self.torch_dtype).eval() controlnet = ControlNetModel.from_pretrained(Path(idp_model_path) / 'ControlNetModel').to(self.torch_device, dtype=self.torch_dtype).eval() scheduler = EulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") self.pipe = StableDiffusionXLIDPatchPipeline.from_pretrained( base_model_path, controlnet=controlnet, scheduler=scheduler, torch_dtype=self.torch_dtype, ).to(self.torch_device) def get_text_embeds_from_strings(self, text_strings): pipe = self.pipe device = pipe.device tokenizer_1 = pipe.tokenizer tokenizer_2 = pipe.tokenizer_2 text_encoder_1 = pipe.text_encoder text_encoder_2 = pipe.text_encoder_2 text_embeds = [] for tokenizer, text_encoder in [(tokenizer_1, text_encoder_1), (tokenizer_2, text_encoder_2)]: input_ids = tokenizer( text_strings, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids.to(device) text_embeds.append(text_encoder(input_ids, output_hidden_states=True)) pooled_embeds = text_embeds[1]['text_embeds'] text_embeds = torch.concat([text_embeds[0]['hidden_states'][-2], text_embeds[1]['hidden_states'][-2]], dim=2) return text_embeds, pooled_embeds def generate(self, face_embeds, face_locations, control_image, prompt, negative_prompt="", guidance_scale=5.0, num_inference_steps=50, controlnet_conditioning_scale=0.8, id_injection_ratio=0.8, seed=-1): """ face_embeds: n_faces x 512 face_locations: n_faces x 2[xy] control_image: PIL image """ face_locations = face_locations.to(self.torch_device, self.torch_dtype) control_image = torch.from_numpy(np.array(control_image)).to(self.torch_device, dtype=self.torch_dtype).permute(2,0,1)[None] / 255.0 height, width = control_image.shape[2:4] text_embeds, pooled_embeds = self.get_text_embeds_from_strings([negative_prompt, prompt]) # text_embeds: 2 x 77 x 2048, pooled_embeds: 2 x 1280 negative_pooled_embeds, pooled_embeds = pooled_embeds[:1], pooled_embeds[1:] negative_text_embeds, text_embeds = text_embeds[:1], text_embeds[1:] n_faces = len(face_embeds) negative_id_embeds = self.id_prompt_projection(torch.zeros(n_faces, 1, 512, device=self.torch_device, dtype=self.torch_dtype)) # (BxF) x 16 x 2048 negative_id_embeds = negative_id_embeds.reshape(1, -1, negative_id_embeds.shape[2]) # B x (Fx16) x 2048 negative_text_id_embeds = torch.concat([negative_text_embeds, negative_id_embeds], dim=1) face_embeds = face_embeds[None].to(self.torch_device, self.torch_dtype) # 1 x faces x 512 id_embeds = self.id_prompt_projection(face_embeds.reshape(-1, 1, 512)) # (BxF) x 16 x 2048 id_embeds = id_embeds.reshape(face_embeds.shape[0], -1, id_embeds.shape[2]) # B x (Fx16) x 2048 text_id_embeds = torch.concat([text_embeds, id_embeds], dim=1) # B x (77+Fx16) x 2048 patch_prompt_embeds = self.id_patch_projection(face_embeds.reshape(-1, 1, 512)) # (Bxn_faces) x 3 x (64*64) patch_prompt_embeds = patch_prompt_embeds.reshape(1, n_faces, 3, self.patch_size, self.patch_size) pad = self.patch_size // 2 canvas = torch.zeros((1, 3, height + pad * 2, width + pad * 2), device=self.torch_device) xymin = torch.round(face_locations - self.patch_size // 2).int() xymax =torch.round(face_locations + self.patch_size // 2).int() for f in range(n_faces): xmin, ymin = xymin[f,0], xymin[f,1] xmax, ymax = xymax[f,0], xymax[f,1] if xmin+pad < 0 or xmax-pad >= width or ymin+pad < 0 or ymax-pad >= height: continue canvas[0,:,ymin+pad:ymax+pad,xmin+pad:xmax+pad] += patch_prompt_embeds[0,f] condition_image = control_image + canvas[:,:,pad:-pad,pad:-pad] if seed >= 0: generator = torch.Generator(self.torch_device).manual_seed(seed) else: generator = None output_image = self.pipe( prompt_embeds=text_id_embeds, pooled_prompt_embeds=pooled_embeds, negative_prompt_embeds=negative_text_id_embeds, negative_pooled_prompt_embeds=negative_pooled_embeds, image=condition_image, guidance_scale=guidance_scale, controlnet_conditioning_scale=controlnet_conditioning_scale, num_inference_steps=num_inference_steps, id_injection_ratio=id_injection_ratio, output_type='pil', generator=generator, ).images[0] return output_image