File size: 5,378 Bytes
f6018b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn

from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from transformers import PretrainedConfig, PreTrainedModel

class SEPath(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=16):
        super(SEPath, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, out_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, in_tensor, out_tensor):
        B, C, H, W = in_tensor.size()
        # Squeeze operation
        x = in_tensor.view(B, C, -1).mean(dim=2)
        # Excitation operation
        x = self.fc(x).unsqueeze(2).unsqueeze(2)

        return out_tensor * x

class SeResVaeConfig(PretrainedConfig):
    model_type = "seresvae"
    def __init__(
        self,
        base_model="stabilityai/stable-diffusion-2-1",
        height=512,
        width=512,
        **kwargs
    ):
        self.base_model=base_model
        self.height=height
        self.width=width
        super().__init__(**kwargs)

class SeResVaeModel(PreTrainedModel):
    config_class = SeResVaeConfig
    def __init__(self, config):
        super().__init__(config)
        self.image_processor = VaeImageProcessor()
        self.vae = AutoencoderKL.from_pretrained(config.base_model, subfolder='vae')
        self.unet = UNet2DConditionModel.from_pretrained(config.base_model, subfolder='unet')
        self.se_paths = nn.ModuleList([SEPath(8,4), SEPath(512,512), SEPath(512,512), SEPath(256,512), SEPath(128,256)])
        self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
        self.height=config.height
        self.width=config.width

    def forward(self, images_gray, input_type='pil', output_type='pil'):
        if input_type=='pil':
            images_gray = self.image_processor.preprocess(images_gray, height=self.height, width=self.width).float()
        elif input_type=='pt':
            images_gray=images_gray
        else:
            raise ValueError('unsupported input_type')
        images_gray = images_gray.to(self.vae.device)
        B, C, H, W = images_gray.shape
        prompt_embeds = self.prompt_embeds.repeat(B,1,1)

        posterior, encode_residual = self.encode_with_residual(images_gray)
        latents = posterior.mode()
        t = torch.LongTensor([500]).repeat(B).to(self.vae.device)
        noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)[0]
        denoised_latents = latents - noise_pred
        images_rgb = self.decode_with_residual(denoised_latents, *encode_residual)
        
        if output_type=='pil':
            images_rgb = self.image_processor.postprocess(images_rgb)
        elif output_type=='np':
            images_rgb = self.image_processor.postprocess(images_rgb, 'np')
        elif output_type=='pt':
            images_rgb = self.image_processor.postprocess(images_rgb, 'pt')
        elif output_type=='none':
            images_rgb = images_rgb
        else:
            raise ValueError('unsupported output_type')

        return images_rgb

    def encode_with_residual(self, sample):
        re = self.vae.encoder.conv_in(sample)
        re0, re0_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[0], re)
        re1, re1_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[1], re0)
        re2, re2_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[2], re1)
        re3, re3_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[3], re2)
        rem = self.vae.encoder.mid_block(re3)
        re_out = self.vae.encoder.conv_norm_out(rem)
        re_out = self.vae.encoder.conv_act(re_out)
        re_out = self.vae.encoder.conv_out(re_out)
        re_out = self.vae.quant_conv(re_out)
        
        posterior = DiagonalGaussianDistribution(re_out)
        return posterior, (re0_out, re1_out, re2_out, rem, re_out)

    def decode_with_residual(self, z, re0_out, re1_out, re2_out, rem, re_out):
        rd = self.vae.post_quant_conv(self.se_paths[0](re_out, z))
        rd = self.vae.decoder.conv_in(rd)
        rdm = self.vae.decoder.mid_block(self.se_paths[1](rem, rd)).to(torch.float32)
        rd0 = self.vae.decoder.up_blocks[0](rdm)
        rd1 = self.vae.decoder.up_blocks[1](self.se_paths[2](re2_out, rd0))
        rd2 = self.vae.decoder.up_blocks[2](self.se_paths[3](re1_out, rd1))
        rd3 = self.vae.decoder.up_blocks[3](self.se_paths[4](re0_out, rd2))
        rd_out = self.vae.decoder.conv_norm_out(rd3)
        rd_out = self.vae.decoder.conv_act(rd_out)
        sample_out = self.vae.decoder.conv_out(rd_out)
        return sample_out

    def _DownEncoderBlock2D_res_forward(self, down_encoder_block_2d, hidden_states):
        for resnet in down_encoder_block_2d.resnets:
            hidden_states = resnet(hidden_states, temb=None)
    
        output_states = hidden_states
        if down_encoder_block_2d.downsamplers is not None:
            for downsampler in down_encoder_block_2d.downsamplers:
                hidden_states = downsampler(hidden_states)
    
        return hidden_states, output_states