File size: 6,049 Bytes
54c1f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0

import torch
from pathlib import Path
import numpy as np
from diffusers import ControlNetModel, EulerDiscreteScheduler
from diffusers.loaders.unet import UNet2DConditionLoadersMixin

from .pipeline_idpatch_sd_xl import StableDiffusionXLIDPatchPipeline

class IDPatchInferencer:
    def __init__(self, base_model_path, idp_model_path, patch_size=64, torch_device='cuda:0', torch_dtype=torch.bfloat16):
        super().__init__()
        self.patch_size = patch_size
        self.torch_device = torch_device
        self.torch_dtype = torch_dtype
        idp_state_dict = torch.load(Path(idp_model_path) / 'id-patch.bin', map_location="cpu")
        loader = UNet2DConditionLoadersMixin()
        self.id_patch_projection = loader._convert_ip_adapter_image_proj_to_diffusers(idp_state_dict['patch_proj']).to(self.torch_device, dtype=self.torch_dtype).eval()
        self.id_prompt_projection = loader._convert_ip_adapter_image_proj_to_diffusers(idp_state_dict['prompt_proj']).to(self.torch_device, dtype=self.torch_dtype).eval()
        controlnet = ControlNetModel.from_pretrained(Path(idp_model_path) / 'ControlNetModel').to(self.torch_device, dtype=self.torch_dtype).eval()
        scheduler = EulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
        self.pipe = StableDiffusionXLIDPatchPipeline.from_pretrained(
            base_model_path,
            controlnet=controlnet,
            scheduler=scheduler,
            torch_dtype=self.torch_dtype,
        ).to(self.torch_device)
        
    def get_text_embeds_from_strings(self, text_strings):
        pipe = self.pipe
        device = pipe.device
        tokenizer_1 = pipe.tokenizer
        tokenizer_2 = pipe.tokenizer_2
        text_encoder_1 = pipe.text_encoder
        text_encoder_2 = pipe.text_encoder_2
        
        text_embeds = []
        for tokenizer, text_encoder in [(tokenizer_1, text_encoder_1), (tokenizer_2, text_encoder_2)]:
            input_ids = tokenizer(
                text_strings,
                max_length=tokenizer.model_max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            ).input_ids.to(device)
            text_embeds.append(text_encoder(input_ids, output_hidden_states=True))
        pooled_embeds = text_embeds[1]['text_embeds']
        text_embeds = torch.concat([text_embeds[0]['hidden_states'][-2], text_embeds[1]['hidden_states'][-2]], dim=2)
        return text_embeds, pooled_embeds
    
    def generate(self, face_embeds, face_locations, control_image, prompt, negative_prompt="", guidance_scale=5.0, num_inference_steps=50, controlnet_conditioning_scale=0.8, id_injection_ratio=0.8, seed=-1):
        """
            face_embeds: n_faces x 512
            face_locations: n_faces x 2[xy]
            control_image: PIL image
        """
        
        face_locations = face_locations.to(self.torch_device, self.torch_dtype)
        control_image = torch.from_numpy(np.array(control_image)).to(self.torch_device, dtype=self.torch_dtype).permute(2,0,1)[None] / 255.0
        height, width = control_image.shape[2:4]
        
        text_embeds, pooled_embeds = self.get_text_embeds_from_strings([negative_prompt, prompt]) # text_embeds: 2 x 77 x 2048, pooled_embeds: 2 x 1280
        negative_pooled_embeds, pooled_embeds = pooled_embeds[:1], pooled_embeds[1:]
        negative_text_embeds, text_embeds = text_embeds[:1], text_embeds[1:]
        
        n_faces = len(face_embeds)
        negative_id_embeds = self.id_prompt_projection(torch.zeros(n_faces, 1, 512, device=self.torch_device, dtype=self.torch_dtype)) # (BxF) x 16 x 2048
        negative_id_embeds = negative_id_embeds.reshape(1, -1, negative_id_embeds.shape[2]) # B x (Fx16) x 2048
        negative_text_id_embeds = torch.concat([negative_text_embeds, negative_id_embeds], dim=1)
        
        face_embeds = face_embeds[None].to(self.torch_device, self.torch_dtype) # 1 x faces x 512
        id_embeds = self.id_prompt_projection(face_embeds.reshape(-1, 1, 512)) # (BxF) x 16 x 2048
        id_embeds = id_embeds.reshape(face_embeds.shape[0], -1, id_embeds.shape[2]) # B x (Fx16) x 2048
        text_id_embeds = torch.concat([text_embeds, id_embeds], dim=1) # B x (77+Fx16) x 2048
        
        patch_prompt_embeds = self.id_patch_projection(face_embeds.reshape(-1, 1, 512)) # (Bxn_faces) x 3 x (64*64)
        patch_prompt_embeds = patch_prompt_embeds.reshape(1, n_faces, 3, self.patch_size, self.patch_size)
        pad = self.patch_size // 2
        canvas = torch.zeros((1, 3, height + pad * 2, width + pad * 2), device=self.torch_device)
        
        xymin = torch.round(face_locations - self.patch_size // 2).int()
        xymax =torch.round(face_locations + self.patch_size // 2).int()
        for f in range(n_faces):
            xmin, ymin = xymin[f,0], xymin[f,1]
            xmax, ymax = xymax[f,0], xymax[f,1]
            if xmin+pad < 0 or xmax-pad >= width or ymin+pad < 0 or ymax-pad >= height:
                continue
            canvas[0,:,ymin+pad:ymax+pad,xmin+pad:xmax+pad] += patch_prompt_embeds[0,f]
        condition_image = control_image + canvas[:,:,pad:-pad,pad:-pad]
        
        if seed >= 0:
            generator = torch.Generator(self.torch_device).manual_seed(seed)
        else:
            generator = None
        output_image = self.pipe(
            prompt_embeds=text_id_embeds,
            pooled_prompt_embeds=pooled_embeds,
            negative_prompt_embeds=negative_text_id_embeds,
            negative_pooled_prompt_embeds=negative_pooled_embeds,
            image=condition_image,
            guidance_scale=guidance_scale,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            num_inference_steps=num_inference_steps,
            id_injection_ratio=id_injection_ratio,
            output_type='pil',
            generator=generator,
        ).images[0]
        return output_image