Upload folder using huggingface_hub
Browse files- __init__.py +0 -0
- gray-inpaint/config.json +15 -0
- gray-inpaint/model.safetensors +3 -0
- gray-inpaint/modeling_sd_gray_inpaint.py +98 -0
- gray2rgb/config.json +15 -0
- gray2rgb/model.safetensors +3 -0
- gray2rgb/modeling_seresvae.py +124 -0
- modeling_sd_gray_inpaint.py +98 -0
- modeling_seresvae.py +124 -0
__init__.py
ADDED
File without changes
|
gray-inpaint/config.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"SDGrayInpaintModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "modeling_sd_gray_inpaint.SDGrayInpaintConfig",
|
7 |
+
"AutoModel": "modeling_sd_gray_inpaint.SDGrayInpaintModel"
|
8 |
+
},
|
9 |
+
"base_model": "stabilityai/stable-diffusion-2-inpainting",
|
10 |
+
"height": 512,
|
11 |
+
"model_type": "sd_gray_inpaint",
|
12 |
+
"torch_dtype": "float32",
|
13 |
+
"transformers_version": "4.46.3",
|
14 |
+
"width": 512
|
15 |
+
}
|
gray-inpaint/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2c6d964dca7f33a3a87e90056e8ab617efeabd99e3cfcea71f73d459b133f231
|
3 |
+
size 4055354432
|
gray-inpaint/modeling_sd_gray_inpaint.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from copy import deepcopy
|
6 |
+
from torchvision.transforms.functional import rgb_to_grayscale
|
7 |
+
import segmentation_models_pytorch as smp
|
8 |
+
from diffusers import StableDiffusionInpaintPipeline
|
9 |
+
from diffusers.utils.torch_utils import randn_tensor
|
10 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
11 |
+
|
12 |
+
class SDGrayInpaintConfig(PretrainedConfig):
|
13 |
+
model_type = "sd_gray_inpaint"
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
base_model="stabilityai/stable-diffusion-2-inpainting",
|
17 |
+
height=512,
|
18 |
+
width=512,
|
19 |
+
**kwargs
|
20 |
+
):
|
21 |
+
self.base_model=base_model
|
22 |
+
self.height=height
|
23 |
+
self.width=width
|
24 |
+
super().__init__(**kwargs)
|
25 |
+
|
26 |
+
class SDGrayInpaintModel(PreTrainedModel):
|
27 |
+
config_class = SDGrayInpaintConfig
|
28 |
+
def __init__(self, config):
|
29 |
+
super().__init__(config)
|
30 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(config.base_model)
|
31 |
+
self.mask_predictor = smp.Unet(
|
32 |
+
encoder_name="mit_b4",
|
33 |
+
encoder_weights="imagenet",
|
34 |
+
in_channels=3,
|
35 |
+
classes=1,
|
36 |
+
)
|
37 |
+
self.image_processor = pipe.image_processor
|
38 |
+
self.scheduler = pipe.scheduler
|
39 |
+
self.unet = pipe.unet
|
40 |
+
self.vae = pipe.vae
|
41 |
+
self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
|
42 |
+
self.height=config.height
|
43 |
+
self.width=config.width
|
44 |
+
|
45 |
+
def forward(
|
46 |
+
self,
|
47 |
+
images_gray_masked,
|
48 |
+
masks=None,
|
49 |
+
num_inference_steps=250,
|
50 |
+
seed=42,
|
51 |
+
input_type='pil',
|
52 |
+
output_type='pil'
|
53 |
+
):
|
54 |
+
generator = torch.Generator()
|
55 |
+
generator.manual_seed(seed)
|
56 |
+
if input_type=='pil':
|
57 |
+
images_gray_masked = self.image_processor.preprocess(images_gray_masked, height=self.height, width=self.width).float()
|
58 |
+
elif input_type=='pt':
|
59 |
+
images_gray_masked=images_gray_masked
|
60 |
+
else:
|
61 |
+
raise ValueError('unsupported input_type')
|
62 |
+
images_gray_masked = images_gray_masked.to(self.vae.device)
|
63 |
+
if masks is None:
|
64 |
+
masks_logits = self.mask_predictor(images_gray_masked)
|
65 |
+
masks = (torch.sigmoid(masks_logits)>0.5)*1.
|
66 |
+
masks = masks.float().to(self.vae.device)
|
67 |
+
B, C, H, W = images_gray_masked.shape
|
68 |
+
prompt_embeds = self.prompt_embeds.repeat(B,1,1)
|
69 |
+
|
70 |
+
scheduler = deepcopy(self.scheduler)
|
71 |
+
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=self.vae.device)
|
72 |
+
masked_image_latents = self.vae.encode(images_gray_masked).latent_dist.mode() * self.vae.config.scaling_factor
|
73 |
+
mask_latents = F.interpolate(masks, size=(self.unet.config.sample_size, self.unet.config.sample_size))
|
74 |
+
latents = randn_tensor(masked_image_latents.shape, generator=generator).to(self.device) * self.scheduler.init_noise_sigma
|
75 |
+
for t in scheduler.timesteps:
|
76 |
+
latents = scheduler.scale_model_input(latents, t)
|
77 |
+
latent_model_input = torch.cat([latents, mask_latents, masked_image_latents], dim=1)
|
78 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)[0]
|
79 |
+
latents = scheduler.step(noise_pred, t, latents)[0]
|
80 |
+
latents = latents / self.vae.config.scaling_factor
|
81 |
+
images_gray_restored = self.vae.decode(latents.detach())[0]
|
82 |
+
images_gray_restored = images_gray_masked * (1-masks) + images_gray_restored.detach() * masks
|
83 |
+
images_gray_restored = rgb_to_grayscale(images_gray_restored)
|
84 |
+
|
85 |
+
if output_type=='pil':
|
86 |
+
images_gray_restored = self.image_processor.postprocess(images_gray_restored)
|
87 |
+
elif output_type=='np':
|
88 |
+
images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'np')
|
89 |
+
elif output_type=='pt':
|
90 |
+
images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'pt')
|
91 |
+
elif output_type=='none':
|
92 |
+
images_gray_restored = images_gray_restored
|
93 |
+
else:
|
94 |
+
raise ValueError('unsupported output_type')
|
95 |
+
|
96 |
+
return images_gray_restored
|
97 |
+
|
98 |
+
|
gray2rgb/config.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"SeResVaeModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "modeling_seresvae.SeResVaeConfig",
|
7 |
+
"AutoModel": "modeling_seresvae.SeResVaeModel"
|
8 |
+
},
|
9 |
+
"base_model": "stabilityai/stable-diffusion-2-1",
|
10 |
+
"height": 512,
|
11 |
+
"model_type": "seresvae",
|
12 |
+
"torch_dtype": "float32",
|
13 |
+
"transformers_version": "4.46.3",
|
14 |
+
"width": 512
|
15 |
+
}
|
gray2rgb/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:657257a0969eec19b5a3ff2629500454e1941615ae646fa72f2d88e9d41be737
|
3 |
+
size 3799014812
|
gray2rgb/modeling_seresvae.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
5 |
+
from diffusers.image_processor import VaeImageProcessor
|
6 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
7 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
8 |
+
|
9 |
+
class SEPath(nn.Module):
|
10 |
+
def __init__(self, in_channels, out_channels, reduction=16):
|
11 |
+
super(SEPath, self).__init__()
|
12 |
+
self.fc = nn.Sequential(
|
13 |
+
nn.Linear(in_channels, in_channels // reduction, bias=False),
|
14 |
+
nn.ReLU(inplace=True),
|
15 |
+
nn.Linear(in_channels // reduction, out_channels, bias=False),
|
16 |
+
nn.Sigmoid()
|
17 |
+
)
|
18 |
+
|
19 |
+
def forward(self, in_tensor, out_tensor):
|
20 |
+
B, C, H, W = in_tensor.size()
|
21 |
+
# Squeeze operation
|
22 |
+
x = in_tensor.view(B, C, -1).mean(dim=2)
|
23 |
+
# Excitation operation
|
24 |
+
x = self.fc(x).unsqueeze(2).unsqueeze(2)
|
25 |
+
|
26 |
+
return out_tensor * x
|
27 |
+
|
28 |
+
class SeResVaeConfig(PretrainedConfig):
|
29 |
+
model_type = "seresvae"
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
base_model="stabilityai/stable-diffusion-2-1",
|
33 |
+
height=512,
|
34 |
+
width=512,
|
35 |
+
**kwargs
|
36 |
+
):
|
37 |
+
self.base_model=base_model
|
38 |
+
self.height=height
|
39 |
+
self.width=width
|
40 |
+
super().__init__(**kwargs)
|
41 |
+
|
42 |
+
class SeResVaeModel(PreTrainedModel):
|
43 |
+
config_class = SeResVaeConfig
|
44 |
+
def __init__(self, config):
|
45 |
+
super().__init__(config)
|
46 |
+
self.image_processor = VaeImageProcessor()
|
47 |
+
self.vae = AutoencoderKL.from_pretrained(config.base_model, subfolder='vae')
|
48 |
+
self.unet = UNet2DConditionModel.from_pretrained(config.base_model, subfolder='unet')
|
49 |
+
self.se_paths = nn.ModuleList([SEPath(8,4), SEPath(512,512), SEPath(512,512), SEPath(256,512), SEPath(128,256)])
|
50 |
+
self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
|
51 |
+
self.height=config.height
|
52 |
+
self.width=config.width
|
53 |
+
|
54 |
+
def forward(self, images_gray, input_type='pil', output_type='pil'):
|
55 |
+
if input_type=='pil':
|
56 |
+
images_gray = self.image_processor.preprocess(images_gray, height=self.height, width=self.width).float()
|
57 |
+
elif input_type=='pt':
|
58 |
+
images_gray=images_gray
|
59 |
+
else:
|
60 |
+
raise ValueError('unsupported input_type')
|
61 |
+
images_gray = images_gray.to(self.vae.device)
|
62 |
+
B, C, H, W = images_gray.shape
|
63 |
+
prompt_embeds = self.prompt_embeds.repeat(B,1,1)
|
64 |
+
|
65 |
+
posterior, encode_residual = self.encode_with_residual(images_gray)
|
66 |
+
latents = posterior.mode()
|
67 |
+
t = torch.LongTensor([500]).repeat(B).to(self.vae.device)
|
68 |
+
noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)[0]
|
69 |
+
denoised_latents = latents - noise_pred
|
70 |
+
images_rgb = self.decode_with_residual(denoised_latents, *encode_residual)
|
71 |
+
|
72 |
+
if output_type=='pil':
|
73 |
+
images_rgb = self.image_processor.postprocess(images_rgb)
|
74 |
+
elif output_type=='np':
|
75 |
+
images_rgb = self.image_processor.postprocess(images_rgb, 'np')
|
76 |
+
elif output_type=='pt':
|
77 |
+
images_rgb = self.image_processor.postprocess(images_rgb, 'pt')
|
78 |
+
elif output_type=='none':
|
79 |
+
images_rgb = images_rgb
|
80 |
+
else:
|
81 |
+
raise ValueError('unsupported output_type')
|
82 |
+
|
83 |
+
return images_rgb
|
84 |
+
|
85 |
+
def encode_with_residual(self, sample):
|
86 |
+
re = self.vae.encoder.conv_in(sample)
|
87 |
+
re0, re0_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[0], re)
|
88 |
+
re1, re1_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[1], re0)
|
89 |
+
re2, re2_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[2], re1)
|
90 |
+
re3, re3_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[3], re2)
|
91 |
+
rem = self.vae.encoder.mid_block(re3)
|
92 |
+
re_out = self.vae.encoder.conv_norm_out(rem)
|
93 |
+
re_out = self.vae.encoder.conv_act(re_out)
|
94 |
+
re_out = self.vae.encoder.conv_out(re_out)
|
95 |
+
re_out = self.vae.quant_conv(re_out)
|
96 |
+
|
97 |
+
posterior = DiagonalGaussianDistribution(re_out)
|
98 |
+
return posterior, (re0_out, re1_out, re2_out, rem, re_out)
|
99 |
+
|
100 |
+
def decode_with_residual(self, z, re0_out, re1_out, re2_out, rem, re_out):
|
101 |
+
rd = self.vae.post_quant_conv(self.se_paths[0](re_out, z))
|
102 |
+
rd = self.vae.decoder.conv_in(rd)
|
103 |
+
rdm = self.vae.decoder.mid_block(self.se_paths[1](rem, rd)).to(torch.float32)
|
104 |
+
rd0 = self.vae.decoder.up_blocks[0](rdm)
|
105 |
+
rd1 = self.vae.decoder.up_blocks[1](self.se_paths[2](re2_out, rd0))
|
106 |
+
rd2 = self.vae.decoder.up_blocks[2](self.se_paths[3](re1_out, rd1))
|
107 |
+
rd3 = self.vae.decoder.up_blocks[3](self.se_paths[4](re0_out, rd2))
|
108 |
+
rd_out = self.vae.decoder.conv_norm_out(rd3)
|
109 |
+
rd_out = self.vae.decoder.conv_act(rd_out)
|
110 |
+
sample_out = self.vae.decoder.conv_out(rd_out)
|
111 |
+
return sample_out
|
112 |
+
|
113 |
+
def _DownEncoderBlock2D_res_forward(self, down_encoder_block_2d, hidden_states):
|
114 |
+
for resnet in down_encoder_block_2d.resnets:
|
115 |
+
hidden_states = resnet(hidden_states, temb=None)
|
116 |
+
|
117 |
+
output_states = hidden_states
|
118 |
+
if down_encoder_block_2d.downsamplers is not None:
|
119 |
+
for downsampler in down_encoder_block_2d.downsamplers:
|
120 |
+
hidden_states = downsampler(hidden_states)
|
121 |
+
|
122 |
+
return hidden_states, output_states
|
123 |
+
|
124 |
+
|
modeling_sd_gray_inpaint.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from copy import deepcopy
|
6 |
+
from torchvision.transforms.functional import rgb_to_grayscale
|
7 |
+
import segmentation_models_pytorch as smp
|
8 |
+
from diffusers import StableDiffusionInpaintPipeline
|
9 |
+
from diffusers.utils.torch_utils import randn_tensor
|
10 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
11 |
+
|
12 |
+
class SDGrayInpaintConfig(PretrainedConfig):
|
13 |
+
model_type = "sd_gray_inpaint"
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
base_model="stabilityai/stable-diffusion-2-inpainting",
|
17 |
+
height=512,
|
18 |
+
width=512,
|
19 |
+
**kwargs
|
20 |
+
):
|
21 |
+
self.base_model=base_model
|
22 |
+
self.height=height
|
23 |
+
self.width=width
|
24 |
+
super().__init__(**kwargs)
|
25 |
+
|
26 |
+
class SDGrayInpaintModel(PreTrainedModel):
|
27 |
+
config_class = SDGrayInpaintConfig
|
28 |
+
def __init__(self, config):
|
29 |
+
super().__init__(config)
|
30 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(config.base_model)
|
31 |
+
self.mask_predictor = smp.Unet(
|
32 |
+
encoder_name="mit_b4",
|
33 |
+
encoder_weights="imagenet",
|
34 |
+
in_channels=3,
|
35 |
+
classes=1,
|
36 |
+
)
|
37 |
+
self.image_processor = pipe.image_processor
|
38 |
+
self.scheduler = pipe.scheduler
|
39 |
+
self.unet = pipe.unet
|
40 |
+
self.vae = pipe.vae
|
41 |
+
self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
|
42 |
+
self.height=config.height
|
43 |
+
self.width=config.width
|
44 |
+
|
45 |
+
def forward(
|
46 |
+
self,
|
47 |
+
images_gray_masked,
|
48 |
+
masks=None,
|
49 |
+
num_inference_steps=250,
|
50 |
+
seed=42,
|
51 |
+
input_type='pil',
|
52 |
+
output_type='pil'
|
53 |
+
):
|
54 |
+
generator = torch.Generator()
|
55 |
+
generator.manual_seed(seed)
|
56 |
+
if input_type=='pil':
|
57 |
+
images_gray_masked = self.image_processor.process(images_gray_masked, height=self.height, width=self.width).float()
|
58 |
+
elif input_type=='pt':
|
59 |
+
images_gray_masked=images_gray_masked
|
60 |
+
else:
|
61 |
+
raise ValueError('unsupported input_type')
|
62 |
+
images_gray_masked = images_gray_masked.to(self.vae.device)
|
63 |
+
if masks is None:
|
64 |
+
masks_logits = self.mask_predictor(images_gray_masked)
|
65 |
+
masks = (torch.sigmoid(masks_logits)>0.5)*1.
|
66 |
+
masks = masks.float().to(self.vae.device)
|
67 |
+
B, C, H, W = images_gray_masked.shape
|
68 |
+
prompt_embeds = self.prompt_embeds.repeat(B,1,1)
|
69 |
+
|
70 |
+
scheduler = deepcopy(self.scheduler)
|
71 |
+
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=self.vae.device)
|
72 |
+
masked_image_latents = self.vae.encode(images_gray_masked).latent_dist.mode() * self.vae.config.scaling_factor
|
73 |
+
mask_latents = F.interpolate(masks, size=(self.unet.config.sample_size, self.unet.config.sample_size))
|
74 |
+
latents = randn_tensor(masked_image_latents.shape, generator=generator).to(self.device) * self.scheduler.init_noise_sigma
|
75 |
+
for t in scheduler.timesteps:
|
76 |
+
latents = scheduler.scale_model_input(latents, t)
|
77 |
+
latent_model_input = torch.cat([latents, mask_latents, masked_image_latents], dim=1)
|
78 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)[0]
|
79 |
+
latents = scheduler.step(noise_pred, t, latents)[0]
|
80 |
+
latents = latents / self.vae.config.scaling_factor
|
81 |
+
images_gray_restored = self.vae.decode(latents.detach())[0]
|
82 |
+
images_gray_restored = images_gray_masked * (1-masks) + images_gray_restored.detach() * masks
|
83 |
+
images_gray_restored = rgb_to_grayscale(images_gray_restored)
|
84 |
+
|
85 |
+
if output_type=='pil':
|
86 |
+
images_gray_restored = self.image_processor.postprocess(images_gray_restored)
|
87 |
+
elif output_type=='np':
|
88 |
+
images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'np')
|
89 |
+
elif output_type=='pt':
|
90 |
+
images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'pt')
|
91 |
+
elif output_type=='none':
|
92 |
+
images_gray_restored = images_gray_restored
|
93 |
+
else:
|
94 |
+
raise ValueError('unsupported output_type')
|
95 |
+
|
96 |
+
return images_gray_restored
|
97 |
+
|
98 |
+
|
modeling_seresvae.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
5 |
+
from diffusers.image_processor import VaeImageProcessor
|
6 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
7 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
8 |
+
|
9 |
+
class SEPath(nn.Module):
|
10 |
+
def __init__(self, in_channels, out_channels, reduction=16):
|
11 |
+
super(SEPath, self).__init__()
|
12 |
+
self.fc = nn.Sequential(
|
13 |
+
nn.Linear(in_channels, in_channels // reduction, bias=False),
|
14 |
+
nn.ReLU(inplace=True),
|
15 |
+
nn.Linear(in_channels // reduction, out_channels, bias=False),
|
16 |
+
nn.Sigmoid()
|
17 |
+
)
|
18 |
+
|
19 |
+
def forward(self, in_tensor, out_tensor):
|
20 |
+
B, C, H, W = in_tensor.size()
|
21 |
+
# Squeeze operation
|
22 |
+
x = in_tensor.view(B, C, -1).mean(dim=2)
|
23 |
+
# Excitation operation
|
24 |
+
x = self.fc(x).unsqueeze(2).unsqueeze(2)
|
25 |
+
|
26 |
+
return out_tensor * x
|
27 |
+
|
28 |
+
class SeResVaeConfig(PretrainedConfig):
|
29 |
+
model_type = "seresvae"
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
base_model="stabilityai/stable-diffusion-2-1",
|
33 |
+
height=512,
|
34 |
+
width=512,
|
35 |
+
**kwargs
|
36 |
+
):
|
37 |
+
self.base_model=base_model
|
38 |
+
self.height=height
|
39 |
+
self.width=width
|
40 |
+
super().__init__(**kwargs)
|
41 |
+
|
42 |
+
class SeResVaeModel(PreTrainedModel):
|
43 |
+
config_class = SeResVaeConfig
|
44 |
+
def __init__(self, config):
|
45 |
+
super().__init__(config)
|
46 |
+
self.image_processor = VaeImageProcessor()
|
47 |
+
self.vae = AutoencoderKL.from_pretrained(config.base_model, subfolder='vae')
|
48 |
+
self.unet = UNet2DConditionModel.from_pretrained(config.base_model, subfolder='unet')
|
49 |
+
self.se_paths = nn.ModuleList([SEPath(8,4), SEPath(512,512), SEPath(512,512), SEPath(256,512), SEPath(128,256)])
|
50 |
+
self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
|
51 |
+
self.height=config.height
|
52 |
+
self.width=config.width
|
53 |
+
|
54 |
+
def forward(self, images_gray, input_type='pil', output_type='pil'):
|
55 |
+
if input_type=='pil':
|
56 |
+
images_gray = self.image_processor.process(images_gray, height=self.height, width=self.width).float()
|
57 |
+
elif input_type=='pt':
|
58 |
+
images_gray=images_gray
|
59 |
+
else:
|
60 |
+
raise ValueError('unsupported input_type')
|
61 |
+
images_gray = images_gray.to(self.vae.device)
|
62 |
+
B, C, H, W = images_gray.shape
|
63 |
+
prompt_embeds = self.prompt_embeds.repeat(B,1,1)
|
64 |
+
|
65 |
+
posterior, encode_residual = self.encode_with_residual(images_gray)
|
66 |
+
latents = posterior.mode()
|
67 |
+
t = torch.LongTensor([500]).repeat(B).to(self.vae.device)
|
68 |
+
noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)[0]
|
69 |
+
denoised_latents = latents - noise_pred
|
70 |
+
images_rgb = self.decode_with_residual(denoised_latents, *encode_residual)
|
71 |
+
|
72 |
+
if output_type=='pil':
|
73 |
+
images_rgb = self.image_processor.postprocess(images_rgb)
|
74 |
+
elif output_type=='np':
|
75 |
+
images_rgb = self.image_processor.postprocess(images_rgb, 'np')
|
76 |
+
elif output_type=='pt':
|
77 |
+
images_rgb = self.image_processor.postprocess(images_rgb, 'pt')
|
78 |
+
elif output_type=='none':
|
79 |
+
images_rgb = images_rgb
|
80 |
+
else:
|
81 |
+
raise ValueError('unsupported output_type')
|
82 |
+
|
83 |
+
return images_rgb
|
84 |
+
|
85 |
+
def encode_with_residual(self, sample):
|
86 |
+
re = self.vae.encoder.conv_in(sample)
|
87 |
+
re0, re0_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[0], re)
|
88 |
+
re1, re1_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[1], re0)
|
89 |
+
re2, re2_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[2], re1)
|
90 |
+
re3, re3_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[3], re2)
|
91 |
+
rem = self.vae.encoder.mid_block(re3)
|
92 |
+
re_out = self.vae.encoder.conv_norm_out(rem)
|
93 |
+
re_out = self.vae.encoder.conv_act(re_out)
|
94 |
+
re_out = self.vae.encoder.conv_out(re_out)
|
95 |
+
re_out = self.vae.quant_conv(re_out)
|
96 |
+
|
97 |
+
posterior = DiagonalGaussianDistribution(re_out)
|
98 |
+
return posterior, (re0_out, re1_out, re2_out, rem, re_out)
|
99 |
+
|
100 |
+
def decode_with_residual(self, z, re0_out, re1_out, re2_out, rem, re_out):
|
101 |
+
rd = self.vae.post_quant_conv(self.se_paths[0](re_out, z))
|
102 |
+
rd = self.vae.decoder.conv_in(rd)
|
103 |
+
rdm = self.vae.decoder.mid_block(self.se_paths[1](rem, rd)).to(torch.float32)
|
104 |
+
rd0 = self.vae.decoder.up_blocks[0](rdm)
|
105 |
+
rd1 = self.vae.decoder.up_blocks[1](self.se_paths[2](re2_out, rd0))
|
106 |
+
rd2 = self.vae.decoder.up_blocks[2](self.se_paths[3](re1_out, rd1))
|
107 |
+
rd3 = self.vae.decoder.up_blocks[3](self.se_paths[4](re0_out, rd2))
|
108 |
+
rd_out = self.vae.decoder.conv_norm_out(rd3)
|
109 |
+
rd_out = self.vae.decoder.conv_act(rd_out)
|
110 |
+
sample_out = self.vae.decoder.conv_out(rd_out)
|
111 |
+
return sample_out
|
112 |
+
|
113 |
+
def _DownEncoderBlock2D_res_forward(self, down_encoder_block_2d, hidden_states):
|
114 |
+
for resnet in down_encoder_block_2d.resnets:
|
115 |
+
hidden_states = resnet(hidden_states, temb=None)
|
116 |
+
|
117 |
+
output_states = hidden_states
|
118 |
+
if down_encoder_block_2d.downsamplers is not None:
|
119 |
+
for downsampler in down_encoder_block_2d.downsamplers:
|
120 |
+
hidden_states = downsampler(hidden_states)
|
121 |
+
|
122 |
+
return hidden_states, output_states
|
123 |
+
|
124 |
+
|