|
import torch |
|
import torch.nn as nn |
|
|
|
from diffusers import AutoencoderKL, UNet2DConditionModel |
|
from diffusers.image_processor import VaeImageProcessor |
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
class SEPath(nn.Module): |
|
def __init__(self, in_channels, out_channels, reduction=16): |
|
super(SEPath, self).__init__() |
|
self.fc = nn.Sequential( |
|
nn.Linear(in_channels, in_channels // reduction, bias=False), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(in_channels // reduction, out_channels, bias=False), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, in_tensor, out_tensor): |
|
B, C, H, W = in_tensor.size() |
|
|
|
x = in_tensor.view(B, C, -1).mean(dim=2) |
|
|
|
x = self.fc(x).unsqueeze(2).unsqueeze(2) |
|
|
|
return out_tensor * x |
|
|
|
class SeResVaeConfig(PretrainedConfig): |
|
model_type = "seresvae" |
|
def __init__( |
|
self, |
|
base_model="stabilityai/stable-diffusion-2-1", |
|
height=512, |
|
width=512, |
|
**kwargs |
|
): |
|
self.base_model=base_model |
|
self.height=height |
|
self.width=width |
|
super().__init__(**kwargs) |
|
|
|
class SeResVaeModel(PreTrainedModel): |
|
config_class = SeResVaeConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.image_processor = VaeImageProcessor() |
|
self.vae = AutoencoderKL.from_pretrained(config.base_model, subfolder='vae') |
|
self.unet = UNet2DConditionModel.from_pretrained(config.base_model, subfolder='unet') |
|
self.se_paths = nn.ModuleList([SEPath(8,4), SEPath(512,512), SEPath(512,512), SEPath(256,512), SEPath(128,256)]) |
|
self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024)) |
|
self.height=config.height |
|
self.width=config.width |
|
|
|
def forward(self, images_gray, input_type='pil', output_type='pil'): |
|
if input_type=='pil': |
|
images_gray = self.image_processor.preprocess(images_gray, height=self.height, width=self.width).float() |
|
elif input_type=='pt': |
|
images_gray=images_gray |
|
else: |
|
raise ValueError('unsupported input_type') |
|
images_gray = images_gray.to(self.vae.device) |
|
B, C, H, W = images_gray.shape |
|
prompt_embeds = self.prompt_embeds.repeat(B,1,1) |
|
|
|
posterior, encode_residual = self.encode_with_residual(images_gray) |
|
latents = posterior.mode() |
|
t = torch.LongTensor([500]).repeat(B).to(self.vae.device) |
|
noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)[0] |
|
denoised_latents = latents - noise_pred |
|
images_rgb = self.decode_with_residual(denoised_latents, *encode_residual) |
|
|
|
if output_type=='pil': |
|
images_rgb = self.image_processor.postprocess(images_rgb) |
|
elif output_type=='np': |
|
images_rgb = self.image_processor.postprocess(images_rgb, 'np') |
|
elif output_type=='pt': |
|
images_rgb = self.image_processor.postprocess(images_rgb, 'pt') |
|
elif output_type=='none': |
|
images_rgb = images_rgb |
|
else: |
|
raise ValueError('unsupported output_type') |
|
|
|
return images_rgb |
|
|
|
def encode_with_residual(self, sample): |
|
re = self.vae.encoder.conv_in(sample) |
|
re0, re0_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[0], re) |
|
re1, re1_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[1], re0) |
|
re2, re2_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[2], re1) |
|
re3, re3_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[3], re2) |
|
rem = self.vae.encoder.mid_block(re3) |
|
re_out = self.vae.encoder.conv_norm_out(rem) |
|
re_out = self.vae.encoder.conv_act(re_out) |
|
re_out = self.vae.encoder.conv_out(re_out) |
|
re_out = self.vae.quant_conv(re_out) |
|
|
|
posterior = DiagonalGaussianDistribution(re_out) |
|
return posterior, (re0_out, re1_out, re2_out, rem, re_out) |
|
|
|
def decode_with_residual(self, z, re0_out, re1_out, re2_out, rem, re_out): |
|
rd = self.vae.post_quant_conv(self.se_paths[0](re_out, z)) |
|
rd = self.vae.decoder.conv_in(rd) |
|
rdm = self.vae.decoder.mid_block(self.se_paths[1](rem, rd)).to(torch.float32) |
|
rd0 = self.vae.decoder.up_blocks[0](rdm) |
|
rd1 = self.vae.decoder.up_blocks[1](self.se_paths[2](re2_out, rd0)) |
|
rd2 = self.vae.decoder.up_blocks[2](self.se_paths[3](re1_out, rd1)) |
|
rd3 = self.vae.decoder.up_blocks[3](self.se_paths[4](re0_out, rd2)) |
|
rd_out = self.vae.decoder.conv_norm_out(rd3) |
|
rd_out = self.vae.decoder.conv_act(rd_out) |
|
sample_out = self.vae.decoder.conv_out(rd_out) |
|
return sample_out |
|
|
|
def _DownEncoderBlock2D_res_forward(self, down_encoder_block_2d, hidden_states): |
|
for resnet in down_encoder_block_2d.resnets: |
|
hidden_states = resnet(hidden_states, temb=None) |
|
|
|
output_states = hidden_states |
|
if down_encoder_block_2d.downsamplers is not None: |
|
for downsampler in down_encoder_block_2d.downsamplers: |
|
hidden_states = downsampler(hidden_states) |
|
|
|
return hidden_states, output_states |
|
|
|
|
|
|