File size: 4,806 Bytes
1d24639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import List
import torch
from torchvision import transforms
from transformers import CLIPImageProcessor
from transformers import CLIPVisionModel as OriginalCLIPVisionModel
from ._clip import CLIPVisionModel
from PIL import Image
import torch.nn.functional as F
import torch.nn as nn
import os

def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")
if is_torch2_available():
    from .attention_processor import SSRAttnProcessor2_0 as SSRAttnProcessor, AttnProcessor2_0 as AttnProcessor
else:
    from .attention_processor import SSRAttnProcessor, AttnProcessor
from .resampler import Resampler

class detail_encoder(torch.nn.Module):
    """from SSR-encoder"""
    def __init__(self, unet, image_encoder_path, device="cuda", dtype=torch.float32):
        super().__init__()
        self.device = device
        self.dtype = dtype

        # load image encoder
        clip_encoder = OriginalCLIPVisionModel.from_pretrained(image_encoder_path)
        self.image_encoder = CLIPVisionModel(clip_encoder.config)
        state_dict = clip_encoder.state_dict()
        self.image_encoder.load_state_dict(state_dict, strict=False)
        self.image_encoder.to(self.device, self.dtype)
        del clip_encoder
        self.clip_image_processor = CLIPImageProcessor()

        # load SSR layers
        attn_procs = {}
        for name in unet.attn_processors.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]
            if cross_attention_dim is None:
                attn_procs[name] = AttnProcessor()
            else:
                attn_procs[name] = SSRAttnProcessor(hidden_size=hidden_size, cross_attention_dim=1024, scale=1).to(self.device, dtype=self.dtype)
        unet.set_attn_processor(attn_procs)
        adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
        self.SSR_layers = adapter_modules
        self.SSR_layers.to(self.device, dtype=self.dtype)
        self.resampler = self.init_proj()

    def init_proj(self):
        resampler = Resampler().to(self.device, dtype=self.dtype)
        return resampler

    def forward(self, img):
        image_embeds = self.image_encoder(img, output_hidden_states=True)['hidden_states'][2::2]
        image_embeds = torch.cat(image_embeds, dim=1)
        image_embeds = self.resampler(image_embeds)
        return image_embeds

    @torch.inference_mode()
    def get_image_embeds(self, pil_image):
        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]
        clip_image = []
        for pil in pil_image:
            tensor_image = self.clip_image_processor(images=pil, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype)
            clip_image.append(tensor_image)
        clip_image = torch.cat(clip_image, dim=0)

        # cond
        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True)['hidden_states'][2::2]  # 1 257*12 1024
        clip_image_embeds = torch.cat(clip_image_embeds, dim=1)
        uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True)['hidden_states'][2::2]
        uncond_clip_image_embeds = torch.cat(uncond_clip_image_embeds, dim=1)
        clip_image_embeds = self.resampler(clip_image_embeds)
        uncond_clip_image_embeds = self.resampler(uncond_clip_image_embeds)
        return clip_image_embeds, uncond_clip_image_embeds

    def generate(
            self,
            id_image,
            makeup_image,
            seed=None,
            guidance_scale=2,
            num_inference_steps=30,
            pipe=None,
            **kwargs,
    ):
        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(makeup_image)

        prompt_embeds = image_prompt_embeds
        negative_prompt_embeds = uncond_image_prompt_embeds

        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
        image = pipe(
            image=id_image,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            **kwargs,
        ).images[0]

        return image