File size: 5,378 Bytes
f6018b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
import torch.nn as nn
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from transformers import PretrainedConfig, PreTrainedModel
class SEPath(nn.Module):
def __init__(self, in_channels, out_channels, reduction=16):
super(SEPath, self).__init__()
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, out_channels, bias=False),
nn.Sigmoid()
)
def forward(self, in_tensor, out_tensor):
B, C, H, W = in_tensor.size()
# Squeeze operation
x = in_tensor.view(B, C, -1).mean(dim=2)
# Excitation operation
x = self.fc(x).unsqueeze(2).unsqueeze(2)
return out_tensor * x
class SeResVaeConfig(PretrainedConfig):
model_type = "seresvae"
def __init__(
self,
base_model="stabilityai/stable-diffusion-2-1",
height=512,
width=512,
**kwargs
):
self.base_model=base_model
self.height=height
self.width=width
super().__init__(**kwargs)
class SeResVaeModel(PreTrainedModel):
config_class = SeResVaeConfig
def __init__(self, config):
super().__init__(config)
self.image_processor = VaeImageProcessor()
self.vae = AutoencoderKL.from_pretrained(config.base_model, subfolder='vae')
self.unet = UNet2DConditionModel.from_pretrained(config.base_model, subfolder='unet')
self.se_paths = nn.ModuleList([SEPath(8,4), SEPath(512,512), SEPath(512,512), SEPath(256,512), SEPath(128,256)])
self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
self.height=config.height
self.width=config.width
def forward(self, images_gray, input_type='pil', output_type='pil'):
if input_type=='pil':
images_gray = self.image_processor.preprocess(images_gray, height=self.height, width=self.width).float()
elif input_type=='pt':
images_gray=images_gray
else:
raise ValueError('unsupported input_type')
images_gray = images_gray.to(self.vae.device)
B, C, H, W = images_gray.shape
prompt_embeds = self.prompt_embeds.repeat(B,1,1)
posterior, encode_residual = self.encode_with_residual(images_gray)
latents = posterior.mode()
t = torch.LongTensor([500]).repeat(B).to(self.vae.device)
noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)[0]
denoised_latents = latents - noise_pred
images_rgb = self.decode_with_residual(denoised_latents, *encode_residual)
if output_type=='pil':
images_rgb = self.image_processor.postprocess(images_rgb)
elif output_type=='np':
images_rgb = self.image_processor.postprocess(images_rgb, 'np')
elif output_type=='pt':
images_rgb = self.image_processor.postprocess(images_rgb, 'pt')
elif output_type=='none':
images_rgb = images_rgb
else:
raise ValueError('unsupported output_type')
return images_rgb
def encode_with_residual(self, sample):
re = self.vae.encoder.conv_in(sample)
re0, re0_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[0], re)
re1, re1_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[1], re0)
re2, re2_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[2], re1)
re3, re3_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[3], re2)
rem = self.vae.encoder.mid_block(re3)
re_out = self.vae.encoder.conv_norm_out(rem)
re_out = self.vae.encoder.conv_act(re_out)
re_out = self.vae.encoder.conv_out(re_out)
re_out = self.vae.quant_conv(re_out)
posterior = DiagonalGaussianDistribution(re_out)
return posterior, (re0_out, re1_out, re2_out, rem, re_out)
def decode_with_residual(self, z, re0_out, re1_out, re2_out, rem, re_out):
rd = self.vae.post_quant_conv(self.se_paths[0](re_out, z))
rd = self.vae.decoder.conv_in(rd)
rdm = self.vae.decoder.mid_block(self.se_paths[1](rem, rd)).to(torch.float32)
rd0 = self.vae.decoder.up_blocks[0](rdm)
rd1 = self.vae.decoder.up_blocks[1](self.se_paths[2](re2_out, rd0))
rd2 = self.vae.decoder.up_blocks[2](self.se_paths[3](re1_out, rd1))
rd3 = self.vae.decoder.up_blocks[3](self.se_paths[4](re0_out, rd2))
rd_out = self.vae.decoder.conv_norm_out(rd3)
rd_out = self.vae.decoder.conv_act(rd_out)
sample_out = self.vae.decoder.conv_out(rd_out)
return sample_out
def _DownEncoderBlock2D_res_forward(self, down_encoder_block_2d, hidden_states):
for resnet in down_encoder_block_2d.resnets:
hidden_states = resnet(hidden_states, temb=None)
output_states = hidden_states
if down_encoder_block_2d.downsamplers is not None:
for downsampler in down_encoder_block_2d.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states, output_states
|