File size: 6,753 Bytes
adf1965
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import einops
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from functools import partial
from torchvision.utils import make_grid
from ldm.util import default, count_params, instantiate_from_config
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor
from ldm.models.diffusion.ddim import DDIMSampler
from torchvision.transforms import Resize
import random



def disabled_train(self, mode=True):
    return self


class DiffusionWrapper(pl.LightningModule):
    def __init__(self, unet_config):
        super().__init__()
        self.diffusion_model = instantiate_from_config(unet_config)

    def forward(self, x, timesteps=None, context=None, control=None):
        out = self.diffusion_model(x, timesteps, context, control)
        return out


class DDPM(pl.LightningModule):
    def __init__(self,
                 unet_config,
                 linear_start=1e-4,                 # 0.00085
                 linear_end=2e-2,                   # 0.0120
                 log_every_t=100,                   # 200
                 timesteps=1000,                    # 1000
                 image_size=256,                    # 32
                 channels=3,                        # 4
                 u_cond_percent=0,                  # 0.2
                 use_ema=True,                      # False
                 beta_schedule="linear",
                 loss_type="l2",
                 clip_denoised=True,
                 cosine_s=8e-3,
                 original_elbo_weight=0.,
                 v_posterior=0.,
                 l_simple_weight=1.,              
                 parameterization="eps",
                 use_positional_encodings=False,
                 learn_logvar=False,
                 logvar_init=0.):
        super().__init__()
        self.parameterization = parameterization
        self.cond_stage_model = None
        self.clip_denoised = clip_denoised
        self.log_every_t = log_every_t
        self.image_size = image_size 
        self.channels = channels
        self.u_cond_percent=u_cond_percent
        self.use_positional_encodings = use_positional_encodings
        self.model = DiffusionWrapper(unet_config) # 调用 UNet 模型

        self.use_ema = use_ema
        self.use_scheduler = True
        self.v_posterior = v_posterior
        self.original_elbo_weight = original_elbo_weight
        self.l_simple_weight = l_simple_weight
        self.register_schedule(beta_schedule=beta_schedule,     # "linear"
                               timesteps=timesteps,             # 1000
                               linear_start=linear_start,       # 0.00085
                               linear_end=linear_end,           # 0.0120
                               cosine_s=cosine_s)               # 8e-3
        self.loss_type = loss_type
        self.learn_logvar = learn_logvar
        self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))

    def register_schedule(self, 
                          beta_schedule="linear", 
                          timesteps=1000,
                          linear_start=0.00085, 
                          linear_end=0.0120, 
                          cosine_s=8e-3):
        betas = make_beta_schedule(beta_schedule, 
                                   timesteps, 
                                   linear_start=linear_start, 
                                   linear_end=linear_end,
                                   cosine_s=cosine_s)
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end
        to_torch = partial(torch.tensor, dtype=torch.float32)
        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + self.v_posterior * betas
        self.register_buffer('posterior_variance', to_torch(posterior_variance))
        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
        self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
        self.register_buffer('posterior_mean_coef2', to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
        lvlb_weights = self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
        lvlb_weights[0] = lvlb_weights[1]
        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)

    def get_input(self, batch):
        x = batch['GT']
        mask = batch['inpaint_mask']
        inpaint = batch['inpaint_image']
        reference = batch['ref_imgs']
        hint = batch['hint']

        x = x.to(memory_format=torch.contiguous_format).float()
        mask = mask.to(memory_format=torch.contiguous_format).float()
        inpaint = inpaint.to(memory_format=torch.contiguous_format).float()
        reference = reference.to(memory_format=torch.contiguous_format).float()
        hint = hint.to(memory_format=torch.contiguous_format).float()

        return x, inpaint, mask, reference, hint

    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

    def get_loss(self, pred, target, mean=True):
        if mean:
            loss = torch.nn.functional.mse_loss(target, pred)
        else:
            loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
        return loss