sjtu-deepvision commited on
Commit
311419e
·
verified ·
1 Parent(s): c360cac

Upload 9 files

Browse files
DAI/__init__.py ADDED
File without changes
DAI/controlnetvae.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.unets.unet_2d_blocks import (
34
+ CrossAttnDownBlock2D,
35
+ DownBlock2D,
36
+ UNetMidBlock2D,
37
+ UNetMidBlock2DCrossAttn,
38
+ get_down_block,
39
+ )
40
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
41
+ from diffusers.models.controlnet import ControlNetOutput
42
+ from diffusers.models import ControlNetModel
43
+
44
+ import pdb
45
+
46
+
47
+ class ControlNetVAEModel(ControlNetModel):
48
+ def forward(
49
+ self,
50
+ sample: torch.Tensor,
51
+ timestep: Union[torch.Tensor, float, int],
52
+ encoder_hidden_states: torch.Tensor,
53
+ controlnet_cond: torch.Tensor = None,
54
+ conditioning_scale: float = 1.0,
55
+ class_labels: Optional[torch.Tensor] = None,
56
+ timestep_cond: Optional[torch.Tensor] = None,
57
+ attention_mask: Optional[torch.Tensor] = None,
58
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
59
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
60
+ guess_mode: bool = False,
61
+ return_dict: bool = True,
62
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
63
+ """
64
+ The [`ControlNetVAEModel`] forward method.
65
+
66
+ Args:
67
+ sample (`torch.Tensor`):
68
+ The noisy input tensor.
69
+ timestep (`Union[torch.Tensor, float, int]`):
70
+ The number of timesteps to denoise an input.
71
+ encoder_hidden_states (`torch.Tensor`):
72
+ The encoder hidden states.
73
+ controlnet_cond (`torch.Tensor`):
74
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
75
+ conditioning_scale (`float`, defaults to `1.0`):
76
+ The scale factor for ControlNet outputs.
77
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
78
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
79
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
80
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
81
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
82
+ embeddings.
83
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
84
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
85
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
86
+ negative values to the attention scores corresponding to "discard" tokens.
87
+ added_cond_kwargs (`dict`):
88
+ Additional conditions for the Stable Diffusion XL UNet.
89
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
90
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
91
+ guess_mode (`bool`, defaults to `False`):
92
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
93
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
94
+ return_dict (`bool`, defaults to `True`):
95
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
96
+
97
+ Returns:
98
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
99
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
100
+ returned where the first element is the sample tensor.
101
+ """
102
+ # check channel order
103
+
104
+
105
+ channel_order = self.config.controlnet_conditioning_channel_order
106
+
107
+ if channel_order == "rgb":
108
+ # in rgb order by default
109
+ ...
110
+ elif channel_order == "bgr":
111
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
112
+ else:
113
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
114
+
115
+ # prepare attention_mask
116
+ if attention_mask is not None:
117
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
118
+ attention_mask = attention_mask.unsqueeze(1)
119
+
120
+ # 1. time
121
+ timesteps = timestep
122
+ if not torch.is_tensor(timesteps):
123
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
124
+ # This would be a good case for the `match` statement (Python 3.10+)
125
+ is_mps = sample.device.type == "mps"
126
+ if isinstance(timestep, float):
127
+ dtype = torch.float32 if is_mps else torch.float64
128
+ else:
129
+ dtype = torch.int32 if is_mps else torch.int64
130
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
131
+ elif len(timesteps.shape) == 0:
132
+ timesteps = timesteps[None].to(sample.device)
133
+
134
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
135
+ timesteps = timesteps.expand(sample.shape[0])
136
+
137
+ t_emb = self.time_proj(timesteps)
138
+
139
+ # timesteps does not contain any weights and will always return f32 tensors
140
+ # but time_embedding might actually be running in fp16. so we need to cast here.
141
+ # there might be better ways to encapsulate this.
142
+ t_emb = t_emb.to(dtype=sample.dtype)
143
+
144
+ emb = self.time_embedding(t_emb, timestep_cond)
145
+ aug_emb = None
146
+
147
+ if self.class_embedding is not None:
148
+ if class_labels is None:
149
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
150
+
151
+ if self.config.class_embed_type == "timestep":
152
+ class_labels = self.time_proj(class_labels)
153
+
154
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
155
+ emb = emb + class_emb
156
+
157
+ if self.config.addition_embed_type is not None:
158
+ if self.config.addition_embed_type == "text":
159
+ aug_emb = self.add_embedding(encoder_hidden_states)
160
+
161
+ elif self.config.addition_embed_type == "text_time":
162
+ if "text_embeds" not in added_cond_kwargs:
163
+ raise ValueError(
164
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
165
+ )
166
+ text_embeds = added_cond_kwargs.get("text_embeds")
167
+ if "time_ids" not in added_cond_kwargs:
168
+ raise ValueError(
169
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
170
+ )
171
+ time_ids = added_cond_kwargs.get("time_ids")
172
+ time_embeds = self.add_time_proj(time_ids.flatten())
173
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
174
+
175
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
176
+ add_embeds = add_embeds.to(emb.dtype)
177
+ aug_emb = self.add_embedding(add_embeds)
178
+
179
+
180
+ emb = emb + aug_emb if aug_emb is not None else emb
181
+ # 2. pre-process
182
+ sample = self.conv_in(sample)
183
+
184
+ # 3. down
185
+ down_block_res_samples = (sample,)
186
+ for downsample_block in self.down_blocks:
187
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
188
+ sample, res_samples = downsample_block(
189
+ hidden_states=sample,
190
+ temb=emb,
191
+ encoder_hidden_states=encoder_hidden_states,
192
+ attention_mask=attention_mask,
193
+ cross_attention_kwargs=cross_attention_kwargs,
194
+ )
195
+ else:
196
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
197
+
198
+ down_block_res_samples += res_samples
199
+
200
+ # 4. mid
201
+ if self.mid_block is not None:
202
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
203
+ sample = self.mid_block(
204
+ sample,
205
+ emb,
206
+ encoder_hidden_states=encoder_hidden_states,
207
+ attention_mask=attention_mask,
208
+ cross_attention_kwargs=cross_attention_kwargs,
209
+ )
210
+ else:
211
+ sample = self.mid_block(sample, emb)
212
+
213
+ # 5. Control net blocks
214
+
215
+ controlnet_down_block_res_samples = ()
216
+
217
+ # NOTE that controlnet downblock is zeroconv, we discard
218
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
219
+ down_block_res_sample = down_block_res_sample
220
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
221
+
222
+ down_block_res_samples = controlnet_down_block_res_samples
223
+
224
+ mid_block_res_sample = sample
225
+
226
+ # 6. scaling
227
+ if guess_mode and not self.config.global_pool_conditions:
228
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
229
+ scales = scales * conditioning_scale
230
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
231
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
232
+ else:
233
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
234
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
235
+
236
+ if self.config.global_pool_conditions:
237
+ down_block_res_samples = [
238
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
239
+ ]
240
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
241
+
242
+ if not return_dict:
243
+ return (down_block_res_samples, mid_block_res_sample)
244
+
245
+ return ControlNetOutput(
246
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
247
+ )
248
+
249
+
250
+
DAI/decoder.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Dict, Optional, Tuple, Union
5
+ from diffusers import AutoencoderKL
6
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, Decoder
7
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
8
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
9
+ from diffusers.models.unets.unet_2d_blocks import (
10
+ AutoencoderTinyBlock,
11
+ UNetMidBlock2D,
12
+ get_down_block,
13
+ get_up_block,
14
+ )
15
+ from diffusers.utils.accelerate_utils import apply_forward_hook
16
+
17
+ class ZeroConv2d(nn.Module):
18
+ """
19
+ Zero Convolution layer, similar to the one used in ControlNet.
20
+ """
21
+ def __init__(self, in_channels, out_channels):
22
+ super().__init__()
23
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
24
+ self.conv.weight.data.zero_()
25
+ self.conv.bias.data.zero_()
26
+
27
+ def forward(self, x):
28
+ return self.conv(x)
29
+
30
+ class CustomAutoencoderKL(AutoencoderKL):
31
+ def __init__(
32
+ self,
33
+ in_channels: int = 3,
34
+ out_channels: int = 3,
35
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
36
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
37
+ block_out_channels: Tuple[int] = (64,),
38
+ layers_per_block: int = 1,
39
+ act_fn: str = "silu",
40
+ latent_channels: int = 4,
41
+ norm_num_groups: int = 32,
42
+ sample_size: int = 32,
43
+ scaling_factor: float = 0.18215,
44
+ force_upcast: float = True,
45
+ use_quant_conv: bool = True,
46
+ use_post_quant_conv: bool = True,
47
+ mid_block_add_attention: bool = True,
48
+ ):
49
+ super().__init__(
50
+ in_channels=in_channels,
51
+ out_channels=out_channels,
52
+ down_block_types=down_block_types,
53
+ up_block_types=up_block_types,
54
+ block_out_channels=block_out_channels,
55
+ layers_per_block=layers_per_block,
56
+ act_fn=act_fn,
57
+ latent_channels=latent_channels,
58
+ norm_num_groups=norm_num_groups,
59
+ sample_size=sample_size,
60
+ scaling_factor=scaling_factor,
61
+ force_upcast=force_upcast,
62
+ use_quant_conv=use_quant_conv,
63
+ use_post_quant_conv=use_post_quant_conv,
64
+ mid_block_add_attention=mid_block_add_attention,
65
+ )
66
+
67
+ # Add Zero Convolution layers to the encoder
68
+ # self.zero_convs = nn.ModuleList()
69
+ # for i, out_channels_ in enumerate(block_out_channels):
70
+ # self.zero_convs.append(ZeroConv2d(out_channels_, out_channels_))
71
+
72
+ # Modify the decoder to accept skip connections
73
+ self.decoder = CustomDecoder(
74
+ in_channels=latent_channels,
75
+ out_channels=out_channels,
76
+ up_block_types=up_block_types,
77
+ block_out_channels=block_out_channels,
78
+ layers_per_block=layers_per_block,
79
+ norm_num_groups=norm_num_groups,
80
+ act_fn=act_fn,
81
+ mid_block_add_attention=mid_block_add_attention,
82
+ )
83
+ self.encoder = CustomEncoder(
84
+ in_channels=in_channels,
85
+ out_channels=latent_channels,
86
+ down_block_types=down_block_types,
87
+ block_out_channels=block_out_channels,
88
+ layers_per_block=layers_per_block,
89
+ norm_num_groups=norm_num_groups,
90
+ act_fn=act_fn,
91
+ mid_block_add_attention=mid_block_add_attention,
92
+ )
93
+
94
+ def encode(self, x: torch.Tensor, return_dict: bool = True):
95
+ # Get the encoder outputs
96
+ _, skip_connections = self.encoder(x)
97
+
98
+ return skip_connections
99
+
100
+ def decode(self, z: torch.Tensor, skip_connections: list, return_dict: bool = True):
101
+ if self.post_quant_conv is not None:
102
+ z = self.post_quant_conv(z)
103
+ # Decode the latent representation with skip connections
104
+ dec = self.decoder(z, skip_connections)
105
+
106
+ if not return_dict:
107
+ return (dec,)
108
+
109
+ return DecoderOutput(sample=dec)
110
+
111
+ def forward(
112
+ self,
113
+ sample: torch.Tensor,
114
+ sample_posterior: bool = False,
115
+ return_dict: bool = True,
116
+ generator: Optional[torch.Generator] = None,
117
+ ):
118
+ # Encode the input and get the skip connections
119
+ posterior, skip_connections = self.encode(sample, return_dict=True)
120
+
121
+ # Sample from the posterior
122
+ if sample_posterior:
123
+ z = posterior.sample(generator=generator)
124
+ else:
125
+ z = posterior.mode()
126
+
127
+ # Decode the latent representation with skip connections
128
+ dec = self.decode(z, skip_connections, return_dict=return_dict)
129
+
130
+ if not return_dict:
131
+ return (dec,)
132
+
133
+ return DecoderOutput(sample=dec)
134
+
135
+
136
+ class CustomDecoder(Decoder):
137
+ def __init__(
138
+ self,
139
+ in_channels: int,
140
+ out_channels: int,
141
+ up_block_types: Tuple[str, ...],
142
+ block_out_channels: Tuple[int, ...],
143
+ layers_per_block: int,
144
+ norm_num_groups: int,
145
+ act_fn: str,
146
+ mid_block_add_attention: bool,
147
+ ):
148
+ super().__init__(
149
+ in_channels=in_channels,
150
+ out_channels=out_channels,
151
+ up_block_types=up_block_types,
152
+ block_out_channels=block_out_channels,
153
+ layers_per_block=layers_per_block,
154
+ norm_num_groups=norm_num_groups,
155
+ act_fn=act_fn,
156
+ mid_block_add_attention=mid_block_add_attention,
157
+ )
158
+
159
+ def forward(
160
+ self,
161
+ sample: torch.Tensor,
162
+ skip_connections: list,
163
+ latent_embeds: Optional[torch.Tensor] = None,
164
+ ) -> torch.Tensor:
165
+ r"""The forward method of the `Decoder` class."""
166
+
167
+ sample = self.conv_in(sample)
168
+
169
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
170
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
171
+
172
+ def create_custom_forward(module):
173
+ def custom_forward(*inputs):
174
+ return module(*inputs)
175
+
176
+ return custom_forward
177
+
178
+ if is_torch_version(">=", "1.11.0"):
179
+ # middle
180
+ sample = torch.utils.checkpoint.checkpoint(
181
+ create_custom_forward(self.mid_block),
182
+ sample,
183
+ latent_embeds,
184
+ use_reentrant=False,
185
+ )
186
+ sample = sample.to(upscale_dtype)
187
+
188
+ # up
189
+ for up_block in self.up_blocks:
190
+ sample = torch.utils.checkpoint.checkpoint(
191
+ create_custom_forward(up_block),
192
+ sample,
193
+ latent_embeds,
194
+ use_reentrant=False,
195
+ )
196
+ else:
197
+ # middle
198
+ sample = torch.utils.checkpoint.checkpoint(
199
+ create_custom_forward(self.mid_block), sample, latent_embeds
200
+ )
201
+ sample = sample.to(upscale_dtype)
202
+
203
+ # up
204
+ for up_block in self.up_blocks:
205
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
206
+ else:
207
+ # middle
208
+ sample = self.mid_block(sample, latent_embeds)
209
+ sample = sample.to(upscale_dtype)
210
+
211
+ # up
212
+ # for up_block in self.up_blocks:
213
+ # sample = up_block(sample, latent_embeds)
214
+ for i, up_block in enumerate(self.up_blocks):
215
+ # Add skip connections directly
216
+ if i < len(skip_connections):
217
+ skip_connection = skip_connections[-(i + 1)]
218
+ # import pdb; pdb.set_trace()
219
+ sample = sample + skip_connection
220
+ # import pdb; pdb.set_trace() #torch.Size([1, 512, 96, 96]
221
+ sample = up_block(sample)
222
+
223
+ # post-process
224
+ if latent_embeds is None:
225
+ sample = self.conv_norm_out(sample)
226
+ else:
227
+ sample = self.conv_norm_out(sample, latent_embeds)
228
+ sample = self.conv_act(sample)
229
+ sample = self.conv_out(sample)
230
+
231
+ return sample
232
+
233
+ class CustomEncoder(Encoder):
234
+ r"""
235
+ Custom Encoder that adds Zero Convolution layers to each block's output
236
+ to generate skip connections.
237
+ """
238
+ def __init__(
239
+ self,
240
+ in_channels: int = 3,
241
+ out_channels: int = 3,
242
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
243
+ block_out_channels: Tuple[int, ...] = (64,),
244
+ layers_per_block: int = 2,
245
+ norm_num_groups: int = 32,
246
+ act_fn: str = "silu",
247
+ double_z: bool = True,
248
+ mid_block_add_attention: bool = True,
249
+ ):
250
+ super().__init__(
251
+ in_channels=in_channels,
252
+ out_channels=out_channels,
253
+ down_block_types=down_block_types,
254
+ block_out_channels=block_out_channels,
255
+ layers_per_block=layers_per_block,
256
+ norm_num_groups=norm_num_groups,
257
+ act_fn=act_fn,
258
+ double_z=double_z,
259
+ mid_block_add_attention=mid_block_add_attention,
260
+ )
261
+
262
+ # Add Zero Convolution layers to each block's output
263
+ self.zero_convs = nn.ModuleList()
264
+ for i, out_channels in enumerate(block_out_channels):
265
+ if i < 2:
266
+ self.zero_convs.append(ZeroConv2d(out_channels, out_channels * 2))
267
+ else:
268
+ self.zero_convs.append(ZeroConv2d(out_channels, out_channels))
269
+
270
+ def forward(self, sample: torch.Tensor) -> list[torch.Tensor]:
271
+ r"""
272
+ Forward pass of the CustomEncoder.
273
+
274
+ Args:
275
+ sample (`torch.Tensor`): Input tensor.
276
+
277
+ Returns:
278
+ `Tuple[torch.Tensor, List[torch.Tensor]]`:
279
+ - The final latent representation.
280
+ - A list of skip connections from each block.
281
+ """
282
+ skip_connections = []
283
+
284
+ # Initial convolution
285
+ sample = self.conv_in(sample)
286
+
287
+ # Down blocks
288
+ for i, (down_block, zero_conv) in enumerate(zip(self.down_blocks, self.zero_convs)):
289
+ # import pdb; pdb.set_trace()
290
+ sample = down_block(sample)
291
+ if i != len(self.down_blocks) - 1:
292
+ sample_out = nn.functional.interpolate(zero_conv(sample), scale_factor=2, mode='bilinear', align_corners=False)
293
+ else:
294
+ sample_out = zero_conv(sample)
295
+ skip_connections.append(sample_out)
296
+
297
+
298
+ # import pdb; pdb.set_trace()
299
+ # torch.Size([1, 128, 768, 768])
300
+ # torch.Size([1, 128, 384, 384])
301
+ # torch.Size([1, 256, 192, 192])
302
+ # torch.Size([1, 512, 96, 96])
303
+ # torch.Size([1, 512, 96, 96])
304
+
305
+ # # Middle block
306
+ # sample = self.mid_block(sample)
307
+
308
+ # # Post-process
309
+ # sample = self.conv_norm_out(sample)
310
+ # sample = self.conv_act(sample)
311
+ # sample = self.conv_out(sample)
312
+
313
+ return sample, skip_connections
DAI/pipeline_all.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # More information and citation instructions are available on the
17
+ # --------------------------------------------------------------------------
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from PIL import Image
24
+ from tqdm.auto import tqdm
25
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
26
+
27
+
28
+ from diffusers.image_processor import PipelineImageInput
29
+ from diffusers.models import (
30
+ AutoencoderKL,
31
+ UNet2DConditionModel,
32
+ ControlNetModel,
33
+ )
34
+ from diffusers.schedulers import (
35
+ DDIMScheduler
36
+ )
37
+
38
+ from diffusers.utils import (
39
+ BaseOutput,
40
+ logging,
41
+ replace_example_docstring,
42
+ )
43
+
44
+
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+ from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline
47
+ from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
48
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
49
+
50
+ from DAI.decoder import CustomAutoencoderKL
51
+
52
+ import pdb
53
+
54
+
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+
59
+ EXAMPLE_DOC_STRING = """
60
+ Examples:
61
+ ```py
62
+ >>> import diffusers
63
+ >>> import torch
64
+
65
+ >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
66
+ ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
67
+ ... ).to("cuda")
68
+
69
+ >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
70
+ >>> normals = pipe(image)
71
+
72
+ >>> vis = pipe.image_processor.visualize_normals(normals.prediction)
73
+ >>> vis[0].save("einstein_normals.png")
74
+ ```
75
+ """
76
+
77
+
78
+ @dataclass
79
+ class DAIOutput(BaseOutput):
80
+ """
81
+ Output class for Marigold monocular normals prediction pipeline.
82
+
83
+ Args:
84
+ prediction (`np.ndarray`, `torch.Tensor`):
85
+ Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
86
+ \times width$, regardless of whether the images were passed as a 4D array or a list.
87
+ uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
88
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
89
+ \times 1 \times height \times width$.
90
+ latent (`None`, `torch.Tensor`):
91
+ Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
92
+ The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
93
+ """
94
+
95
+ prediction: Union[np.ndarray, torch.Tensor]
96
+ latent: Union[None, torch.Tensor]
97
+ gaus_noise: Union[None, torch.Tensor]
98
+
99
+
100
+ class DAIPipeline(StableDiffusionControlNetPipeline):
101
+ """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
102
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
103
+
104
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
105
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
106
+
107
+ The pipeline also inherits the following loading methods:
108
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
109
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
110
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
111
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
112
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
113
+
114
+ Args:
115
+ vae ([`AutoencoderKL`]):
116
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
117
+ text_encoder ([`~transformers.CLIPTextModel`]):
118
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
119
+ tokenizer ([`~transformers.CLIPTokenizer`]):
120
+ A `CLIPTokenizer` to tokenize text.
121
+ unet ([`UNet2DConditionModel`]):
122
+ A `UNet2DConditionModel` to denoise the encoded image latents.
123
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
124
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
125
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
126
+ additional conditioning.
127
+ scheduler ([`SchedulerMixin`]):
128
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
129
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
130
+ safety_checker ([`StableDiffusionSafetyChecker`]):
131
+ Classification module that estimates whether generated images could be considered offensive or harmful.
132
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
133
+ about a model's potential harms.
134
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
135
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
136
+ """
137
+
138
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
139
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
140
+ _exclude_from_cpu_offload = ["safety_checker"]
141
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
142
+
143
+
144
+
145
+ def __init__(
146
+ self,
147
+ # vae_2: CustomAutoencoderKL,
148
+ vae: AutoencoderKL,
149
+ text_encoder: CLIPTextModel,
150
+ tokenizer: CLIPTokenizer,
151
+ unet: UNet2DConditionModel,
152
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
153
+ scheduler: Union[DDIMScheduler],
154
+ safety_checker: StableDiffusionSafetyChecker,
155
+ feature_extractor: CLIPImageProcessor,
156
+ image_encoder: CLIPVisionModelWithProjection = None,
157
+ requires_safety_checker: bool = True,
158
+ default_denoising_steps: Optional[int] = 1,
159
+ default_processing_resolution: Optional[int] = 768,
160
+ prompt="remove glass reflection",
161
+ empty_text_embedding=None,
162
+ t_start: Optional[int] = 0,
163
+ ):
164
+ super().__init__(
165
+ vae,
166
+ text_encoder,
167
+ tokenizer,
168
+ unet,
169
+ controlnet,
170
+ scheduler,
171
+ safety_checker,
172
+ feature_extractor,
173
+ image_encoder,
174
+ requires_safety_checker,
175
+ )
176
+ # self.vae_2 = vae_2
177
+
178
+ self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
179
+ self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
180
+ self.default_denoising_steps = default_denoising_steps
181
+ self.default_processing_resolution = default_processing_resolution
182
+ self.prompt = prompt
183
+ self.prompt_embeds = None
184
+ self.empty_text_embedding = empty_text_embedding
185
+ self.t_start= t_start # target_out latents
186
+
187
+
188
+ def check_inputs(
189
+ self,
190
+ image: PipelineImageInput,
191
+ num_inference_steps: int,
192
+ ensemble_size: int,
193
+ processing_resolution: int,
194
+ resample_method_input: str,
195
+ resample_method_output: str,
196
+ batch_size: int,
197
+ ensembling_kwargs: Optional[Dict[str, Any]],
198
+ latents: Optional[torch.Tensor],
199
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
200
+ output_type: str,
201
+ output_uncertainty: bool,
202
+ ) -> int:
203
+ if num_inference_steps is None:
204
+ raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
205
+ if num_inference_steps < 1:
206
+ raise ValueError("`num_inference_steps` must be positive.")
207
+ if ensemble_size < 1:
208
+ raise ValueError("`ensemble_size` must be positive.")
209
+ if ensemble_size == 2:
210
+ logger.warning(
211
+ "`ensemble_size` == 2 results are similar to no ensembling (1); "
212
+ "consider increasing the value to at least 3."
213
+ )
214
+ if ensemble_size == 1 and output_uncertainty:
215
+ raise ValueError(
216
+ "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` "
217
+ "greater than 1."
218
+ )
219
+ if processing_resolution is None:
220
+ raise ValueError(
221
+ "`processing_resolution` is not specified and could not be resolved from the model config."
222
+ )
223
+ if processing_resolution < 0:
224
+ raise ValueError(
225
+ "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
226
+ "downsampled processing."
227
+ )
228
+ if processing_resolution % self.vae_scale_factor != 0:
229
+ raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
230
+ if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
231
+ raise ValueError(
232
+ "`resample_method_input` takes string values compatible with PIL library: "
233
+ "nearest, nearest-exact, bilinear, bicubic, area."
234
+ )
235
+ if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
236
+ raise ValueError(
237
+ "`resample_method_output` takes string values compatible with PIL library: "
238
+ "nearest, nearest-exact, bilinear, bicubic, area."
239
+ )
240
+ if batch_size < 1:
241
+ raise ValueError("`batch_size` must be positive.")
242
+ if output_type not in ["pt", "np"]:
243
+ raise ValueError("`output_type` must be one of `pt` or `np`.")
244
+ if latents is not None and generator is not None:
245
+ raise ValueError("`latents` and `generator` cannot be used together.")
246
+ if ensembling_kwargs is not None:
247
+ if not isinstance(ensembling_kwargs, dict):
248
+ raise ValueError("`ensembling_kwargs` must be a dictionary.")
249
+ if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"):
250
+ raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.")
251
+
252
+ # image checks
253
+ num_images = 0
254
+ W, H = None, None
255
+ if not isinstance(image, list):
256
+ image = [image]
257
+ for i, img in enumerate(image):
258
+ if isinstance(img, np.ndarray) or torch.is_tensor(img):
259
+ if img.ndim not in (2, 3, 4):
260
+ raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
261
+ H_i, W_i = img.shape[-2:]
262
+ N_i = 1
263
+ if img.ndim == 4:
264
+ N_i = img.shape[0]
265
+ elif isinstance(img, Image.Image):
266
+ W_i, H_i = img.size
267
+ N_i = 1
268
+ else:
269
+ raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
270
+ if W is None:
271
+ W, H = W_i, H_i
272
+ elif (W, H) != (W_i, H_i):
273
+ raise ValueError(
274
+ f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
275
+ )
276
+ num_images += N_i
277
+
278
+ # latents checks
279
+ if latents is not None:
280
+ if not torch.is_tensor(latents):
281
+ raise ValueError("`latents` must be a torch.Tensor.")
282
+ if latents.dim() != 4:
283
+ raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.")
284
+
285
+ if processing_resolution > 0:
286
+ max_orig = max(H, W)
287
+ new_H = H * processing_resolution // max_orig
288
+ new_W = W * processing_resolution // max_orig
289
+ if new_H == 0 or new_W == 0:
290
+ raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
291
+ W, H = new_W, new_H
292
+ w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
293
+ h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
294
+ shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w)
295
+
296
+ if latents.shape != shape_expected:
297
+ raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
298
+
299
+ # generator checks
300
+ if generator is not None:
301
+ if isinstance(generator, list):
302
+ if len(generator) != num_images * ensemble_size:
303
+ raise ValueError(
304
+ "The number of generators must match the total number of ensemble members for all input images."
305
+ )
306
+ if not all(g.device.type == generator[0].device.type for g in generator):
307
+ raise ValueError("`generator` device placement is not consistent in the list.")
308
+ elif not isinstance(generator, torch.Generator):
309
+ raise ValueError(f"Unsupported generator type: {type(generator)}.")
310
+
311
+ return num_images
312
+
313
+ def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
314
+ if not hasattr(self, "_progress_bar_config"):
315
+ self._progress_bar_config = {}
316
+ elif not isinstance(self._progress_bar_config, dict):
317
+ raise ValueError(
318
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
319
+ )
320
+
321
+ progress_bar_config = dict(**self._progress_bar_config)
322
+ progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
323
+ progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
324
+ if iterable is not None:
325
+ return tqdm(iterable, **progress_bar_config)
326
+ elif total is not None:
327
+ return tqdm(total=total, **progress_bar_config)
328
+ else:
329
+ raise ValueError("Either `total` or `iterable` has to be defined.")
330
+
331
+ @torch.no_grad()
332
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
333
+ def __call__(
334
+ self,
335
+ image: PipelineImageInput,
336
+ vae_2: CustomAutoencoderKL,
337
+ prompt: Union[str, List[str]] = None,
338
+ negative_prompt: Optional[Union[str, List[str]]] = None,
339
+ num_inference_steps: Optional[int] = None,
340
+ ensemble_size: int = 1,
341
+ processing_resolution: Optional[int] = None,
342
+ match_input_resolution: bool = True,
343
+ resample_method_input: str = "bilinear",
344
+ resample_method_output: str = "bilinear",
345
+ batch_size: int = 1,
346
+ ensembling_kwargs: Optional[Dict[str, Any]] = None,
347
+ latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
348
+ prompt_embeds: Optional[torch.Tensor] = None,
349
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
350
+ num_images_per_prompt: Optional[int] = 1,
351
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
352
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
353
+ output_type: str = "np",
354
+ output_uncertainty: bool = False,
355
+ output_latent: bool = False,
356
+ skip_preprocess: bool = False,
357
+ return_dict: bool = True,
358
+ **kwargs,
359
+ ):
360
+ """
361
+ Function invoked when calling the pipeline.
362
+
363
+ Args:
364
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
365
+ `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
366
+ arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
367
+ by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
368
+ three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
369
+ same width and height.
370
+ num_inference_steps (`int`, *optional*, defaults to `None`):
371
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
372
+ selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
373
+ for Marigold-LCM models.
374
+ ensemble_size (`int`, defaults to `1`):
375
+ Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
376
+ faster inference.
377
+ processing_resolution (`int`, *optional*, defaults to `None`):
378
+ Effective processing resolution. When set to `0`, matches the larger input image dimension. This
379
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
380
+ value `None` resolves to the optimal value from the model config.
381
+ match_input_resolution (`bool`, *optional*, defaults to `True`):
382
+ When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
383
+ side of the output will equal to `processing_resolution`.
384
+ resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
385
+ Resampling method used to resize input images to `processing_resolution`. The accepted values are:
386
+ `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
387
+ resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
388
+ Resampling method used to resize output predictions to match the input resolution. The accepted values
389
+ are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
390
+ batch_size (`int`, *optional*, defaults to `1`):
391
+ Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
392
+ ensembling_kwargs (`dict`, *optional*, defaults to `None`)
393
+ Extra dictionary with arguments for precise ensembling control. The following options are available:
394
+ - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in
395
+ every pixel location, can be either `"closest"` or `"mean"`.
396
+ latents (`torch.Tensor`, *optional*, defaults to `None`):
397
+ Latent noise tensors to replace the random initialization. These can be taken from the previous
398
+ function call's output.
399
+ generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`):
400
+ Random number generator object to ensure reproducibility.
401
+ output_type (`str`, *optional*, defaults to `"np"`):
402
+ Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
403
+ values are: `"np"` (numpy array) or `"pt"` (torch tensor).
404
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
405
+ When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
406
+ the `ensemble_size` argument is set to a value above 2.
407
+ output_latent (`bool`, *optional*, defaults to `False`):
408
+ When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
409
+ within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
410
+ `latents` argument.
411
+ return_dict (`bool`, *optional*, defaults to `True`):
412
+ Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
413
+
414
+ Examples:
415
+
416
+ Returns:
417
+ [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`:
418
+ If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a
419
+ `tuple` is returned where the first element is the prediction, the second element is the uncertainty
420
+ (or `None`), and the third is the latent (or `None`).
421
+ """
422
+
423
+ # 0. Resolving variables.
424
+ device = self._execution_device
425
+ dtype = self.dtype
426
+
427
+ # Model-specific optimal default values leading to fast and reasonable results.
428
+ if num_inference_steps is None:
429
+ num_inference_steps = self.default_denoising_steps
430
+ if processing_resolution is None:
431
+ processing_resolution = self.default_processing_resolution
432
+
433
+
434
+ # 1. Check inputs.
435
+ num_images = self.check_inputs(
436
+ image,
437
+ num_inference_steps,
438
+ ensemble_size,
439
+ processing_resolution,
440
+ resample_method_input,
441
+ resample_method_output,
442
+ batch_size,
443
+ ensembling_kwargs,
444
+ latents,
445
+ generator,
446
+ output_type,
447
+ output_uncertainty,
448
+ )
449
+
450
+
451
+ # 2. Prepare empty text conditioning.
452
+ # Model invocation: self.tokenizer, self.text_encoder.
453
+ if self.empty_text_embedding is None:
454
+ prompt = ""
455
+ text_inputs = self.tokenizer(
456
+ prompt,
457
+ padding="do_not_pad",
458
+ max_length=self.tokenizer.model_max_length,
459
+ truncation=True,
460
+ return_tensors="pt",
461
+ )
462
+ text_input_ids = text_inputs.input_ids.to(device)
463
+ self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
464
+
465
+
466
+
467
+ # 3. prepare prompt
468
+ if self.prompt_embeds is None:
469
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
470
+ self.prompt,
471
+ device,
472
+ num_images_per_prompt,
473
+ False,
474
+ negative_prompt,
475
+ prompt_embeds=prompt_embeds,
476
+ negative_prompt_embeds=None,
477
+ lora_scale=None,
478
+ clip_skip=None,
479
+ )
480
+ self.prompt_embeds = prompt_embeds
481
+ self.negative_prompt_embeds = negative_prompt_embeds
482
+
483
+
484
+
485
+ # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
486
+ # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
487
+ # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
488
+ # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
489
+ # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
490
+ # operation and leads to the most reasonable results. Using the native image resolution or any other processing
491
+ # resolution can lead to loss of either fine details or global context in the output predictions.
492
+ if not skip_preprocess:
493
+ image, padding, original_resolution = self.image_processor.preprocess(
494
+ image, processing_resolution, resample_method_input, device, dtype
495
+ ) # [N,3,PPH,PPW]
496
+ else:
497
+ padding = (0, 0)
498
+ original_resolution = image.shape[2:]
499
+ # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
500
+ # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
501
+ # Latents of each such predictions across all input images and all ensemble members are represented in the
502
+ # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
503
+ # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
504
+ # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
505
+ # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
506
+ # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
507
+ # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
508
+ # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
509
+ # Model invocation: self.vae.encoder.
510
+ image_latent, pred_latent = self.prepare_latents(
511
+ image, latents, generator, ensemble_size, batch_size
512
+ ) # [N*E,4,h,w], [N*E,4,h,w]
513
+
514
+ gaus_noise = pred_latent.detach().clone()
515
+ # del image
516
+
517
+
518
+ # 6. obtain control_output
519
+
520
+ cond_scale =controlnet_conditioning_scale
521
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
522
+ image_latent.detach(),
523
+ self.t_start,
524
+ encoder_hidden_states=self.prompt_embeds,
525
+ conditioning_scale=cond_scale,
526
+ guess_mode=False,
527
+ return_dict=False,
528
+ )
529
+
530
+ # 7. Onestep sampling
531
+ latent_x_t = self.unet(
532
+ pred_latent,
533
+ self.t_start,
534
+ encoder_hidden_states=self.prompt_embeds,
535
+ down_block_additional_residuals=down_block_res_samples,
536
+ mid_block_additional_residual=mid_block_res_sample,
537
+ return_dict=False,
538
+ )[0]
539
+
540
+
541
+ del (
542
+ pred_latent,
543
+ image_latent,
544
+ )
545
+
546
+ # encoder
547
+ skip_connections = vae_2.encode(image)
548
+ # decoder
549
+ prediction = self.decode_prediction(latent_x_t, skip_connections, vae_2)
550
+
551
+
552
+ prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
553
+
554
+ prediction = self.image_processor.resize_antialias(
555
+ prediction, original_resolution, resample_method_output, is_aa=False
556
+ ) # [N,3,H,W]
557
+
558
+ if output_type == "np":
559
+ prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3]
560
+
561
+ # 11. Offload all models
562
+ self.maybe_free_model_hooks()
563
+
564
+ return DAIOutput(
565
+ prediction=prediction,
566
+ latent=latent_x_t,
567
+ gaus_noise=gaus_noise,
568
+ )
569
+
570
+ # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
571
+ def prepare_latents(
572
+ self,
573
+ image: torch.Tensor,
574
+ latents: Optional[torch.Tensor],
575
+ generator: Optional[torch.Generator],
576
+ ensemble_size: int,
577
+ batch_size: int,
578
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
579
+ def retrieve_latents(encoder_output):
580
+ if hasattr(encoder_output, "latent_dist"):
581
+ return encoder_output.latent_dist.mode()
582
+ elif hasattr(encoder_output, "latents"):
583
+ return encoder_output.latents
584
+ else:
585
+ raise AttributeError("Could not access latents of provided encoder_output")
586
+
587
+
588
+
589
+ image_latent = torch.cat(
590
+ [
591
+ retrieve_latents(self.vae.encode(image[i : i + batch_size]))
592
+ for i in range(0, image.shape[0], batch_size)
593
+ ],
594
+ dim=0,
595
+ ) # [N,4,h,w]
596
+ image_latent = image_latent * self.vae.config.scaling_factor
597
+ image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
598
+
599
+ pred_latent = torch.zeros_like(image_latent)
600
+ if pred_latent is None:
601
+ pred_latent = randn_tensor(
602
+ image_latent.shape,
603
+ generator=generator,
604
+ device=image_latent.device,
605
+ dtype=image_latent.dtype,
606
+ ) # [N*E,4,h,w]
607
+
608
+ return image_latent, pred_latent
609
+
610
+ def decode_prediction(self, pred_latent: torch.Tensor, skip_connections: list, vae_2: CustomAutoencoderKL) -> torch.Tensor:
611
+ if pred_latent.dim() != 4 or pred_latent.shape[1] != vae_2.config.latent_channels:
612
+ raise ValueError(
613
+ f"Expecting 4D tensor of shape [B,{vae_2.config.latent_channels},H,W]; got {pred_latent.shape}."
614
+ )
615
+
616
+ prediction = vae_2.decode(pred_latent / vae_2.config.scaling_factor, skip_connections, return_dict=False)[0] # [B,3,H,W]
617
+
618
+ return prediction # [B,3,H,W]
619
+
620
+ @staticmethod
621
+ def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
622
+ if normals.dim() != 4 or normals.shape[1] != 3:
623
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
624
+
625
+ norm = torch.norm(normals, dim=1, keepdim=True)
626
+ normals /= norm.clamp(min=eps)
627
+
628
+ return normals
629
+
630
+ @staticmethod
631
+ def ensemble_normals(
632
+ normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest"
633
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
634
+ """
635
+ Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
636
+ the number of ensemble members for a given prediction of size `(H x W)`.
637
+
638
+ Args:
639
+ normals (`torch.Tensor`):
640
+ Input ensemble normals maps.
641
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
642
+ Whether to output uncertainty map.
643
+ reduction (`str`, *optional*, defaults to `"closest"`):
644
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
645
+ `"mean"`.
646
+
647
+ Returns:
648
+ A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
649
+ uncertainties of shape `(1, 1, H, W)`.
650
+ """
651
+ if normals.dim() != 4 or normals.shape[1] != 3:
652
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
653
+ if reduction not in ("closest", "mean"):
654
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
655
+
656
+ mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W]
657
+ mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W]
658
+
659
+ sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W]
660
+ sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16
661
+
662
+ uncertainty = None
663
+ if output_uncertainty:
664
+ uncertainty = sim_cos.arccos() # [E,1,H,W]
665
+ uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W]
666
+
667
+ if reduction == "mean":
668
+ return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W]
669
+
670
+ closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W]
671
+ closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W]
672
+ closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W]
673
+
674
+ return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W]
675
+
676
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
677
+ def retrieve_timesteps(
678
+ scheduler,
679
+ num_inference_steps: Optional[int] = None,
680
+ device: Optional[Union[str, torch.device]] = None,
681
+ timesteps: Optional[List[int]] = None,
682
+ sigmas: Optional[List[float]] = None,
683
+ **kwargs,
684
+ ):
685
+ """
686
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
687
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
688
+
689
+ Args:
690
+ scheduler (`SchedulerMixin`):
691
+ The scheduler to get timesteps from.
692
+ num_inference_steps (`int`):
693
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
694
+ must be `None`.
695
+ device (`str` or `torch.device`, *optional*):
696
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
697
+ timesteps (`List[int]`, *optional*):
698
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
699
+ `num_inference_steps` and `sigmas` must be `None`.
700
+ sigmas (`List[float]`, *optional*):
701
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
702
+ `num_inference_steps` and `timesteps` must be `None`.
703
+
704
+ Returns:
705
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
706
+ second element is the number of inference steps.
707
+ """
708
+ if timesteps is not None and sigmas is not None:
709
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
710
+ if timesteps is not None:
711
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
712
+ if not accepts_timesteps:
713
+ raise ValueError(
714
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
715
+ f" timestep schedules. Please check whether you are using the correct scheduler."
716
+ )
717
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
718
+ timesteps = scheduler.timesteps
719
+ num_inference_steps = len(timesteps)
720
+ elif sigmas is not None:
721
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
722
+ if not accept_sigmas:
723
+ raise ValueError(
724
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
725
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
726
+ )
727
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
728
+ timesteps = scheduler.timesteps
729
+ num_inference_steps = len(timesteps)
730
+ else:
731
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
732
+ timesteps = scheduler.timesteps
733
+ return timesteps, num_inference_steps
DAI/pipeline_onestep.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # More information and citation instructions are available on the
17
+ # --------------------------------------------------------------------------
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from PIL import Image
24
+ from tqdm.auto import tqdm
25
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
26
+
27
+
28
+ from diffusers.image_processor import PipelineImageInput
29
+ from diffusers.models import (
30
+ AutoencoderKL,
31
+ UNet2DConditionModel,
32
+ ControlNetModel,
33
+ )
34
+ from diffusers.schedulers import (
35
+ DDIMScheduler
36
+ )
37
+
38
+ from diffusers.utils import (
39
+ BaseOutput,
40
+ logging,
41
+ replace_example_docstring,
42
+ )
43
+
44
+
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+ from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline
47
+ from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
48
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
49
+
50
+ import pdb
51
+
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import diffusers
61
+ >>> import torch
62
+
63
+ >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
64
+ ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
65
+ ... ).to("cuda")
66
+
67
+ >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
68
+ >>> normals = pipe(image)
69
+
70
+ >>> vis = pipe.image_processor.visualize_normals(normals.prediction)
71
+ >>> vis[0].save("einstein_normals.png")
72
+ ```
73
+ """
74
+
75
+
76
+ @dataclass
77
+ class DAIOutput(BaseOutput):
78
+ """
79
+ Output class for Marigold monocular normals prediction pipeline.
80
+
81
+ Args:
82
+ prediction (`np.ndarray`, `torch.Tensor`):
83
+ Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
84
+ \times width$, regardless of whether the images were passed as a 4D array or a list.
85
+ uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
86
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
87
+ \times 1 \times height \times width$.
88
+ latent (`None`, `torch.Tensor`):
89
+ Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
90
+ The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
91
+ """
92
+
93
+ prediction: Union[np.ndarray, torch.Tensor]
94
+ latent: Union[None, torch.Tensor]
95
+ gaus_noise: Union[None, torch.Tensor]
96
+
97
+
98
+ class OneStepPipeline(StableDiffusionControlNetPipeline):
99
+ """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
100
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
101
+
102
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
103
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
104
+
105
+ The pipeline also inherits the following loading methods:
106
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
107
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
108
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
109
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
110
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
111
+
112
+ Args:
113
+ vae ([`AutoencoderKL`]):
114
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
115
+ text_encoder ([`~transformers.CLIPTextModel`]):
116
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
117
+ tokenizer ([`~transformers.CLIPTokenizer`]):
118
+ A `CLIPTokenizer` to tokenize text.
119
+ unet ([`UNet2DConditionModel`]):
120
+ A `UNet2DConditionModel` to denoise the encoded image latents.
121
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
122
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
123
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
124
+ additional conditioning.
125
+ scheduler ([`SchedulerMixin`]):
126
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
127
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
128
+ safety_checker ([`StableDiffusionSafetyChecker`]):
129
+ Classification module that estimates whether generated images could be considered offensive or harmful.
130
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
131
+ about a model's potential harms.
132
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
133
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
134
+ """
135
+
136
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
137
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
138
+ _exclude_from_cpu_offload = ["safety_checker"]
139
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
140
+
141
+
142
+
143
+ def __init__(
144
+ self,
145
+ vae: AutoencoderKL,
146
+ text_encoder: CLIPTextModel,
147
+ tokenizer: CLIPTokenizer,
148
+ unet: UNet2DConditionModel,
149
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
150
+ scheduler: Union[DDIMScheduler],
151
+ safety_checker: StableDiffusionSafetyChecker,
152
+ feature_extractor: CLIPImageProcessor,
153
+ image_encoder: CLIPVisionModelWithProjection = None,
154
+ requires_safety_checker: bool = True,
155
+ default_denoising_steps: Optional[int] = 1,
156
+ default_processing_resolution: Optional[int] = 768,
157
+ prompt="remove glass reflection",
158
+ empty_text_embedding=None,
159
+ t_start: Optional[int] = 401,
160
+ ):
161
+ super().__init__(
162
+ vae,
163
+ text_encoder,
164
+ tokenizer,
165
+ unet,
166
+ controlnet,
167
+ scheduler,
168
+ safety_checker,
169
+ feature_extractor,
170
+ image_encoder,
171
+ requires_safety_checker,
172
+ )
173
+
174
+ self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
175
+ self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
176
+ self.default_denoising_steps = default_denoising_steps
177
+ self.default_processing_resolution = default_processing_resolution
178
+ self.prompt = prompt
179
+ self.prompt_embeds = None
180
+ self.empty_text_embedding = empty_text_embedding
181
+ self.t_start= t_start # target_out latents
182
+
183
+ def check_inputs(
184
+ self,
185
+ image: PipelineImageInput,
186
+ num_inference_steps: int,
187
+ ensemble_size: int,
188
+ processing_resolution: int,
189
+ resample_method_input: str,
190
+ resample_method_output: str,
191
+ batch_size: int,
192
+ ensembling_kwargs: Optional[Dict[str, Any]],
193
+ latents: Optional[torch.Tensor],
194
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
195
+ output_type: str,
196
+ output_uncertainty: bool,
197
+ ) -> int:
198
+ if num_inference_steps is None:
199
+ raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
200
+ if num_inference_steps < 1:
201
+ raise ValueError("`num_inference_steps` must be positive.")
202
+ if ensemble_size < 1:
203
+ raise ValueError("`ensemble_size` must be positive.")
204
+ if ensemble_size == 2:
205
+ logger.warning(
206
+ "`ensemble_size` == 2 results are similar to no ensembling (1); "
207
+ "consider increasing the value to at least 3."
208
+ )
209
+ if ensemble_size == 1 and output_uncertainty:
210
+ raise ValueError(
211
+ "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` "
212
+ "greater than 1."
213
+ )
214
+ if processing_resolution is None:
215
+ raise ValueError(
216
+ "`processing_resolution` is not specified and could not be resolved from the model config."
217
+ )
218
+ if processing_resolution < 0:
219
+ raise ValueError(
220
+ "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
221
+ "downsampled processing."
222
+ )
223
+ if processing_resolution % self.vae_scale_factor != 0:
224
+ raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
225
+ if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
226
+ raise ValueError(
227
+ "`resample_method_input` takes string values compatible with PIL library: "
228
+ "nearest, nearest-exact, bilinear, bicubic, area."
229
+ )
230
+ if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
231
+ raise ValueError(
232
+ "`resample_method_output` takes string values compatible with PIL library: "
233
+ "nearest, nearest-exact, bilinear, bicubic, area."
234
+ )
235
+ if batch_size < 1:
236
+ raise ValueError("`batch_size` must be positive.")
237
+ if output_type not in ["pt", "np"]:
238
+ raise ValueError("`output_type` must be one of `pt` or `np`.")
239
+ if latents is not None and generator is not None:
240
+ raise ValueError("`latents` and `generator` cannot be used together.")
241
+ if ensembling_kwargs is not None:
242
+ if not isinstance(ensembling_kwargs, dict):
243
+ raise ValueError("`ensembling_kwargs` must be a dictionary.")
244
+ if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"):
245
+ raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.")
246
+
247
+ # image checks
248
+ num_images = 0
249
+ W, H = None, None
250
+ if not isinstance(image, list):
251
+ image = [image]
252
+ for i, img in enumerate(image):
253
+ if isinstance(img, np.ndarray) or torch.is_tensor(img):
254
+ if img.ndim not in (2, 3, 4):
255
+ raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
256
+ H_i, W_i = img.shape[-2:]
257
+ N_i = 1
258
+ if img.ndim == 4:
259
+ N_i = img.shape[0]
260
+ elif isinstance(img, Image.Image):
261
+ W_i, H_i = img.size
262
+ N_i = 1
263
+ else:
264
+ raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
265
+ if W is None:
266
+ W, H = W_i, H_i
267
+ elif (W, H) != (W_i, H_i):
268
+ raise ValueError(
269
+ f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
270
+ )
271
+ num_images += N_i
272
+
273
+ # latents checks
274
+ if latents is not None:
275
+ if not torch.is_tensor(latents):
276
+ raise ValueError("`latents` must be a torch.Tensor.")
277
+ if latents.dim() != 4:
278
+ raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.")
279
+
280
+ if processing_resolution > 0:
281
+ max_orig = max(H, W)
282
+ new_H = H * processing_resolution // max_orig
283
+ new_W = W * processing_resolution // max_orig
284
+ if new_H == 0 or new_W == 0:
285
+ raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
286
+ W, H = new_W, new_H
287
+ w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
288
+ h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
289
+ shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w)
290
+
291
+ if latents.shape != shape_expected:
292
+ raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
293
+
294
+ # generator checks
295
+ if generator is not None:
296
+ if isinstance(generator, list):
297
+ if len(generator) != num_images * ensemble_size:
298
+ raise ValueError(
299
+ "The number of generators must match the total number of ensemble members for all input images."
300
+ )
301
+ if not all(g.device.type == generator[0].device.type for g in generator):
302
+ raise ValueError("`generator` device placement is not consistent in the list.")
303
+ elif not isinstance(generator, torch.Generator):
304
+ raise ValueError(f"Unsupported generator type: {type(generator)}.")
305
+
306
+ return num_images
307
+
308
+ def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
309
+ if not hasattr(self, "_progress_bar_config"):
310
+ self._progress_bar_config = {}
311
+ elif not isinstance(self._progress_bar_config, dict):
312
+ raise ValueError(
313
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
314
+ )
315
+
316
+ progress_bar_config = dict(**self._progress_bar_config)
317
+ progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
318
+ progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
319
+ if iterable is not None:
320
+ return tqdm(iterable, **progress_bar_config)
321
+ elif total is not None:
322
+ return tqdm(total=total, **progress_bar_config)
323
+ else:
324
+ raise ValueError("Either `total` or `iterable` has to be defined.")
325
+
326
+ @torch.no_grad()
327
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
328
+ def __call__(
329
+ self,
330
+ image: PipelineImageInput,
331
+ prompt: Union[str, List[str]] = None,
332
+ negative_prompt: Optional[Union[str, List[str]]] = None,
333
+ num_inference_steps: Optional[int] = None,
334
+ ensemble_size: int = 1,
335
+ processing_resolution: Optional[int] = None,
336
+ match_input_resolution: bool = True,
337
+ resample_method_input: str = "bilinear",
338
+ resample_method_output: str = "bilinear",
339
+ batch_size: int = 1,
340
+ ensembling_kwargs: Optional[Dict[str, Any]] = None,
341
+ latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
342
+ prompt_embeds: Optional[torch.Tensor] = None,
343
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
344
+ num_images_per_prompt: Optional[int] = 1,
345
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
346
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
347
+ output_type: str = "np",
348
+ output_uncertainty: bool = False,
349
+ output_latent: bool = False,
350
+ skip_preprocess: bool = False,
351
+ return_dict: bool = True,
352
+ **kwargs,
353
+ ):
354
+ """
355
+ Function invoked when calling the pipeline.
356
+
357
+ Args:
358
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
359
+ `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
360
+ arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
361
+ by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
362
+ three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
363
+ same width and height.
364
+ num_inference_steps (`int`, *optional*, defaults to `None`):
365
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
366
+ selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
367
+ for Marigold-LCM models.
368
+ ensemble_size (`int`, defaults to `1`):
369
+ Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
370
+ faster inference.
371
+ processing_resolution (`int`, *optional*, defaults to `None`):
372
+ Effective processing resolution. When set to `0`, matches the larger input image dimension. This
373
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
374
+ value `None` resolves to the optimal value from the model config.
375
+ match_input_resolution (`bool`, *optional*, defaults to `True`):
376
+ When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
377
+ side of the output will equal to `processing_resolution`.
378
+ resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
379
+ Resampling method used to resize input images to `processing_resolution`. The accepted values are:
380
+ `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
381
+ resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
382
+ Resampling method used to resize output predictions to match the input resolution. The accepted values
383
+ are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
384
+ batch_size (`int`, *optional*, defaults to `1`):
385
+ Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
386
+ ensembling_kwargs (`dict`, *optional*, defaults to `None`)
387
+ Extra dictionary with arguments for precise ensembling control. The following options are available:
388
+ - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in
389
+ every pixel location, can be either `"closest"` or `"mean"`.
390
+ latents (`torch.Tensor`, *optional*, defaults to `None`):
391
+ Latent noise tensors to replace the random initialization. These can be taken from the previous
392
+ function call's output.
393
+ generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`):
394
+ Random number generator object to ensure reproducibility.
395
+ output_type (`str`, *optional*, defaults to `"np"`):
396
+ Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
397
+ values are: `"np"` (numpy array) or `"pt"` (torch tensor).
398
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
399
+ When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
400
+ the `ensemble_size` argument is set to a value above 2.
401
+ output_latent (`bool`, *optional*, defaults to `False`):
402
+ When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
403
+ within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
404
+ `latents` argument.
405
+ return_dict (`bool`, *optional*, defaults to `True`):
406
+ Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
407
+
408
+ Examples:
409
+
410
+ Returns:
411
+ [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`:
412
+ If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a
413
+ `tuple` is returned where the first element is the prediction, the second element is the uncertainty
414
+ (or `None`), and the third is the latent (or `None`).
415
+ """
416
+
417
+ # 0. Resolving variables.
418
+ device = self._execution_device
419
+ dtype = self.dtype
420
+
421
+ # Model-specific optimal default values leading to fast and reasonable results.
422
+ if num_inference_steps is None:
423
+ num_inference_steps = self.default_denoising_steps
424
+ if processing_resolution is None:
425
+ processing_resolution = self.default_processing_resolution
426
+
427
+ # 1. Check inputs.
428
+ num_images = self.check_inputs(
429
+ image,
430
+ num_inference_steps,
431
+ ensemble_size,
432
+ processing_resolution,
433
+ resample_method_input,
434
+ resample_method_output,
435
+ batch_size,
436
+ ensembling_kwargs,
437
+ latents,
438
+ generator,
439
+ output_type,
440
+ output_uncertainty,
441
+ )
442
+
443
+
444
+ # 2. Prepare empty text conditioning.
445
+ # Model invocation: self.tokenizer, self.text_encoder.
446
+ if self.empty_text_embedding is None:
447
+ prompt = ""
448
+ text_inputs = self.tokenizer(
449
+ prompt,
450
+ padding="do_not_pad",
451
+ max_length=self.tokenizer.model_max_length,
452
+ truncation=True,
453
+ return_tensors="pt",
454
+ )
455
+ text_input_ids = text_inputs.input_ids.to(device)
456
+ self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
457
+
458
+
459
+
460
+ # 3. prepare prompt
461
+ if self.prompt_embeds is None:
462
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
463
+ self.prompt,
464
+ device,
465
+ num_images_per_prompt,
466
+ False,
467
+ negative_prompt,
468
+ prompt_embeds=prompt_embeds,
469
+ negative_prompt_embeds=None,
470
+ lora_scale=None,
471
+ clip_skip=None,
472
+ )
473
+ self.prompt_embeds = prompt_embeds
474
+ self.negative_prompt_embeds = negative_prompt_embeds
475
+
476
+
477
+
478
+ # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
479
+ # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
480
+ # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
481
+ # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
482
+ # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
483
+ # operation and leads to the most reasonable results. Using the native image resolution or any other processing
484
+ # resolution can lead to loss of either fine details or global context in the output predictions.
485
+ if not skip_preprocess:
486
+ image, padding, original_resolution = self.image_processor.preprocess(
487
+ image, processing_resolution, resample_method_input, device, dtype
488
+ ) # [N,3,PPH,PPW]
489
+ else:
490
+ padding = (0, 0)
491
+ original_resolution = image.shape[2:]
492
+ # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
493
+ # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
494
+ # Latents of each such predictions across all input images and all ensemble members are represented in the
495
+ # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
496
+ # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
497
+ # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
498
+ # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
499
+ # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
500
+ # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
501
+ # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
502
+ # Model invocation: self.vae.encoder.
503
+ image_latent, pred_latent = self.prepare_latents(
504
+ image, latents, generator, ensemble_size, batch_size
505
+ ) # [N*E,4,h,w], [N*E,4,h,w]
506
+
507
+ gaus_noise = pred_latent.detach().clone()
508
+ del image
509
+
510
+
511
+ # 6. obtain control_output
512
+
513
+ cond_scale =controlnet_conditioning_scale
514
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
515
+ image_latent.detach(),
516
+ self.t_start,
517
+ encoder_hidden_states=self.prompt_embeds,
518
+ conditioning_scale=cond_scale,
519
+ guess_mode=False,
520
+ return_dict=False,
521
+ )
522
+
523
+ # 7. Onestep sampling
524
+ latent_x_t = self.unet(
525
+ pred_latent,
526
+ self.t_start,
527
+ encoder_hidden_states=self.prompt_embeds,
528
+ down_block_additional_residuals=down_block_res_samples,
529
+ mid_block_additional_residual=mid_block_res_sample,
530
+ return_dict=False,
531
+ )[0]
532
+
533
+
534
+ del (
535
+ pred_latent,
536
+ image_latent,
537
+ )
538
+
539
+ # decoder
540
+ prediction = self.decode_prediction(latent_x_t)
541
+
542
+ prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
543
+
544
+ prediction = self.image_processor.resize_antialias(
545
+ prediction, original_resolution, resample_method_output, is_aa=False
546
+ ) # [N,3,H,W]
547
+
548
+ if output_type == "np":
549
+ prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3]
550
+
551
+ # 11. Offload all models
552
+ self.maybe_free_model_hooks()
553
+
554
+ return DAIOutput(
555
+ prediction=prediction,
556
+ latent=latent_x_t,
557
+ gaus_noise=gaus_noise,
558
+ )
559
+
560
+ # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
561
+ def prepare_latents(
562
+ self,
563
+ image: torch.Tensor,
564
+ latents: Optional[torch.Tensor],
565
+ generator: Optional[torch.Generator],
566
+ ensemble_size: int,
567
+ batch_size: int,
568
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
569
+ def retrieve_latents(encoder_output):
570
+ if hasattr(encoder_output, "latent_dist"):
571
+ return encoder_output.latent_dist.mode()
572
+ elif hasattr(encoder_output, "latents"):
573
+ return encoder_output.latents
574
+ else:
575
+ raise AttributeError("Could not access latents of provided encoder_output")
576
+
577
+
578
+
579
+ image_latent = torch.cat(
580
+ [
581
+ retrieve_latents(self.vae.encode(image[i : i + batch_size]))
582
+ for i in range(0, image.shape[0], batch_size)
583
+ ],
584
+ dim=0,
585
+ ) # [N,4,h,w]
586
+ image_latent = image_latent * self.vae.config.scaling_factor
587
+ image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
588
+
589
+ pred_latent = torch.zeros_like(image_latent)
590
+ if pred_latent is None:
591
+ pred_latent = randn_tensor(
592
+ image_latent.shape,
593
+ generator=generator,
594
+ device=image_latent.device,
595
+ dtype=image_latent.dtype,
596
+ ) # [N*E,4,h,w]
597
+
598
+ return image_latent, pred_latent
599
+
600
+ def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
601
+ if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
602
+ raise ValueError(
603
+ f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
604
+ )
605
+
606
+ prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
607
+
608
+ return prediction # [B,3,H,W]
609
+
610
+ @staticmethod
611
+ def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
612
+ if normals.dim() != 4 or normals.shape[1] != 3:
613
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
614
+
615
+ norm = torch.norm(normals, dim=1, keepdim=True)
616
+ normals /= norm.clamp(min=eps)
617
+
618
+ return normals
619
+
620
+ @staticmethod
621
+ def ensemble_normals(
622
+ normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest"
623
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
624
+ """
625
+ Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
626
+ the number of ensemble members for a given prediction of size `(H x W)`.
627
+
628
+ Args:
629
+ normals (`torch.Tensor`):
630
+ Input ensemble normals maps.
631
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
632
+ Whether to output uncertainty map.
633
+ reduction (`str`, *optional*, defaults to `"closest"`):
634
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
635
+ `"mean"`.
636
+
637
+ Returns:
638
+ A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
639
+ uncertainties of shape `(1, 1, H, W)`.
640
+ """
641
+ if normals.dim() != 4 or normals.shape[1] != 3:
642
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
643
+ if reduction not in ("closest", "mean"):
644
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
645
+
646
+ mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W]
647
+ mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W]
648
+
649
+ sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W]
650
+ sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16
651
+
652
+ uncertainty = None
653
+ if output_uncertainty:
654
+ uncertainty = sim_cos.arccos() # [E,1,H,W]
655
+ uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W]
656
+
657
+ if reduction == "mean":
658
+ return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W]
659
+
660
+ closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W]
661
+ closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W]
662
+ closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W]
663
+
664
+ return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W]
665
+
666
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
667
+ def retrieve_timesteps(
668
+ scheduler,
669
+ num_inference_steps: Optional[int] = None,
670
+ device: Optional[Union[str, torch.device]] = None,
671
+ timesteps: Optional[List[int]] = None,
672
+ sigmas: Optional[List[float]] = None,
673
+ **kwargs,
674
+ ):
675
+ """
676
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
677
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
678
+
679
+ Args:
680
+ scheduler (`SchedulerMixin`):
681
+ The scheduler to get timesteps from.
682
+ num_inference_steps (`int`):
683
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
684
+ must be `None`.
685
+ device (`str` or `torch.device`, *optional*):
686
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
687
+ timesteps (`List[int]`, *optional*):
688
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
689
+ `num_inference_steps` and `sigmas` must be `None`.
690
+ sigmas (`List[float]`, *optional*):
691
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
692
+ `num_inference_steps` and `timesteps` must be `None`.
693
+
694
+ Returns:
695
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
696
+ second element is the number of inference steps.
697
+ """
698
+ if timesteps is not None and sigmas is not None:
699
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
700
+ if timesteps is not None:
701
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
702
+ if not accepts_timesteps:
703
+ raise ValueError(
704
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
705
+ f" timestep schedules. Please check whether you are using the correct scheduler."
706
+ )
707
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
708
+ timesteps = scheduler.timesteps
709
+ num_inference_steps = len(timesteps)
710
+ elif sigmas is not None:
711
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
712
+ if not accept_sigmas:
713
+ raise ValueError(
714
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
715
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
716
+ )
717
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
718
+ timesteps = scheduler.timesteps
719
+ num_inference_steps = len(timesteps)
720
+ else:
721
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
722
+ timesteps = scheduler.timesteps
723
+ return timesteps, num_inference_steps
app.py CHANGED
@@ -1,7 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+ from __future__ import annotations
20
+
21
+ import functools
22
+ import os
23
+ import tempfile
24
+
25
  import gradio as gr
26
+ import imageio as imageio
27
+ import numpy as np
28
+ import spaces
29
+ import torch as torch
30
+ torch.backends.cuda.matmul.allow_tf32 = True
31
+ from PIL import Image
32
+ from gradio_imageslider import ImageSlider
33
+ from tqdm import tqdm
34
+
35
+ from pathlib import Path
36
+ import gradio
37
+ from gradio.utils import get_cache_folder
38
+ from DAI.pipeline_all import DAIPipeline
39
+
40
+ from diffusers import (
41
+ AutoencoderKL,
42
+ UNet2DConditionModel,
43
+ )
44
+
45
+ from transformers import CLIPTextModel, AutoTokenizer
46
+
47
+ from DAI.controlnetvae import ControlNetVAEModel
48
+
49
+ from DAI.decoder import CustomAutoencoderKL
50
+
51
+
52
+ class Examples(gradio.helpers.Examples):
53
+ def __init__(self, *args, directory_name=None, **kwargs):
54
+ super().__init__(*args, **kwargs, _initiated_directly=False)
55
+ if directory_name is not None:
56
+ self.cached_folder = get_cache_folder() / directory_name
57
+ self.cached_file = Path(self.cached_folder) / "log.csv"
58
+ self.create()
59
+
60
+
61
+ default_seed = 2024
62
+ default_batch_size = 1
63
+
64
+ def process_image_check(path_input):
65
+ if path_input is None:
66
+ raise gr.Error(
67
+ "Missing image in the first pane: upload a file or use one from the gallery below."
68
+ )
69
+
70
+ def resize_image(input_image, resolution):
71
+ # Ensure input_image is a PIL Image object
72
+ if not isinstance(input_image, Image.Image):
73
+ raise ValueError("input_image should be a PIL Image object")
74
+
75
+ # Convert image to numpy array
76
+ input_image_np = np.asarray(input_image)
77
+
78
+ # Get image dimensions
79
+ H, W, C = input_image_np.shape
80
+ H = float(H)
81
+ W = float(W)
82
+
83
+ # Calculate the scaling factor
84
+ k = float(resolution) / min(H, W)
85
+
86
+ # Determine new dimensions
87
+ H *= k
88
+ W *= k
89
+ H = int(np.round(H / 64.0)) * 64
90
+ W = int(np.round(W / 64.0)) * 64
91
+
92
+ # Resize the image using PIL's resize method
93
+ img = input_image.resize((W, H), Image.Resampling.LANCZOS)
94
+
95
+ return img
96
+
97
+ def process_image(
98
+ pipe,
99
+ vae_2,
100
+ path_input,
101
+ ):
102
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
103
+ print(f"Processing image {name_base}{name_ext}")
104
+
105
+ path_output_dir = tempfile.mkdtemp()
106
+ path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
107
+ input_image = Image.open(path_input)
108
+ resolution = None
109
+
110
+ pipe_out = pipe(
111
+ image=input_image,
112
+ prompt="remove glass reflection",
113
+ vae_2=vae_2,
114
+ processing_resolution=resolution,
115
+ )
116
+
117
+ processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
118
+ processed_frame = (processed_frame[0] * 255).astype(np.uint8)
119
+ processed_frame = Image.fromarray(processed_frame)
120
+ processed_frame.save(path_out_png)
121
+ yield [input_image, path_out_png]
122
+
123
+ def run_demo_server(pipe, vae_2):
124
+ process_pipe_image = spaces.GPU(functools.partial(process_image, pipe, vae_2))
125
+
126
+ gradio_theme = gr.themes.Default()
127
+
128
+ with gr.Blocks(
129
+ theme=gradio_theme,
130
+ title="Dereflection Any Image",
131
+ css="""
132
+ #download {
133
+ height: 118px;
134
+ }
135
+ .slider .inner {
136
+ width: 5px;
137
+ background: #FFF;
138
+ }
139
+ .viewport {
140
+ aspect-ratio: 4/3;
141
+ }
142
+ .tabs button.selected {
143
+ font-size: 20px !important;
144
+ color: crimson !important;
145
+ }
146
+ h1 {
147
+ text-align: center;
148
+ display: block;
149
+ }
150
+ h2 {
151
+ text-align: center;
152
+ display: block;
153
+ }
154
+ h3 {
155
+ text-align: center;
156
+ display: block;
157
+ }
158
+ .md_feedback li {
159
+ margin-bottom: 0px !important;
160
+ }
161
+ """,
162
+ head="""
163
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
164
+ <script>
165
+ window.dataLayer = window.dataLayer || [];
166
+ function gtag() {dataLayer.push(arguments);}
167
+ gtag('js', new Date());
168
+ gtag('config', 'G-1FWSVCGZTG');
169
+ </script>
170
+ """,
171
+ ) as demo:
172
+ gr.Markdown(
173
+ """
174
+ # Dereflection Any Image
175
+ <p align="center">
176
+ """
177
+ )
178
+
179
+ with gr.Tabs(elem_classes=["tabs"]):
180
+ with gr.Tab("Image"):
181
+ with gr.Row():
182
+ with gr.Column():
183
+ image_input = gr.Image(
184
+ label="Input Image",
185
+ type="filepath",
186
+ )
187
+ with gr.Row():
188
+ image_submit_btn = gr.Button(
189
+ value="remove reflection", variant="primary"
190
+ )
191
+ image_reset_btn = gr.Button(value="Reset")
192
+ with gr.Column():
193
+ image_output_slider = ImageSlider(
194
+ label="outputs",
195
+ type="filepath",
196
+ show_download_button=True,
197
+ show_share_button=True,
198
+ interactive=False,
199
+ elem_classes="slider",
200
+ # position=0.25,
201
+ )
202
+
203
+ Examples(
204
+ fn=process_pipe_image,
205
+ examples=sorted([
206
+ os.path.join("files", "image", name)
207
+ for name in os.listdir(os.path.join("files", "image"))
208
+ ]),
209
+ inputs=[image_input],
210
+ outputs=[image_output_slider],
211
+ cache_examples=False,
212
+ directory_name="examples_image",
213
+ )
214
+
215
+
216
+ ### Image tab
217
+ image_submit_btn.click(
218
+ fn=process_image_check,
219
+ inputs=image_input,
220
+ outputs=None,
221
+ preprocess=False,
222
+ queue=False,
223
+ ).success(
224
+ fn=process_pipe_image,
225
+ inputs=[
226
+ image_input,
227
+ ],
228
+ outputs=[image_output_slider],
229
+ concurrency_limit=1,
230
+ )
231
+
232
+ image_reset_btn.click(
233
+ fn=lambda: (
234
+ None,
235
+ None,
236
+ None,
237
+ ),
238
+ inputs=[],
239
+ outputs=[
240
+ image_input,
241
+ image_output_slider,
242
+ ],
243
+ queue=False,
244
+ )
245
+
246
+
247
+ ### Server launch
248
+
249
+ demo.queue(
250
+ api_open=False,
251
+ ).launch(
252
+ server_name="0.0.0.0",
253
+ server_port=7860,
254
+ )
255
+
256
+
257
+ def main():
258
+ os.system("pip freeze")
259
+
260
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
261
+
262
+ weight_dtype = torch.float32
263
+ model_dir = "./weights"
264
+ pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
265
+ pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
266
+ revision = None
267
+ variant = None
268
+ # Load the model
269
+ controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
270
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
271
+ vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
272
+
273
+ # Load other components of the pipeline
274
+ vae = AutoencoderKL.from_pretrained(
275
+ pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
276
+ ).to(device)
277
+
278
+ text_encoder = CLIPTextModel.from_pretrained(
279
+ pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
280
+ ).to(device)
281
+ tokenizer = AutoTokenizer.from_pretrained(
282
+ pretrained_model_name_or_path2,
283
+ subfolder="tokenizer",
284
+ revision=revision,
285
+ use_fast=False,
286
+ )
287
+ pipe = DAIPipeline(
288
+ vae=vae,
289
+ text_encoder=text_encoder,
290
+ tokenizer=tokenizer,
291
+ unet=unet,
292
+ controlnet=controlnet,
293
+ safety_checker=None,
294
+ scheduler=None,
295
+ feature_extractor=None,
296
+ t_start=0,
297
+ ).to(device)
298
+
299
+ try:
300
+ import xformers
301
+ pipe.enable_xformers_memory_efficient_attention()
302
+ except:
303
+ pass # run without xformers
304
+
305
+ run_demo_server(pipe, vae_2)
306
 
 
 
307
 
308
+ if __name__ == "__main__":
309
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ gradio
3
+ gradio_imageslider
4
+ torch
5
+ transformers
6
+ pillow
7
+ numpy
8
+ xformers
utils/image_utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+
14
+ def mse(img1, img2):
15
+ return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
16
+
17
+ def psnr(img1, img2):
18
+ mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
19
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
20
+
21
+ # torchmetrics
utils/loss_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch.autograd import Variable
15
+ from math import exp
16
+
17
+ def l1_loss(network_output, gt):
18
+ return torch.abs((network_output - gt)).mean()
19
+
20
+ def l2_loss(network_output, gt):
21
+ return ((network_output - gt) ** 2).mean()
22
+
23
+ def gaussian(window_size, sigma):
24
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
25
+ return gauss / gauss.sum()
26
+
27
+ def create_window(window_size, channel):
28
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
29
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
30
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
31
+ return window
32
+
33
+ def ssim(img1, img2, window_size=11, size_average=True):
34
+ channel = img1.size(-3)
35
+ window = create_window(window_size, channel)
36
+
37
+ if img1.is_cuda:
38
+ window = window.cuda(img1.get_device())
39
+ window = window.type_as(img1)
40
+
41
+ return _ssim(img1, img2, window, window_size, channel, size_average)
42
+
43
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
44
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
45
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
46
+
47
+ mu1_sq = mu1.pow(2)
48
+ mu2_sq = mu2.pow(2)
49
+ mu1_mu2 = mu1 * mu2
50
+
51
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
52
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
53
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
54
+
55
+ C1 = 0.01 ** 2
56
+ C2 = 0.03 ** 2
57
+
58
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
59
+
60
+ if size_average:
61
+ return ssim_map.mean()
62
+ else:
63
+ return ssim_map.mean(1).mean(1).mean(1)
64
+