jwengr commited on
Commit
f6018b4
1 Parent(s): 18e76e0

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
File without changes
gray-inpaint/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SDGrayInpaintModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_sd_gray_inpaint.SDGrayInpaintConfig",
7
+ "AutoModel": "modeling_sd_gray_inpaint.SDGrayInpaintModel"
8
+ },
9
+ "base_model": "stabilityai/stable-diffusion-2-inpainting",
10
+ "height": 512,
11
+ "model_type": "sd_gray_inpaint",
12
+ "torch_dtype": "float32",
13
+ "transformers_version": "4.46.3",
14
+ "width": 512
15
+ }
gray-inpaint/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c6d964dca7f33a3a87e90056e8ab617efeabd99e3cfcea71f73d459b133f231
3
+ size 4055354432
gray-inpaint/modeling_sd_gray_inpaint.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from copy import deepcopy
6
+ from torchvision.transforms.functional import rgb_to_grayscale
7
+ import segmentation_models_pytorch as smp
8
+ from diffusers import StableDiffusionInpaintPipeline
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+
12
+ class SDGrayInpaintConfig(PretrainedConfig):
13
+ model_type = "sd_gray_inpaint"
14
+ def __init__(
15
+ self,
16
+ base_model="stabilityai/stable-diffusion-2-inpainting",
17
+ height=512,
18
+ width=512,
19
+ **kwargs
20
+ ):
21
+ self.base_model=base_model
22
+ self.height=height
23
+ self.width=width
24
+ super().__init__(**kwargs)
25
+
26
+ class SDGrayInpaintModel(PreTrainedModel):
27
+ config_class = SDGrayInpaintConfig
28
+ def __init__(self, config):
29
+ super().__init__(config)
30
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(config.base_model)
31
+ self.mask_predictor = smp.Unet(
32
+ encoder_name="mit_b4",
33
+ encoder_weights="imagenet",
34
+ in_channels=3,
35
+ classes=1,
36
+ )
37
+ self.image_processor = pipe.image_processor
38
+ self.scheduler = pipe.scheduler
39
+ self.unet = pipe.unet
40
+ self.vae = pipe.vae
41
+ self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
42
+ self.height=config.height
43
+ self.width=config.width
44
+
45
+ def forward(
46
+ self,
47
+ images_gray_masked,
48
+ masks=None,
49
+ num_inference_steps=250,
50
+ seed=42,
51
+ input_type='pil',
52
+ output_type='pil'
53
+ ):
54
+ generator = torch.Generator()
55
+ generator.manual_seed(seed)
56
+ if input_type=='pil':
57
+ images_gray_masked = self.image_processor.preprocess(images_gray_masked, height=self.height, width=self.width).float()
58
+ elif input_type=='pt':
59
+ images_gray_masked=images_gray_masked
60
+ else:
61
+ raise ValueError('unsupported input_type')
62
+ images_gray_masked = images_gray_masked.to(self.vae.device)
63
+ if masks is None:
64
+ masks_logits = self.mask_predictor(images_gray_masked)
65
+ masks = (torch.sigmoid(masks_logits)>0.5)*1.
66
+ masks = masks.float().to(self.vae.device)
67
+ B, C, H, W = images_gray_masked.shape
68
+ prompt_embeds = self.prompt_embeds.repeat(B,1,1)
69
+
70
+ scheduler = deepcopy(self.scheduler)
71
+ scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=self.vae.device)
72
+ masked_image_latents = self.vae.encode(images_gray_masked).latent_dist.mode() * self.vae.config.scaling_factor
73
+ mask_latents = F.interpolate(masks, size=(self.unet.config.sample_size, self.unet.config.sample_size))
74
+ latents = randn_tensor(masked_image_latents.shape, generator=generator).to(self.device) * self.scheduler.init_noise_sigma
75
+ for t in scheduler.timesteps:
76
+ latents = scheduler.scale_model_input(latents, t)
77
+ latent_model_input = torch.cat([latents, mask_latents, masked_image_latents], dim=1)
78
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)[0]
79
+ latents = scheduler.step(noise_pred, t, latents)[0]
80
+ latents = latents / self.vae.config.scaling_factor
81
+ images_gray_restored = self.vae.decode(latents.detach())[0]
82
+ images_gray_restored = images_gray_masked * (1-masks) + images_gray_restored.detach() * masks
83
+ images_gray_restored = rgb_to_grayscale(images_gray_restored)
84
+
85
+ if output_type=='pil':
86
+ images_gray_restored = self.image_processor.postprocess(images_gray_restored)
87
+ elif output_type=='np':
88
+ images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'np')
89
+ elif output_type=='pt':
90
+ images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'pt')
91
+ elif output_type=='none':
92
+ images_gray_restored = images_gray_restored
93
+ else:
94
+ raise ValueError('unsupported output_type')
95
+
96
+ return images_gray_restored
97
+
98
+
gray2rgb/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SeResVaeModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_seresvae.SeResVaeConfig",
7
+ "AutoModel": "modeling_seresvae.SeResVaeModel"
8
+ },
9
+ "base_model": "stabilityai/stable-diffusion-2-1",
10
+ "height": 512,
11
+ "model_type": "seresvae",
12
+ "torch_dtype": "float32",
13
+ "transformers_version": "4.46.3",
14
+ "width": 512
15
+ }
gray2rgb/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:657257a0969eec19b5a3ff2629500454e1941615ae646fa72f2d88e9d41be737
3
+ size 3799014812
gray2rgb/modeling_seresvae.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from diffusers import AutoencoderKL, UNet2DConditionModel
5
+ from diffusers.image_processor import VaeImageProcessor
6
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
7
+ from transformers import PretrainedConfig, PreTrainedModel
8
+
9
+ class SEPath(nn.Module):
10
+ def __init__(self, in_channels, out_channels, reduction=16):
11
+ super(SEPath, self).__init__()
12
+ self.fc = nn.Sequential(
13
+ nn.Linear(in_channels, in_channels // reduction, bias=False),
14
+ nn.ReLU(inplace=True),
15
+ nn.Linear(in_channels // reduction, out_channels, bias=False),
16
+ nn.Sigmoid()
17
+ )
18
+
19
+ def forward(self, in_tensor, out_tensor):
20
+ B, C, H, W = in_tensor.size()
21
+ # Squeeze operation
22
+ x = in_tensor.view(B, C, -1).mean(dim=2)
23
+ # Excitation operation
24
+ x = self.fc(x).unsqueeze(2).unsqueeze(2)
25
+
26
+ return out_tensor * x
27
+
28
+ class SeResVaeConfig(PretrainedConfig):
29
+ model_type = "seresvae"
30
+ def __init__(
31
+ self,
32
+ base_model="stabilityai/stable-diffusion-2-1",
33
+ height=512,
34
+ width=512,
35
+ **kwargs
36
+ ):
37
+ self.base_model=base_model
38
+ self.height=height
39
+ self.width=width
40
+ super().__init__(**kwargs)
41
+
42
+ class SeResVaeModel(PreTrainedModel):
43
+ config_class = SeResVaeConfig
44
+ def __init__(self, config):
45
+ super().__init__(config)
46
+ self.image_processor = VaeImageProcessor()
47
+ self.vae = AutoencoderKL.from_pretrained(config.base_model, subfolder='vae')
48
+ self.unet = UNet2DConditionModel.from_pretrained(config.base_model, subfolder='unet')
49
+ self.se_paths = nn.ModuleList([SEPath(8,4), SEPath(512,512), SEPath(512,512), SEPath(256,512), SEPath(128,256)])
50
+ self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
51
+ self.height=config.height
52
+ self.width=config.width
53
+
54
+ def forward(self, images_gray, input_type='pil', output_type='pil'):
55
+ if input_type=='pil':
56
+ images_gray = self.image_processor.preprocess(images_gray, height=self.height, width=self.width).float()
57
+ elif input_type=='pt':
58
+ images_gray=images_gray
59
+ else:
60
+ raise ValueError('unsupported input_type')
61
+ images_gray = images_gray.to(self.vae.device)
62
+ B, C, H, W = images_gray.shape
63
+ prompt_embeds = self.prompt_embeds.repeat(B,1,1)
64
+
65
+ posterior, encode_residual = self.encode_with_residual(images_gray)
66
+ latents = posterior.mode()
67
+ t = torch.LongTensor([500]).repeat(B).to(self.vae.device)
68
+ noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)[0]
69
+ denoised_latents = latents - noise_pred
70
+ images_rgb = self.decode_with_residual(denoised_latents, *encode_residual)
71
+
72
+ if output_type=='pil':
73
+ images_rgb = self.image_processor.postprocess(images_rgb)
74
+ elif output_type=='np':
75
+ images_rgb = self.image_processor.postprocess(images_rgb, 'np')
76
+ elif output_type=='pt':
77
+ images_rgb = self.image_processor.postprocess(images_rgb, 'pt')
78
+ elif output_type=='none':
79
+ images_rgb = images_rgb
80
+ else:
81
+ raise ValueError('unsupported output_type')
82
+
83
+ return images_rgb
84
+
85
+ def encode_with_residual(self, sample):
86
+ re = self.vae.encoder.conv_in(sample)
87
+ re0, re0_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[0], re)
88
+ re1, re1_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[1], re0)
89
+ re2, re2_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[2], re1)
90
+ re3, re3_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[3], re2)
91
+ rem = self.vae.encoder.mid_block(re3)
92
+ re_out = self.vae.encoder.conv_norm_out(rem)
93
+ re_out = self.vae.encoder.conv_act(re_out)
94
+ re_out = self.vae.encoder.conv_out(re_out)
95
+ re_out = self.vae.quant_conv(re_out)
96
+
97
+ posterior = DiagonalGaussianDistribution(re_out)
98
+ return posterior, (re0_out, re1_out, re2_out, rem, re_out)
99
+
100
+ def decode_with_residual(self, z, re0_out, re1_out, re2_out, rem, re_out):
101
+ rd = self.vae.post_quant_conv(self.se_paths[0](re_out, z))
102
+ rd = self.vae.decoder.conv_in(rd)
103
+ rdm = self.vae.decoder.mid_block(self.se_paths[1](rem, rd)).to(torch.float32)
104
+ rd0 = self.vae.decoder.up_blocks[0](rdm)
105
+ rd1 = self.vae.decoder.up_blocks[1](self.se_paths[2](re2_out, rd0))
106
+ rd2 = self.vae.decoder.up_blocks[2](self.se_paths[3](re1_out, rd1))
107
+ rd3 = self.vae.decoder.up_blocks[3](self.se_paths[4](re0_out, rd2))
108
+ rd_out = self.vae.decoder.conv_norm_out(rd3)
109
+ rd_out = self.vae.decoder.conv_act(rd_out)
110
+ sample_out = self.vae.decoder.conv_out(rd_out)
111
+ return sample_out
112
+
113
+ def _DownEncoderBlock2D_res_forward(self, down_encoder_block_2d, hidden_states):
114
+ for resnet in down_encoder_block_2d.resnets:
115
+ hidden_states = resnet(hidden_states, temb=None)
116
+
117
+ output_states = hidden_states
118
+ if down_encoder_block_2d.downsamplers is not None:
119
+ for downsampler in down_encoder_block_2d.downsamplers:
120
+ hidden_states = downsampler(hidden_states)
121
+
122
+ return hidden_states, output_states
123
+
124
+
modeling_sd_gray_inpaint.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from copy import deepcopy
6
+ from torchvision.transforms.functional import rgb_to_grayscale
7
+ import segmentation_models_pytorch as smp
8
+ from diffusers import StableDiffusionInpaintPipeline
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from transformers import PretrainedConfig, PreTrainedModel
11
+
12
+ class SDGrayInpaintConfig(PretrainedConfig):
13
+ model_type = "sd_gray_inpaint"
14
+ def __init__(
15
+ self,
16
+ base_model="stabilityai/stable-diffusion-2-inpainting",
17
+ height=512,
18
+ width=512,
19
+ **kwargs
20
+ ):
21
+ self.base_model=base_model
22
+ self.height=height
23
+ self.width=width
24
+ super().__init__(**kwargs)
25
+
26
+ class SDGrayInpaintModel(PreTrainedModel):
27
+ config_class = SDGrayInpaintConfig
28
+ def __init__(self, config):
29
+ super().__init__(config)
30
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(config.base_model)
31
+ self.mask_predictor = smp.Unet(
32
+ encoder_name="mit_b4",
33
+ encoder_weights="imagenet",
34
+ in_channels=3,
35
+ classes=1,
36
+ )
37
+ self.image_processor = pipe.image_processor
38
+ self.scheduler = pipe.scheduler
39
+ self.unet = pipe.unet
40
+ self.vae = pipe.vae
41
+ self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
42
+ self.height=config.height
43
+ self.width=config.width
44
+
45
+ def forward(
46
+ self,
47
+ images_gray_masked,
48
+ masks=None,
49
+ num_inference_steps=250,
50
+ seed=42,
51
+ input_type='pil',
52
+ output_type='pil'
53
+ ):
54
+ generator = torch.Generator()
55
+ generator.manual_seed(seed)
56
+ if input_type=='pil':
57
+ images_gray_masked = self.image_processor.process(images_gray_masked, height=self.height, width=self.width).float()
58
+ elif input_type=='pt':
59
+ images_gray_masked=images_gray_masked
60
+ else:
61
+ raise ValueError('unsupported input_type')
62
+ images_gray_masked = images_gray_masked.to(self.vae.device)
63
+ if masks is None:
64
+ masks_logits = self.mask_predictor(images_gray_masked)
65
+ masks = (torch.sigmoid(masks_logits)>0.5)*1.
66
+ masks = masks.float().to(self.vae.device)
67
+ B, C, H, W = images_gray_masked.shape
68
+ prompt_embeds = self.prompt_embeds.repeat(B,1,1)
69
+
70
+ scheduler = deepcopy(self.scheduler)
71
+ scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=self.vae.device)
72
+ masked_image_latents = self.vae.encode(images_gray_masked).latent_dist.mode() * self.vae.config.scaling_factor
73
+ mask_latents = F.interpolate(masks, size=(self.unet.config.sample_size, self.unet.config.sample_size))
74
+ latents = randn_tensor(masked_image_latents.shape, generator=generator).to(self.device) * self.scheduler.init_noise_sigma
75
+ for t in scheduler.timesteps:
76
+ latents = scheduler.scale_model_input(latents, t)
77
+ latent_model_input = torch.cat([latents, mask_latents, masked_image_latents], dim=1)
78
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)[0]
79
+ latents = scheduler.step(noise_pred, t, latents)[0]
80
+ latents = latents / self.vae.config.scaling_factor
81
+ images_gray_restored = self.vae.decode(latents.detach())[0]
82
+ images_gray_restored = images_gray_masked * (1-masks) + images_gray_restored.detach() * masks
83
+ images_gray_restored = rgb_to_grayscale(images_gray_restored)
84
+
85
+ if output_type=='pil':
86
+ images_gray_restored = self.image_processor.postprocess(images_gray_restored)
87
+ elif output_type=='np':
88
+ images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'np')
89
+ elif output_type=='pt':
90
+ images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'pt')
91
+ elif output_type=='none':
92
+ images_gray_restored = images_gray_restored
93
+ else:
94
+ raise ValueError('unsupported output_type')
95
+
96
+ return images_gray_restored
97
+
98
+
modeling_seresvae.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from diffusers import AutoencoderKL, UNet2DConditionModel
5
+ from diffusers.image_processor import VaeImageProcessor
6
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
7
+ from transformers import PretrainedConfig, PreTrainedModel
8
+
9
+ class SEPath(nn.Module):
10
+ def __init__(self, in_channels, out_channels, reduction=16):
11
+ super(SEPath, self).__init__()
12
+ self.fc = nn.Sequential(
13
+ nn.Linear(in_channels, in_channels // reduction, bias=False),
14
+ nn.ReLU(inplace=True),
15
+ nn.Linear(in_channels // reduction, out_channels, bias=False),
16
+ nn.Sigmoid()
17
+ )
18
+
19
+ def forward(self, in_tensor, out_tensor):
20
+ B, C, H, W = in_tensor.size()
21
+ # Squeeze operation
22
+ x = in_tensor.view(B, C, -1).mean(dim=2)
23
+ # Excitation operation
24
+ x = self.fc(x).unsqueeze(2).unsqueeze(2)
25
+
26
+ return out_tensor * x
27
+
28
+ class SeResVaeConfig(PretrainedConfig):
29
+ model_type = "seresvae"
30
+ def __init__(
31
+ self,
32
+ base_model="stabilityai/stable-diffusion-2-1",
33
+ height=512,
34
+ width=512,
35
+ **kwargs
36
+ ):
37
+ self.base_model=base_model
38
+ self.height=height
39
+ self.width=width
40
+ super().__init__(**kwargs)
41
+
42
+ class SeResVaeModel(PreTrainedModel):
43
+ config_class = SeResVaeConfig
44
+ def __init__(self, config):
45
+ super().__init__(config)
46
+ self.image_processor = VaeImageProcessor()
47
+ self.vae = AutoencoderKL.from_pretrained(config.base_model, subfolder='vae')
48
+ self.unet = UNet2DConditionModel.from_pretrained(config.base_model, subfolder='unet')
49
+ self.se_paths = nn.ModuleList([SEPath(8,4), SEPath(512,512), SEPath(512,512), SEPath(256,512), SEPath(128,256)])
50
+ self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
51
+ self.height=config.height
52
+ self.width=config.width
53
+
54
+ def forward(self, images_gray, input_type='pil', output_type='pil'):
55
+ if input_type=='pil':
56
+ images_gray = self.image_processor.process(images_gray, height=self.height, width=self.width).float()
57
+ elif input_type=='pt':
58
+ images_gray=images_gray
59
+ else:
60
+ raise ValueError('unsupported input_type')
61
+ images_gray = images_gray.to(self.vae.device)
62
+ B, C, H, W = images_gray.shape
63
+ prompt_embeds = self.prompt_embeds.repeat(B,1,1)
64
+
65
+ posterior, encode_residual = self.encode_with_residual(images_gray)
66
+ latents = posterior.mode()
67
+ t = torch.LongTensor([500]).repeat(B).to(self.vae.device)
68
+ noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)[0]
69
+ denoised_latents = latents - noise_pred
70
+ images_rgb = self.decode_with_residual(denoised_latents, *encode_residual)
71
+
72
+ if output_type=='pil':
73
+ images_rgb = self.image_processor.postprocess(images_rgb)
74
+ elif output_type=='np':
75
+ images_rgb = self.image_processor.postprocess(images_rgb, 'np')
76
+ elif output_type=='pt':
77
+ images_rgb = self.image_processor.postprocess(images_rgb, 'pt')
78
+ elif output_type=='none':
79
+ images_rgb = images_rgb
80
+ else:
81
+ raise ValueError('unsupported output_type')
82
+
83
+ return images_rgb
84
+
85
+ def encode_with_residual(self, sample):
86
+ re = self.vae.encoder.conv_in(sample)
87
+ re0, re0_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[0], re)
88
+ re1, re1_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[1], re0)
89
+ re2, re2_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[2], re1)
90
+ re3, re3_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[3], re2)
91
+ rem = self.vae.encoder.mid_block(re3)
92
+ re_out = self.vae.encoder.conv_norm_out(rem)
93
+ re_out = self.vae.encoder.conv_act(re_out)
94
+ re_out = self.vae.encoder.conv_out(re_out)
95
+ re_out = self.vae.quant_conv(re_out)
96
+
97
+ posterior = DiagonalGaussianDistribution(re_out)
98
+ return posterior, (re0_out, re1_out, re2_out, rem, re_out)
99
+
100
+ def decode_with_residual(self, z, re0_out, re1_out, re2_out, rem, re_out):
101
+ rd = self.vae.post_quant_conv(self.se_paths[0](re_out, z))
102
+ rd = self.vae.decoder.conv_in(rd)
103
+ rdm = self.vae.decoder.mid_block(self.se_paths[1](rem, rd)).to(torch.float32)
104
+ rd0 = self.vae.decoder.up_blocks[0](rdm)
105
+ rd1 = self.vae.decoder.up_blocks[1](self.se_paths[2](re2_out, rd0))
106
+ rd2 = self.vae.decoder.up_blocks[2](self.se_paths[3](re1_out, rd1))
107
+ rd3 = self.vae.decoder.up_blocks[3](self.se_paths[4](re0_out, rd2))
108
+ rd_out = self.vae.decoder.conv_norm_out(rd3)
109
+ rd_out = self.vae.decoder.conv_act(rd_out)
110
+ sample_out = self.vae.decoder.conv_out(rd_out)
111
+ return sample_out
112
+
113
+ def _DownEncoderBlock2D_res_forward(self, down_encoder_block_2d, hidden_states):
114
+ for resnet in down_encoder_block_2d.resnets:
115
+ hidden_states = resnet(hidden_states, temb=None)
116
+
117
+ output_states = hidden_states
118
+ if down_encoder_block_2d.downsamplers is not None:
119
+ for downsampler in down_encoder_block_2d.downsamplers:
120
+ hidden_states = downsampler(hidden_states)
121
+
122
+ return hidden_states, output_states
123
+
124
+