MultiMatrix commited on
Commit
c46568a
·
verified ·
1 Parent(s): 5d60839

Upload 6 files

Browse files
utils/common.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Mapping, Any, Tuple, Callable
2
+ import importlib
3
+ import os
4
+ from urllib.parse import urlparse
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.nn import functional as F
9
+ import numpy as np
10
+
11
+ from torch.hub import download_url_to_file, get_dir
12
+
13
+
14
+ def get_obj_from_str(string: str, reload: bool=False) -> Any:
15
+ module, cls = string.rsplit(".", 1)
16
+ if reload:
17
+ module_imp = importlib.import_module(module)
18
+ importlib.reload(module_imp)
19
+ return getattr(importlib.import_module(module, package=None), cls)
20
+
21
+
22
+ def instantiate_from_config(config: Mapping[str, Any]) -> Any:
23
+ if not "target" in config:
24
+ raise KeyError("Expected key `target` to instantiate.")
25
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
26
+
27
+
28
+ def wavelet_blur(image: Tensor, radius: int):
29
+ """
30
+ Apply wavelet blur to the input tensor.
31
+ """
32
+ # input shape: (1, 3, H, W)
33
+ # convolution kernel
34
+ kernel_vals = [
35
+ [0.0625, 0.125, 0.0625],
36
+ [0.125, 0.25, 0.125],
37
+ [0.0625, 0.125, 0.0625],
38
+ ]
39
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
40
+ # add channel dimensions to the kernel to make it a 4D tensor
41
+ kernel = kernel[None, None]
42
+ # repeat the kernel across all input channels
43
+ kernel = kernel.repeat(3, 1, 1, 1)
44
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
45
+ # apply convolution
46
+ output = F.conv2d(image, kernel, groups=3, dilation=radius)
47
+ return output
48
+
49
+
50
+ def wavelet_decomposition(image: Tensor, levels=5):
51
+ """
52
+ Apply wavelet decomposition to the input tensor.
53
+ This function only returns the low frequency & the high frequency.
54
+ """
55
+ high_freq = torch.zeros_like(image)
56
+ for i in range(levels):
57
+ radius = 2 ** i
58
+ low_freq = wavelet_blur(image, radius)
59
+ high_freq += (image - low_freq)
60
+ image = low_freq
61
+
62
+ return high_freq, low_freq
63
+
64
+
65
+ def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
66
+ """
67
+ Apply wavelet decomposition, so that the content will have the same color as the style.
68
+ """
69
+ # calculate the wavelet decomposition of the content feature
70
+ content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
71
+ del content_low_freq
72
+ # calculate the wavelet decomposition of the style feature
73
+ style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
74
+ del style_high_freq
75
+ # reconstruct the content feature with the style's high frequency
76
+ return content_high_freq + style_low_freq
77
+
78
+
79
+ # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
80
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
81
+ """Load file form http url, will download models if necessary.
82
+
83
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
84
+
85
+ Args:
86
+ url (str): URL to be downloaded.
87
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
88
+ Default: None.
89
+ progress (bool): Whether to show the download progress. Default: True.
90
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
91
+
92
+ Returns:
93
+ str: The path to the downloaded file.
94
+ """
95
+ if model_dir is None: # use the pytorch hub_dir
96
+ hub_dir = get_dir()
97
+ model_dir = os.path.join(hub_dir, 'checkpoints')
98
+
99
+ os.makedirs(model_dir, exist_ok=True)
100
+
101
+ parts = urlparse(url)
102
+ filename = os.path.basename(parts.path)
103
+ if file_name is not None:
104
+ filename = file_name
105
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
106
+ if not os.path.exists(cached_file):
107
+ print(f'Downloading: "{url}" to {cached_file}\n')
108
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
109
+ return cached_file
110
+
111
+
112
+ def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]:
113
+ hi_list = list(range(0, h - tile_size + 1, tile_stride))
114
+ if (h - tile_size) % tile_stride != 0:
115
+ hi_list.append(h - tile_size)
116
+
117
+ wi_list = list(range(0, w - tile_size + 1, tile_stride))
118
+ if (w - tile_size) % tile_stride != 0:
119
+ wi_list.append(w - tile_size)
120
+
121
+ coords = []
122
+ for hi in hi_list:
123
+ for wi in wi_list:
124
+ coords.append((hi, hi + tile_size, wi, wi + tile_size))
125
+ return coords
126
+
127
+
128
+ # https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503
129
+ def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray:
130
+ """Generates a gaussian mask of weights for tile contributions"""
131
+ latent_width = tile_width
132
+ latent_height = tile_height
133
+ var = 0.01
134
+ midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
135
+ x_probs = [
136
+ np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var)
137
+ for x in range(latent_width)]
138
+ midpoint = latent_height / 2
139
+ y_probs = [
140
+ np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var)
141
+ for y in range(latent_height)]
142
+ weights = np.outer(y_probs, x_probs)
143
+ return weights
144
+
145
+
146
+ COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False))
147
+
148
+ def count_vram_usage(func: Callable) -> Callable:
149
+ if not COUNT_VRAM:
150
+ return func
151
+
152
+ def wrapper(*args, **kwargs):
153
+ peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3)
154
+ ret = func(*args, **kwargs)
155
+ torch.cuda.synchronize()
156
+ peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3)
157
+ print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB")
158
+ return ret
159
+ return wrapper
utils/cond_fn.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import overload, Tuple
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class Guidance:
7
+
8
+ def __init__(self, scale: float, t_start: int, t_stop: int, space: str, repeat: int) -> "Guidance":
9
+ """
10
+ Initialize restoration guidance.
11
+
12
+ Args:
13
+ scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale,
14
+ the closer the final result will be to the output of the first stage model.
15
+ t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling
16
+ process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`.
17
+ space (str): The data space for computing loss function (rgb or latent).
18
+
19
+ Our restoration guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior).
20
+ Thanks for their work!
21
+ """
22
+ self.scale = scale * 3000
23
+ self.t_start = t_start
24
+ self.t_stop = t_stop
25
+ self.target = None
26
+ self.space = space
27
+ self.repeat = repeat
28
+
29
+ def load_target(self, target: torch.Tensor) -> None:
30
+ self.target = target
31
+
32
+ def __call__(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
33
+ # avoid propagating gradient out of this scope
34
+ pred_x0 = pred_x0.detach().clone()
35
+ target_x0 = target_x0.detach().clone()
36
+ return self._forward(target_x0, pred_x0, t)
37
+
38
+ @overload
39
+ def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
40
+ ...
41
+
42
+
43
+ class MSEGuidance(Guidance):
44
+
45
+ def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
46
+ # inputs: [-1, 1], nchw, rgb
47
+ with torch.enable_grad():
48
+ pred_x0.requires_grad_(True)
49
+ loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
50
+ scale = self.scale
51
+ g = -torch.autograd.grad(loss, pred_x0)[0] * scale
52
+ return g, loss.item()
53
+
54
+
55
+ class WeightedMSEGuidance(Guidance):
56
+
57
+ def _get_weight(self, target: torch.Tensor) -> torch.Tensor:
58
+ # convert RGB to G
59
+ rgb_to_gray_kernel = torch.tensor([0.2989, 0.5870, 0.1140]).view(1, 3, 1, 1)
60
+ target = torch.sum(target * rgb_to_gray_kernel.to(target.device), dim=1, keepdim=True)
61
+ # initialize sobel kernel in x and y axis
62
+ G_x = [
63
+ [1, 0, -1],
64
+ [2, 0, -2],
65
+ [1, 0, -1]
66
+ ]
67
+ G_y = [
68
+ [1, 2, 1],
69
+ [0, 0, 0],
70
+ [-1, -2, -1]
71
+ ]
72
+ G_x = torch.tensor(G_x, dtype=target.dtype, device=target.device)[None]
73
+ G_y = torch.tensor(G_y, dtype=target.dtype, device=target.device)[None]
74
+ G = torch.stack((G_x, G_y))
75
+
76
+ target = F.pad(target, (1, 1, 1, 1), mode='replicate') # padding = 1
77
+ grad = F.conv2d(target, G, stride=1)
78
+ mag = grad.pow(2).sum(dim=1, keepdim=True).sqrt()
79
+
80
+ n, c, h, w = mag.size()
81
+ block_size = 2
82
+ blocks = mag.view(n, c, h // block_size, block_size, w // block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous()
83
+ block_mean = blocks.sum(dim=(-2, -1), keepdim=True).tanh().repeat(1, 1, 1, 1, block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous()
84
+ block_mean = block_mean.view(n, c, h, w)
85
+ weight_map = 1 - block_mean
86
+
87
+ return weight_map
88
+
89
+ def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
90
+ # inputs: [-1, 1], nchw, rgb
91
+ with torch.no_grad():
92
+ w = self._get_weight((target_x0 + 1) / 2)
93
+ with torch.enable_grad():
94
+ pred_x0.requires_grad_(True)
95
+ loss = ((pred_x0 - target_x0).pow(2) * w).mean((1, 2, 3)).sum()
96
+ scale = self.scale
97
+ g = -torch.autograd.grad(loss, pred_x0)[0] * scale
98
+ return g, loss.item()
utils/face_restoration_helper.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import torch
5
+ from torchvision.transforms.functional import normalize
6
+
7
+ from facexlib.detection import init_detection_model
8
+ from facexlib.parsing import init_parsing_model
9
+ from facexlib.utils.misc import img2tensor, imwrite
10
+
11
+ from utils.common import load_file_from_url
12
+
13
+ def get_largest_face(det_faces, h, w):
14
+
15
+ def get_location(val, length):
16
+ if val < 0:
17
+ return 0
18
+ elif val > length:
19
+ return length
20
+ else:
21
+ return val
22
+
23
+ face_areas = []
24
+ for det_face in det_faces:
25
+ left = get_location(det_face[0], w)
26
+ right = get_location(det_face[2], w)
27
+ top = get_location(det_face[1], h)
28
+ bottom = get_location(det_face[3], h)
29
+ face_area = (right - left) * (bottom - top)
30
+ face_areas.append(face_area)
31
+ largest_idx = face_areas.index(max(face_areas))
32
+ return det_faces[largest_idx], largest_idx
33
+
34
+
35
+ def get_center_face(det_faces, h=0, w=0, center=None):
36
+ if center is not None:
37
+ center = np.array(center)
38
+ else:
39
+ center = np.array([w / 2, h / 2])
40
+ center_dist = []
41
+ for det_face in det_faces:
42
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
43
+ dist = np.linalg.norm(face_center - center)
44
+ center_dist.append(dist)
45
+ center_idx = center_dist.index(min(center_dist))
46
+ return det_faces[center_idx], center_idx
47
+
48
+
49
+ class FaceRestoreHelper(object):
50
+ """Helper for the face restoration pipeline (base class)."""
51
+
52
+ def __init__(self,
53
+ upscale_factor,
54
+ face_size=512,
55
+ crop_ratio=(1, 1),
56
+ det_model='retinaface_resnet50',
57
+ save_ext='png',
58
+ template_3points=False,
59
+ pad_blur=False,
60
+ use_parse=False,
61
+ device=None):
62
+ self.template_3points = template_3points # improve robustness
63
+ self.upscale_factor = int(upscale_factor)
64
+ # the cropped face ratio based on the square face
65
+ self.crop_ratio = crop_ratio # (h, w)
66
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
67
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
68
+ self.det_model = det_model
69
+
70
+ if self.det_model == 'dlib':
71
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
72
+ self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
73
+ [337.91089109, 488.38613861], [437.95049505, 493.51485149],
74
+ [513.58415842, 678.5049505]])
75
+ self.face_template = self.face_template / (1024 // face_size)
76
+ elif self.template_3points:
77
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
78
+ else:
79
+ # standard 5 landmarks for FFHQ faces with 512 x 512
80
+ # facexlib
81
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
82
+ [201.26117, 371.41043], [313.08905, 371.15118]])
83
+
84
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
85
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
86
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
87
+
88
+ self.face_template = self.face_template * (face_size / 512.0)
89
+ if self.crop_ratio[0] > 1:
90
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
91
+ if self.crop_ratio[1] > 1:
92
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
93
+ self.save_ext = save_ext
94
+ self.pad_blur = pad_blur
95
+ if self.pad_blur is True:
96
+ self.template_3points = False
97
+
98
+ self.all_landmarks_5 = []
99
+ self.det_faces = []
100
+ self.affine_matrices = []
101
+ self.inverse_affine_matrices = []
102
+ self.cropped_faces = []
103
+ self.restored_faces = []
104
+ self.pad_input_imgs = []
105
+
106
+ if device is None:
107
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
108
+ # self.device = get_device()
109
+ else:
110
+ self.device = device
111
+
112
+ # init face detection model
113
+ self.face_detector = init_detection_model(det_model, half=False, device=self.device)
114
+
115
+ # init face parsing model
116
+ self.use_parse = use_parse
117
+ self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
118
+
119
+ def set_upscale_factor(self, upscale_factor):
120
+ self.upscale_factor = upscale_factor
121
+
122
+ def read_image(self, img):
123
+ """img can be image path or cv2 loaded image."""
124
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
125
+ if isinstance(img, str):
126
+ img = cv2.imread(img)
127
+
128
+ if np.max(img) > 256: # 16-bit image
129
+ img = img / 65535 * 255
130
+ if len(img.shape) == 2: # gray image
131
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
132
+ elif img.shape[2] == 4: # BGRA image with alpha channel
133
+ img = img[:, :, 0:3]
134
+
135
+ self.input_img = img
136
+ # self.is_gray = is_gray(img, threshold=10)
137
+ # if self.is_gray:
138
+ # print('Grayscale input: True')
139
+
140
+ if min(self.input_img.shape[:2])<512:
141
+ f = 512.0/min(self.input_img.shape[:2])
142
+ self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
143
+
144
+ def init_dlib(self, detection_path, landmark5_path):
145
+ """Initialize the dlib detectors and predictors."""
146
+ try:
147
+ import dlib
148
+ except ImportError:
149
+ print('Please install dlib by running:' 'conda install -c conda-forge dlib')
150
+ detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
151
+ landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
152
+ face_detector = dlib.cnn_face_detection_model_v1(detection_path)
153
+ shape_predictor_5 = dlib.shape_predictor(landmark5_path)
154
+ return face_detector, shape_predictor_5
155
+
156
+ def get_face_landmarks_5_dlib(self,
157
+ only_keep_largest=False,
158
+ scale=1):
159
+ det_faces = self.face_detector(self.input_img, scale)
160
+
161
+ if len(det_faces) == 0:
162
+ print('No face detected. Try to increase upsample_num_times.')
163
+ return 0
164
+ else:
165
+ if only_keep_largest:
166
+ print('Detect several faces and only keep the largest.')
167
+ face_areas = []
168
+ for i in range(len(det_faces)):
169
+ face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
170
+ det_faces[i].rect.bottom() - det_faces[i].rect.top())
171
+ face_areas.append(face_area)
172
+ largest_idx = face_areas.index(max(face_areas))
173
+ self.det_faces = [det_faces[largest_idx]]
174
+ else:
175
+ self.det_faces = det_faces
176
+
177
+ if len(self.det_faces) == 0:
178
+ return 0
179
+
180
+ for face in self.det_faces:
181
+ shape = self.shape_predictor_5(self.input_img, face.rect)
182
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
183
+ self.all_landmarks_5.append(landmark)
184
+
185
+ return len(self.all_landmarks_5)
186
+
187
+
188
+ def get_face_landmarks_5(self,
189
+ only_keep_largest=False,
190
+ only_center_face=False,
191
+ resize=None,
192
+ blur_ratio=0.01,
193
+ eye_dist_threshold=None):
194
+ if self.det_model == 'dlib':
195
+ return self.get_face_landmarks_5_dlib(only_keep_largest)
196
+
197
+ if resize is None:
198
+ scale = 1
199
+ input_img = self.input_img
200
+ else:
201
+ h, w = self.input_img.shape[0:2]
202
+ scale = resize / min(h, w)
203
+ scale = max(1, scale) # always scale up
204
+ h, w = int(h * scale), int(w * scale)
205
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
206
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
207
+
208
+ with torch.no_grad():
209
+ bboxes = self.face_detector.detect_faces(input_img)
210
+
211
+ if bboxes is None or bboxes.shape[0] == 0:
212
+ return 0
213
+ else:
214
+ bboxes = bboxes / scale
215
+
216
+ for bbox in bboxes:
217
+ # remove faces with too small eye distance: side faces or too small faces
218
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
219
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
220
+ continue
221
+
222
+ if self.template_3points:
223
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
224
+ else:
225
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
226
+ self.all_landmarks_5.append(landmark)
227
+ self.det_faces.append(bbox[0:5])
228
+
229
+ if len(self.det_faces) == 0:
230
+ return 0
231
+ if only_keep_largest:
232
+ h, w, _ = self.input_img.shape
233
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
234
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
235
+ elif only_center_face:
236
+ h, w, _ = self.input_img.shape
237
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
238
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
239
+
240
+ # pad blurry images
241
+ if self.pad_blur:
242
+ self.pad_input_imgs = []
243
+ for landmarks in self.all_landmarks_5:
244
+ # get landmarks
245
+ eye_left = landmarks[0, :]
246
+ eye_right = landmarks[1, :]
247
+ eye_avg = (eye_left + eye_right) * 0.5
248
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
249
+ eye_to_eye = eye_right - eye_left
250
+ eye_to_mouth = mouth_avg - eye_avg
251
+
252
+ # Get the oriented crop rectangle
253
+ # x: half width of the oriented crop rectangle
254
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
255
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
256
+ # norm with the hypotenuse: get the direction
257
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
258
+ rect_scale = 1.5
259
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
260
+ # y: half height of the oriented crop rectangle
261
+ y = np.flipud(x) * [-1, 1]
262
+
263
+ # c: center
264
+ c = eye_avg + eye_to_mouth * 0.1
265
+ # quad: (left_top, left_bottom, right_bottom, right_top)
266
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
267
+ # qsize: side length of the square
268
+ qsize = np.hypot(*x) * 2
269
+ border = max(int(np.rint(qsize * 0.1)), 3)
270
+
271
+ # get pad
272
+ # pad: (width_left, height_top, width_right, height_bottom)
273
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
274
+ int(np.ceil(max(quad[:, 1]))))
275
+ pad = [
276
+ max(-pad[0] + border, 1),
277
+ max(-pad[1] + border, 1),
278
+ max(pad[2] - self.input_img.shape[0] + border, 1),
279
+ max(pad[3] - self.input_img.shape[1] + border, 1)
280
+ ]
281
+
282
+ if max(pad) > 1:
283
+ # pad image
284
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
285
+ # modify landmark coords
286
+ landmarks[:, 0] += pad[0]
287
+ landmarks[:, 1] += pad[1]
288
+ # blur pad images
289
+ h, w, _ = pad_img.shape
290
+ y, x, _ = np.ogrid[:h, :w, :1]
291
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
292
+ np.float32(w - 1 - x) / pad[2]),
293
+ 1.0 - np.minimum(np.float32(y) / pad[1],
294
+ np.float32(h - 1 - y) / pad[3]))
295
+ blur = int(qsize * blur_ratio)
296
+ if blur % 2 == 0:
297
+ blur += 1
298
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
299
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
300
+
301
+ pad_img = pad_img.astype('float32')
302
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
303
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
304
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
305
+ self.pad_input_imgs.append(pad_img)
306
+ else:
307
+ self.pad_input_imgs.append(np.copy(self.input_img))
308
+
309
+ return len(self.all_landmarks_5)
310
+
311
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
312
+ """Align and warp faces with face template.
313
+ """
314
+ if self.pad_blur:
315
+ assert len(self.pad_input_imgs) == len(
316
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
317
+ for idx, landmark in enumerate(self.all_landmarks_5):
318
+ # use 5 landmarks to get affine matrix
319
+ # use cv2.LMEDS method for the equivalence to skimage transform
320
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
321
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
322
+ self.affine_matrices.append(affine_matrix)
323
+ # warp and crop faces
324
+ if border_mode == 'constant':
325
+ border_mode = cv2.BORDER_CONSTANT
326
+ elif border_mode == 'reflect101':
327
+ border_mode = cv2.BORDER_REFLECT101
328
+ elif border_mode == 'reflect':
329
+ border_mode = cv2.BORDER_REFLECT
330
+ if self.pad_blur:
331
+ input_img = self.pad_input_imgs[idx]
332
+ else:
333
+ input_img = self.input_img
334
+ cropped_face = cv2.warpAffine(
335
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
336
+ self.cropped_faces.append(cropped_face)
337
+ # save the cropped face
338
+ if save_cropped_path is not None:
339
+ path = os.path.splitext(save_cropped_path)[0]
340
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
341
+ imwrite(cropped_face, save_path)
342
+
343
+ def get_inverse_affine(self, save_inverse_affine_path=None):
344
+ """Get inverse affine matrix."""
345
+ for idx, affine_matrix in enumerate(self.affine_matrices):
346
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
347
+ inverse_affine *= self.upscale_factor
348
+ self.inverse_affine_matrices.append(inverse_affine)
349
+ # save inverse affine matrices
350
+ if save_inverse_affine_path is not None:
351
+ path, _ = os.path.splitext(save_inverse_affine_path)
352
+ save_path = f'{path}_{idx:02d}.pth'
353
+ torch.save(inverse_affine, save_path)
354
+
355
+
356
+ def add_restored_face(self, restored_face, input_face=None):
357
+ # if self.is_gray:
358
+ # restored_face = bgr2gray(restored_face) # convert img into grayscale
359
+ # if input_face is not None:
360
+ # restored_face = adain_npy(restored_face, input_face) # transfer the color
361
+ self.restored_faces.append(restored_face)
362
+
363
+
364
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
365
+ h, w, _ = self.input_img.shape
366
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
367
+
368
+ if upsample_img is None:
369
+ # simply resize the background
370
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
371
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
372
+ else:
373
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
374
+
375
+ assert len(self.restored_faces) == len(
376
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
377
+
378
+ inv_mask_borders = []
379
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
380
+ if face_upsampler is not None:
381
+ restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
382
+ inverse_affine /= self.upscale_factor
383
+ inverse_affine[:, 2] *= self.upscale_factor
384
+ face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
385
+ else:
386
+ # Add an offset to inverse affine matrix, for more precise back alignment
387
+ if self.upscale_factor > 1:
388
+ extra_offset = 0.5 * self.upscale_factor
389
+ else:
390
+ extra_offset = 0
391
+ inverse_affine[:, 2] += extra_offset
392
+ face_size = self.face_size
393
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
394
+
395
+ # if draw_box or not self.use_parse: # use square parse maps
396
+ # mask = np.ones(face_size, dtype=np.float32)
397
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
398
+ # # remove the black borders
399
+ # inv_mask_erosion = cv2.erode(
400
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
401
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
402
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
403
+ # # add border
404
+ # if draw_box:
405
+ # h, w = face_size
406
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
407
+ # border = int(1400/np.sqrt(total_face_area))
408
+ # mask_border[border:h-border, border:w-border,:] = 0
409
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
410
+ # inv_mask_borders.append(inv_mask_border)
411
+ # if not self.use_parse:
412
+ # # compute the fusion edge based on the area of face
413
+ # w_edge = int(total_face_area**0.5) // 20
414
+ # erosion_radius = w_edge * 2
415
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
416
+ # blur_size = w_edge * 2
417
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
418
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
419
+ # upsample_img = upsample_img[:, :, None]
420
+ # inv_soft_mask = inv_soft_mask[:, :, None]
421
+
422
+ # always use square mask
423
+ mask = np.ones(face_size, dtype=np.float32)
424
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
425
+ # remove the black borders
426
+ inv_mask_erosion = cv2.erode(
427
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
428
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
429
+ total_face_area = np.sum(inv_mask_erosion) # // 3
430
+ # add border
431
+ if draw_box:
432
+ h, w = face_size
433
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
434
+ border = int(1400/np.sqrt(total_face_area))
435
+ mask_border[border:h-border, border:w-border,:] = 0
436
+ inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
437
+ inv_mask_borders.append(inv_mask_border)
438
+ # compute the fusion edge based on the area of face
439
+ w_edge = int(total_face_area**0.5) // 20
440
+ erosion_radius = w_edge * 2
441
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
442
+ blur_size = w_edge * 2
443
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
444
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
445
+ upsample_img = upsample_img[:, :, None]
446
+ inv_soft_mask = inv_soft_mask[:, :, None]
447
+
448
+ # parse mask
449
+ if self.use_parse:
450
+ # inference
451
+ face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
452
+ face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
453
+ normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
454
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
455
+ with torch.no_grad():
456
+ out = self.face_parse(face_input)[0]
457
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
458
+
459
+ parse_mask = np.zeros(out.shape)
460
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
461
+ for idx, color in enumerate(MASK_COLORMAP):
462
+ parse_mask[out == idx] = color
463
+ # blur the mask
464
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
465
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
466
+ # remove the black borders
467
+ thres = 10
468
+ parse_mask[:thres, :] = 0
469
+ parse_mask[-thres:, :] = 0
470
+ parse_mask[:, :thres] = 0
471
+ parse_mask[:, -thres:] = 0
472
+ parse_mask = parse_mask / 255.
473
+
474
+ parse_mask = cv2.resize(parse_mask, face_size)
475
+ parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
476
+ inv_soft_parse_mask = parse_mask[:, :, None]
477
+ # pasted_face = inv_restored
478
+ fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
479
+ inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
480
+
481
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
482
+ alpha = upsample_img[:, :, 3:]
483
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
484
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
485
+ else:
486
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
487
+
488
+ if np.max(upsample_img) > 256: # 16-bit image
489
+ upsample_img = upsample_img.astype(np.uint16)
490
+ else:
491
+ upsample_img = upsample_img.astype(np.uint8)
492
+
493
+ # draw bounding box
494
+ if draw_box:
495
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
496
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
497
+ img_color[:,:,0] = 0
498
+ img_color[:,:,1] = 255
499
+ img_color[:,:,2] = 0
500
+ for inv_mask_border in inv_mask_borders:
501
+ upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
502
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
503
+
504
+ if save_path is not None:
505
+ path = os.path.splitext(save_path)[0]
506
+ save_path = f'{path}.{self.save_ext}'
507
+ imwrite(upsample_img, save_path)
508
+ return upsample_img
509
+
510
+ def clean_all(self):
511
+ self.all_landmarks_5 = []
512
+ self.restored_faces = []
513
+ self.affine_matrices = []
514
+ self.cropped_faces = []
515
+ self.inverse_affine_matrices = []
516
+ self.det_faces = []
517
+ self.pad_input_imgs = []
utils/helpers.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import overload, Tuple, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from PIL import Image
8
+ from einops import rearrange
9
+
10
+ from model.cldm import ControlLDM
11
+ from model.gaussian_diffusion import Diffusion
12
+ from model.bsrnet import RRDBNet
13
+ from model.swinir import SwinIR
14
+ from model.scunet import SCUNet
15
+ from utils.sampler import SpacedSampler
16
+ from utils.cond_fn import Guidance
17
+ from utils.common import wavelet_decomposition, wavelet_reconstruction, count_vram_usage
18
+
19
+
20
+ def bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray:
21
+ pil = Image.fromarray(img)
22
+ res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC)
23
+ return np.array(res)
24
+
25
+
26
+ def resize_short_edge_to(imgs: torch.Tensor, size: int) -> torch.Tensor:
27
+ _, _, h, w = imgs.size()
28
+ if h == w:
29
+ new_h, new_w = size, size
30
+ elif h < w:
31
+ new_h, new_w = size, int(w * (size / h))
32
+ else:
33
+ new_h, new_w = int(h * (size / w)), size
34
+ return F.interpolate(imgs, size=(new_h, new_w), mode="bicubic", antialias=True)
35
+
36
+
37
+ def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor:
38
+ _, _, h, w = imgs.size()
39
+ if h % multiple == 0 and w % multiple == 0:
40
+ return imgs.clone()
41
+ # get_pad = lambda x: (x // multiple + 1) * multiple - x
42
+ get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x
43
+ ph, pw = get_pad(h), get_pad(w)
44
+ return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0)
45
+
46
+
47
+ class Pipeline:
48
+
49
+ def __init__(self, stage1_model: nn.Module, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
50
+ self.stage1_model = stage1_model
51
+ self.cldm = cldm
52
+ self.diffusion = diffusion
53
+ self.cond_fn = cond_fn
54
+ self.device = device
55
+ self.final_size: Tuple[int] = None
56
+
57
+ def set_final_size(self, lq: torch.Tensor) -> None:
58
+ h, w = lq.shape[2:]
59
+ self.final_size = (h, w)
60
+
61
+ @overload
62
+ def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
63
+ ...
64
+
65
+ @count_vram_usage
66
+ def run_stage2(
67
+ self,
68
+ clean: torch.Tensor,
69
+ steps: int,
70
+ strength: float,
71
+ tiled: bool,
72
+ tile_size: int,
73
+ tile_stride: int,
74
+ pos_prompt: str,
75
+ neg_prompt: str,
76
+ cfg_scale: float,
77
+ better_start: float
78
+ ) -> torch.Tensor:
79
+ ### preprocess
80
+ bs, _, ori_h, ori_w = clean.shape
81
+ # pad: ensure that height & width are multiples of 64
82
+ pad_clean = pad_to_multiples_of(clean, multiple=64)
83
+ h, w = pad_clean.shape[2:]
84
+ # prepare conditon
85
+ if not tiled:
86
+ cond = self.cldm.prepare_condition(pad_clean, [pos_prompt] * bs)
87
+ uncond = self.cldm.prepare_condition(pad_clean, [neg_prompt] * bs)
88
+ else:
89
+ cond = self.cldm.prepare_condition_tiled(pad_clean, [pos_prompt] * bs, tile_size, tile_stride)
90
+ uncond = self.cldm.prepare_condition_tiled(pad_clean, [neg_prompt] * bs, tile_size, tile_stride)
91
+ if self.cond_fn:
92
+ self.cond_fn.load_target(pad_clean * 2 - 1)
93
+ old_control_scales = self.cldm.control_scales
94
+ self.cldm.control_scales = [strength] * 13
95
+ if better_start:
96
+ # using noised low frequency part of condition as a better start point of
97
+ # reverse sampling, which can prevent our model from generating noise in
98
+ # image background.
99
+ _, low_freq = wavelet_decomposition(pad_clean)
100
+ if not tiled:
101
+ x_0 = self.cldm.vae_encode(low_freq)
102
+ else:
103
+ x_0 = self.cldm.vae_encode_tiled(low_freq, tile_size, tile_stride)
104
+ x_T = self.diffusion.q_sample(
105
+ x_0,
106
+ torch.full((bs, ), self.diffusion.num_timesteps - 1, dtype=torch.long, device=self.device),
107
+ torch.randn(x_0.shape, dtype=torch.float32, device=self.device)
108
+ )
109
+ # print(f"diffusion sqrt_alphas_cumprod: {self.diffusion.sqrt_alphas_cumprod[-1]}")
110
+ else:
111
+ x_T = torch.randn((bs, 4, h // 8, w // 8), dtype=torch.float32, device=self.device)
112
+ ### run sampler
113
+ sampler = SpacedSampler(self.diffusion.betas)
114
+ z = sampler.sample(
115
+ model=self.cldm, device=self.device, steps=steps, batch_size=bs, x_size=(4, h // 8, w // 8),
116
+ cond=cond, uncond=uncond, cfg_scale=cfg_scale, x_T=x_T, progress=True,
117
+ progress_leave=True, cond_fn=self.cond_fn, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
118
+ )
119
+ if not tiled:
120
+ x = self.cldm.vae_decode(z)
121
+ else:
122
+ x = self.cldm.vae_decode_tiled(z, tile_size // 8, tile_stride // 8)
123
+ ### postprocess
124
+ self.cldm.control_scales = old_control_scales
125
+ sample = x[:, :, :ori_h, :ori_w]
126
+ return sample
127
+
128
+ @torch.no_grad()
129
+ def run(
130
+ self,
131
+ lq: np.ndarray,
132
+ steps: int,
133
+ strength: float,
134
+ tiled: bool,
135
+ tile_size: int,
136
+ tile_stride: int,
137
+ pos_prompt: str,
138
+ neg_prompt: str,
139
+ cfg_scale: float,
140
+ better_start: bool
141
+ ) -> np.ndarray:
142
+ # image to tensor
143
+ lq = torch.tensor((lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device)
144
+ lq = rearrange(lq, "n h w c -> n c h w").contiguous()
145
+ # set pipeline output size
146
+ self.set_final_size(lq)
147
+ clean = self.run_stage1(lq)
148
+ sample = self.run_stage2(
149
+ clean, steps, strength, tiled, tile_size, tile_stride,
150
+ pos_prompt, neg_prompt, cfg_scale, better_start
151
+ )
152
+ # colorfix (borrowed from StableSR, thanks for their work)
153
+ sample = (sample + 1) / 2
154
+ sample = wavelet_reconstruction(sample, clean)
155
+ # resize to desired output size
156
+ sample = F.interpolate(sample, size=self.final_size, mode="bicubic", antialias=True)
157
+ # tensor to image
158
+ sample = rearrange(sample * 255., "n c h w -> n h w c")
159
+ sample = sample.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy()
160
+ return sample
161
+
162
+
163
+ class BSRNetPipeline(Pipeline):
164
+
165
+ def __init__(self, bsrnet: RRDBNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str, upscale: float) -> None:
166
+ super().__init__(bsrnet, cldm, diffusion, cond_fn, device)
167
+ self.upscale = upscale
168
+
169
+ def set_final_size(self, lq: torch.Tensor) -> None:
170
+ h, w = lq.shape[2:]
171
+ self.final_size = (int(h * self.upscale), int(w * self.upscale))
172
+
173
+ @count_vram_usage
174
+ def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
175
+ # NOTE: upscale is always set to 4 in our experiments
176
+ clean = self.stage1_model(lq)
177
+ # if self.final_size[0] < 512 and self.final_size[1] < 512:
178
+ if min(self.final_size) < 512:
179
+ clean = resize_short_edge_to(clean, size=512)
180
+ else:
181
+ clean = F.interpolate(clean, size=self.final_size, mode="bicubic", antialias=True)
182
+ return clean
183
+
184
+
185
+ class SwinIRPipeline(Pipeline):
186
+
187
+ def __init__(self, swinir: SwinIR, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
188
+ super().__init__(swinir, cldm, diffusion, cond_fn, device)
189
+
190
+ @count_vram_usage
191
+ def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
192
+ # NOTE: lq size is always equal to 512 in our experiments
193
+ # resize: ensure the input lq size is as least 512, since SwinIR is trained on 512 resolution
194
+ if min(lq.shape[2:]) < 512:
195
+ lq = resize_short_edge_to(lq, size=512)
196
+ ori_h, ori_w = lq.shape[2:]
197
+ # pad: ensure that height & width are multiples of 64
198
+ pad_lq = pad_to_multiples_of(lq, multiple=64)
199
+ # run
200
+ clean = self.stage1_model(pad_lq)
201
+ # remove padding
202
+ clean = clean[:, :, :ori_h, :ori_w]
203
+ return clean
204
+
205
+
206
+ class SCUNetPipeline(Pipeline):
207
+
208
+ def __init__(self, scunet: SCUNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
209
+ super().__init__(scunet, cldm, diffusion, cond_fn, device)
210
+
211
+ @count_vram_usage
212
+ def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
213
+ clean = self.stage1_model(lq)
214
+ if min(clean.shape[2:]) < 512:
215
+ clean = resize_short_edge_to(clean, size=512)
216
+ return clean
utils/inference.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import overload, Generator, Dict
3
+ from argparse import Namespace
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from omegaconf import OmegaConf
9
+
10
+ from model.cldm import ControlLDM
11
+ from model.gaussian_diffusion import Diffusion
12
+ from model.bsrnet import RRDBNet
13
+ from model.scunet import SCUNet
14
+ from model.swinir import SwinIR
15
+ from utils.common import instantiate_from_config, load_file_from_url, count_vram_usage
16
+ from utils.face_restoration_helper import FaceRestoreHelper
17
+ from utils.helpers import (
18
+ Pipeline,
19
+ BSRNetPipeline, SwinIRPipeline, SCUNetPipeline,
20
+ bicubic_resize
21
+ )
22
+ from utils.cond_fn import MSEGuidance, WeightedMSEGuidance
23
+
24
+
25
+ MODELS = {
26
+ ### stage_1 model weights
27
+ "bsrnet": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRNet.pth",
28
+ # the following checkpoint is up-to-date, but we use the old version in our paper
29
+ # "swinir_face": "https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth",
30
+ "swinir_face": "https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt",
31
+ "scunet_psnr": "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth",
32
+ "swinir_general": "https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt",
33
+ ### stage_2 model weights
34
+ "sd_v21": "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt",
35
+ "v1_face": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_face.pth",
36
+ "v1_general": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_general.pth",
37
+ "v2": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v2.pth"
38
+ }
39
+
40
+
41
+ def load_model_from_url(url: str) -> Dict[str, torch.Tensor]:
42
+ sd_path = load_file_from_url(url, model_dir="weights")
43
+ sd = torch.load(sd_path, map_location="cpu")
44
+ if "state_dict" in sd:
45
+ sd = sd["state_dict"]
46
+ if list(sd.keys())[0].startswith("module"):
47
+ sd = {k[len("module."):]: v for k, v in sd.items()}
48
+ return sd
49
+
50
+
51
+ class InferenceLoop:
52
+
53
+ def __init__(self, args: Namespace) -> "InferenceLoop":
54
+ self.args = args
55
+ self.loop_ctx = {}
56
+ self.pipeline: Pipeline = None
57
+ self.init_stage1_model()
58
+ self.init_stage2_model()
59
+ self.init_cond_fn()
60
+ self.init_pipeline()
61
+
62
+ @overload
63
+ def init_stage1_model(self) -> None:
64
+ ...
65
+
66
+ @count_vram_usage
67
+ def init_stage2_model(self) -> None:
68
+ ### load uent, vae, clip
69
+ self.cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml"))
70
+ sd = load_model_from_url(MODELS["sd_v21"])
71
+ unused = self.cldm.load_pretrained_sd(sd)
72
+ print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
73
+ ### load controlnet
74
+ if self.args.version == "v1":
75
+ if self.args.task == "fr":
76
+ control_sd = load_model_from_url(MODELS["v1_face"])
77
+ elif self.args.task == "sr":
78
+ control_sd = load_model_from_url(MODELS["v1_general"])
79
+ else:
80
+ raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'")
81
+ else:
82
+ control_sd = load_model_from_url(MODELS["v2"])
83
+ self.cldm.load_controlnet_from_ckpt(control_sd)
84
+ print(f"strictly load controlnet weight")
85
+ self.cldm.eval().to(self.args.device)
86
+ ### load diffusion
87
+ self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load("configs/inference/diffusion.yaml"))
88
+ self.diffusion.to(self.args.device)
89
+
90
+ def init_cond_fn(self) -> None:
91
+ if not self.args.guidance:
92
+ self.cond_fn = None
93
+ return
94
+ if self.args.g_loss == "mse":
95
+ cond_fn_cls = MSEGuidance
96
+ elif self.args.g_loss == "w_mse":
97
+ cond_fn_cls = WeightedMSEGuidance
98
+ else:
99
+ raise ValueError(self.args.g_loss)
100
+ self.cond_fn = cond_fn_cls(
101
+ scale=self.args.g_scale, t_start=self.args.g_start, t_stop=self.args.g_stop,
102
+ space=self.args.g_space, repeat=self.args.g_repeat
103
+ )
104
+
105
+ @overload
106
+ def init_pipeline(self) -> None:
107
+ ...
108
+
109
+ def setup(self) -> None:
110
+ self.output_dir = self.args.output
111
+ os.makedirs(self.output_dir, exist_ok=True)
112
+
113
+ def lq_loader(self) -> Generator[np.ndarray, None, None]:
114
+ img_exts = [".png", ".jpg", ".jpeg"]
115
+ if os.path.isdir(self.args.input):
116
+ file_names = sorted([
117
+ file_name for file_name in os.listdir(self.args.input) if os.path.splitext(file_name)[-1] in img_exts
118
+ ])
119
+ file_paths = [os.path.join(self.args.input, file_name) for file_name in file_names]
120
+ else:
121
+ assert os.path.splitext(self.args.input)[-1] in img_exts
122
+ file_paths = [self.args.input]
123
+
124
+ def _loader() -> Generator[np.ndarray, None, None]:
125
+ for file_path in file_paths:
126
+ ### load lq
127
+ lq = np.array(Image.open(file_path).convert("RGB"))
128
+ print(f"load lq: {file_path}")
129
+ ### set context for saving results
130
+ self.loop_ctx["file_stem"] = os.path.splitext(os.path.basename(file_path))[0]
131
+ for i in range(self.args.n_samples):
132
+ self.loop_ctx["repeat_idx"] = i
133
+ yield lq
134
+
135
+ return _loader
136
+
137
+ def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
138
+ return lq
139
+
140
+ @torch.no_grad()
141
+ def run(self) -> None:
142
+ self.setup()
143
+ # We don't support batch processing since input images may have different size
144
+ loader = self.lq_loader()
145
+ for lq in loader():
146
+ lq = self.after_load_lq(lq)
147
+ sample = self.pipeline.run(
148
+ lq[None], self.args.steps, 1.0, self.args.tiled,
149
+ self.args.tile_size, self.args.tile_stride,
150
+ self.args.pos_prompt, self.args.neg_prompt, self.args.cfg_scale,
151
+ self.args.better_start
152
+ )[0]
153
+ self.save(sample)
154
+
155
+ def save(self, sample: np.ndarray) -> None:
156
+ file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"]
157
+ file_name = f"{file_stem}_{repeat_idx}.png" if self.args.n_samples > 1 else f"{file_stem}.png"
158
+ save_path = os.path.join(self.args.output, file_name)
159
+ Image.fromarray(sample).save(save_path)
160
+ print(f"save result to {save_path}")
161
+
162
+
163
+ class BSRInferenceLoop(InferenceLoop):
164
+
165
+ @count_vram_usage
166
+ def init_stage1_model(self) -> None:
167
+ self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml"))
168
+ sd = load_model_from_url(MODELS["bsrnet"])
169
+ self.bsrnet.load_state_dict(sd, strict=True)
170
+ self.bsrnet.eval().to(self.args.device)
171
+
172
+ def init_pipeline(self) -> None:
173
+ self.pipeline = BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale)
174
+
175
+
176
+ class BFRInferenceLoop(InferenceLoop):
177
+
178
+ @count_vram_usage
179
+ def init_stage1_model(self) -> None:
180
+ self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
181
+ sd = load_model_from_url(MODELS["swinir_face"])
182
+ self.swinir_face.load_state_dict(sd, strict=True)
183
+ self.swinir_face.eval().to(self.args.device)
184
+
185
+ def init_pipeline(self) -> None:
186
+ self.pipeline = SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device)
187
+
188
+ def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
189
+ # For BFR task, super resolution is achieved by directly upscaling lq
190
+ return bicubic_resize(lq, self.args.upscale)
191
+
192
+
193
+ class BIDInferenceLoop(InferenceLoop):
194
+
195
+ @count_vram_usage
196
+ def init_stage1_model(self) -> None:
197
+ self.scunet_psnr: SCUNet = instantiate_from_config(OmegaConf.load("configs/inference/scunet.yaml"))
198
+ sd = load_model_from_url(MODELS["scunet_psnr"])
199
+ self.scunet_psnr.load_state_dict(sd, strict=True)
200
+ self.scunet_psnr.eval().to(self.args.device)
201
+
202
+ def init_pipeline(self) -> None:
203
+ self.pipeline = SCUNetPipeline(self.scunet_psnr, self.cldm, self.diffusion, self.cond_fn, self.args.device)
204
+
205
+ def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
206
+ # For BID task, super resolution is achieved by directly upscaling lq
207
+ return bicubic_resize(lq, self.args.upscale)
208
+
209
+
210
+ class V1InferenceLoop(InferenceLoop):
211
+
212
+ @count_vram_usage
213
+ def init_stage1_model(self) -> None:
214
+ self.swinir: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
215
+ if self.args.task == "fr":
216
+ sd = load_model_from_url(MODELS["swinir_face"])
217
+ elif self.args.task == "sr":
218
+ sd = load_model_from_url(MODELS["swinir_general"])
219
+ else:
220
+ raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'")
221
+ self.swinir.load_state_dict(sd, strict=True)
222
+ self.swinir.eval().to(self.args.device)
223
+
224
+ def init_pipeline(self) -> None:
225
+ self.pipeline = SwinIRPipeline(self.swinir, self.cldm, self.diffusion, self.cond_fn, self.args.device)
226
+
227
+ def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
228
+ # For BFR task, super resolution is achieved by directly upscaling lq
229
+ return bicubic_resize(lq, self.args.upscale)
230
+
231
+
232
+ class UnAlignedBFRInferenceLoop(InferenceLoop):
233
+
234
+ @count_vram_usage
235
+ def init_stage1_model(self) -> None:
236
+ self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml"))
237
+ sd = load_model_from_url(MODELS["bsrnet"])
238
+ self.bsrnet.load_state_dict(sd, strict=True)
239
+ self.bsrnet.eval().to(self.args.device)
240
+
241
+ self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
242
+ sd = load_model_from_url(MODELS["swinir_face"])
243
+ self.swinir_face.load_state_dict(sd, strict=True)
244
+ self.swinir_face.eval().to(self.args.device)
245
+
246
+ def init_pipeline(self) -> None:
247
+ self.pipes = {
248
+ "bg": BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale),
249
+ "face": SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device)
250
+ }
251
+ self.pipeline = self.pipes["face"]
252
+
253
+ def setup(self) -> None:
254
+ super().setup()
255
+ self.cropped_face_dir = os.path.join(self.args.output, "cropped_faces")
256
+ os.makedirs(self.cropped_face_dir, exist_ok=True)
257
+ self.restored_face_dir = os.path.join(self.args.output, "restored_faces")
258
+ os.makedirs(self.restored_face_dir, exist_ok=True)
259
+ self.restored_bg_dir = os.path.join(self.args.output, "restored_backgrounds")
260
+ os.makedirs(self.restored_bg_dir, exist_ok=True)
261
+
262
+ def lq_loader(self) -> Generator[np.ndarray, None, None]:
263
+ base_loader = super().lq_loader()
264
+ self.face_helper = FaceRestoreHelper(
265
+ device=self.args.device,
266
+ upscale_factor=1,
267
+ face_size=512,
268
+ use_parse=True,
269
+ det_model="retinaface_resnet50"
270
+ )
271
+
272
+ def _loader() -> Generator[np.ndarray, None, None]:
273
+ for lq in base_loader():
274
+ ### set input image
275
+ self.face_helper.clean_all()
276
+ upscaled_bg = bicubic_resize(lq, self.args.upscale)
277
+ self.face_helper.read_image(upscaled_bg)
278
+ ### get face landmarks for each face
279
+ self.face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
280
+ self.face_helper.align_warp_face()
281
+ print(f"detect {len(self.face_helper.cropped_faces)} faces")
282
+ ### restore each face (has been upscaeled)
283
+ for i, lq_face in enumerate(self.face_helper.cropped_faces):
284
+ self.loop_ctx["is_face"] = True
285
+ self.loop_ctx["face_idx"] = i
286
+ self.loop_ctx["cropped_face"] = lq_face
287
+ yield lq_face
288
+ ### restore background (hasn't been upscaled)
289
+ self.loop_ctx["is_face"] = False
290
+ yield lq
291
+
292
+ return _loader
293
+
294
+ def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
295
+ if self.loop_ctx["is_face"]:
296
+ self.pipeline = self.pipes["face"]
297
+ else:
298
+ self.pipeline = self.pipes["bg"]
299
+ return lq
300
+
301
+ def save(self, sample: np.ndarray) -> None:
302
+ file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"]
303
+ if self.loop_ctx["is_face"]:
304
+ face_idx = self.loop_ctx["face_idx"]
305
+ file_name = f"{file_stem}_{repeat_idx}_face_{face_idx}.png"
306
+ Image.fromarray(sample).save(os.path.join(self.restored_face_dir, file_name))
307
+
308
+ cropped_face = self.loop_ctx["cropped_face"]
309
+ Image.fromarray(cropped_face).save(os.path.join(self.cropped_face_dir, file_name))
310
+
311
+ self.face_helper.add_restored_face(sample)
312
+ else:
313
+ self.face_helper.get_inverse_affine()
314
+ # paste each restored face to the input image
315
+ restored_img = self.face_helper.paste_faces_to_input_image(
316
+ upsample_img=sample
317
+ )
318
+ file_name = f"{file_stem}_{repeat_idx}.png"
319
+ Image.fromarray(sample).save(os.path.join(self.restored_bg_dir, file_name))
320
+ Image.fromarray(restored_img).save(os.path.join(self.output_dir, file_name))
utils/sampler.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Dict
2
+
3
+ import torch
4
+ from torch import nn
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ from model.gaussian_diffusion import extract_into_tensor
9
+ from model.cldm import ControlLDM
10
+ from utils.cond_fn import Guidance
11
+ from utils.common import sliding_windows, gaussian_weights
12
+
13
+
14
+ # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
15
+ def space_timesteps(num_timesteps, section_counts):
16
+ """
17
+ Create a list of timesteps to use from an original diffusion process,
18
+ given the number of timesteps we want to take from equally-sized portions
19
+ of the original process.
20
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
21
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
22
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
23
+ If the stride is a string starting with "ddim", then the fixed striding
24
+ from the DDIM paper is used, and only one section is allowed.
25
+ :param num_timesteps: the number of diffusion steps in the original
26
+ process to divide up.
27
+ :param section_counts: either a list of numbers, or a string containing
28
+ comma-separated numbers, indicating the step count
29
+ per section. As a special case, use "ddimN" where N
30
+ is a number of steps to use the striding from the
31
+ DDIM paper.
32
+ :return: a set of diffusion steps from the original process to use.
33
+ """
34
+ if isinstance(section_counts, str):
35
+ if section_counts.startswith("ddim"):
36
+ desired_count = int(section_counts[len("ddim") :])
37
+ for i in range(1, num_timesteps):
38
+ if len(range(0, num_timesteps, i)) == desired_count:
39
+ return set(range(0, num_timesteps, i))
40
+ raise ValueError(
41
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
42
+ )
43
+ section_counts = [int(x) for x in section_counts.split(",")]
44
+ size_per = num_timesteps // len(section_counts)
45
+ extra = num_timesteps % len(section_counts)
46
+ start_idx = 0
47
+ all_steps = []
48
+ for i, section_count in enumerate(section_counts):
49
+ size = size_per + (1 if i < extra else 0)
50
+ if size < section_count:
51
+ raise ValueError(
52
+ f"cannot divide section of {size} steps into {section_count}"
53
+ )
54
+ if section_count <= 1:
55
+ frac_stride = 1
56
+ else:
57
+ frac_stride = (size - 1) / (section_count - 1)
58
+ cur_idx = 0.0
59
+ taken_steps = []
60
+ for _ in range(section_count):
61
+ taken_steps.append(start_idx + round(cur_idx))
62
+ cur_idx += frac_stride
63
+ all_steps += taken_steps
64
+ start_idx += size
65
+ return set(all_steps)
66
+
67
+
68
+ class SpacedSampler(nn.Module):
69
+ """
70
+ Implementation for spaced sampling schedule proposed in IDDPM. This class is designed
71
+ for sampling ControlLDM.
72
+
73
+ https://arxiv.org/pdf/2102.09672.pdf
74
+ """
75
+
76
+ def __init__(self, betas: np.ndarray) -> "SpacedSampler":
77
+ super().__init__()
78
+ self.num_timesteps = len(betas)
79
+ self.original_betas = betas
80
+ self.original_alphas_cumprod = np.cumprod(1.0 - betas, axis=0)
81
+ self.context = {}
82
+
83
+ def register(self, name: str, value: np.ndarray) -> None:
84
+ self.register_buffer(name, torch.tensor(value, dtype=torch.float32))
85
+
86
+ def make_schedule(self, num_steps: int) -> None:
87
+ # calcualte betas for spaced sampling
88
+ # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
89
+ used_timesteps = space_timesteps(self.num_timesteps, str(num_steps))
90
+ betas = []
91
+ last_alpha_cumprod = 1.0
92
+ for i, alpha_cumprod in enumerate(self.original_alphas_cumprod):
93
+ if i in used_timesteps:
94
+ # marginal distribution is the same as q(x_{S_t}|x_0)
95
+ betas.append(1 - alpha_cumprod / last_alpha_cumprod)
96
+ last_alpha_cumprod = alpha_cumprod
97
+ assert len(betas) == num_steps
98
+ self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...]
99
+
100
+ betas = np.array(betas, dtype=np.float64)
101
+ alphas = 1.0 - betas
102
+ alphas_cumprod = np.cumprod(alphas, axis=0)
103
+ # print(f"sampler sqrt_alphas_cumprod: {np.sqrt(alphas_cumprod)[-1]}")
104
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
105
+ sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
106
+ sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
107
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
108
+ posterior_variance = (
109
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
110
+ )
111
+ # log calculation clipped because the posterior variance is 0 at the
112
+ # beginning of the diffusion chain.
113
+ posterior_log_variance_clipped = np.log(
114
+ np.append(posterior_variance[1], posterior_variance[1:])
115
+ )
116
+ posterior_mean_coef1 = (
117
+ betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
118
+ )
119
+ posterior_mean_coef2 = (
120
+ (1.0 - alphas_cumprod_prev)
121
+ * np.sqrt(alphas)
122
+ / (1.0 - alphas_cumprod)
123
+ )
124
+
125
+ self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod)
126
+ self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod)
127
+ self.register("posterior_variance", posterior_variance)
128
+ self.register("posterior_log_variance_clipped", posterior_log_variance_clipped)
129
+ self.register("posterior_mean_coef1", posterior_mean_coef1)
130
+ self.register("posterior_mean_coef2", posterior_mean_coef2)
131
+
132
+ def q_posterior_mean_variance(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor]:
133
+ """
134
+ Implement the posterior distribution q(x_{t-1}|x_t, x_0).
135
+
136
+ Args:
137
+ x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`.
138
+ x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`.
139
+ t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get
140
+ parameters for each timestep.
141
+
142
+ Returns:
143
+ posterior_mean (torch.Tensor): Mean of the posterior distribution.
144
+ posterior_variance (torch.Tensor): Variance of the posterior distribution.
145
+ posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution.
146
+ """
147
+ posterior_mean = (
148
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
149
+ + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
150
+ )
151
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
152
+ posterior_log_variance_clipped = extract_into_tensor(
153
+ self.posterior_log_variance_clipped, t, x_t.shape
154
+ )
155
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
156
+
157
+ def _predict_xstart_from_eps(self, x_t: torch.Tensor, t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor:
158
+ return (
159
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
160
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
161
+ )
162
+
163
+ def apply_cond_fn(
164
+ self,
165
+ model: ControlLDM,
166
+ pred_x0: torch.Tensor,
167
+ t: torch.Tensor,
168
+ index: torch.Tensor,
169
+ cond_fn: Guidance
170
+ ) -> torch.Tensor:
171
+ t_now = int(t[0].item()) + 1
172
+ if not (cond_fn.t_stop < t_now and t_now < cond_fn.t_start):
173
+ # stop guidance
174
+ self.context["g_apply"] = False
175
+ return pred_x0
176
+ grad_rescale = 1 / extract_into_tensor(self.posterior_mean_coef1, index, pred_x0.shape)
177
+ # apply guidance for multiple times
178
+ loss_vals = []
179
+ for _ in range(cond_fn.repeat):
180
+ # set target and pred for gradient computation
181
+ target, pred = None, None
182
+ if cond_fn.space == "latent":
183
+ target = model.vae_encode(cond_fn.target)
184
+ pred = pred_x0
185
+ elif cond_fn.space == "rgb":
186
+ # We need to backward gradient to x0 in latent space, so it's required
187
+ # to trace the computation graph while decoding the latent.
188
+ with torch.enable_grad():
189
+ target = cond_fn.target
190
+ pred_x0_rg = pred_x0.detach().clone().requires_grad_(True)
191
+ pred = model.vae_decode(pred_x0_rg)
192
+ assert pred.requires_grad
193
+ else:
194
+ raise NotImplementedError(cond_fn.space)
195
+ # compute gradient
196
+ delta_pred, loss_val = cond_fn(target, pred, t_now)
197
+ loss_vals.append(loss_val)
198
+ # update pred_x0 w.r.t gradient
199
+ if cond_fn.space == "latent":
200
+ delta_pred_x0 = delta_pred
201
+ pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale
202
+ elif cond_fn.space == "rgb":
203
+ pred.backward(delta_pred)
204
+ delta_pred_x0 = pred_x0_rg.grad
205
+ pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale
206
+ else:
207
+ raise NotImplementedError(cond_fn.space)
208
+ self.context["g_apply"] = True
209
+ self.context["g_loss"] = float(np.mean(loss_vals))
210
+ return pred_x0
211
+
212
+ def predict_noise(
213
+ self,
214
+ model: ControlLDM,
215
+ x: torch.Tensor,
216
+ t: torch.Tensor,
217
+ cond: Dict[str, torch.Tensor],
218
+ uncond: Optional[Dict[str, torch.Tensor]],
219
+ cfg_scale: float
220
+ ) -> torch.Tensor:
221
+ if uncond is None or cfg_scale == 1.:
222
+ model_output = model(x, t, cond)
223
+ else:
224
+ # apply classifier-free guidance
225
+ model_cond = model(x, t, cond)
226
+ model_uncond = model(x, t, uncond)
227
+ model_output = model_uncond + cfg_scale * (model_cond - model_uncond)
228
+ return model_output
229
+
230
+ @torch.no_grad()
231
+ def predict_noise_tiled(
232
+ self,
233
+ model: ControlLDM,
234
+ x: torch.Tensor,
235
+ t: torch.Tensor,
236
+ cond: Dict[str, torch.Tensor],
237
+ uncond: Optional[Dict[str, torch.Tensor]],
238
+ cfg_scale: float,
239
+ tile_size: int,
240
+ tile_stride: int
241
+ ):
242
+ _, _, h, w = x.shape
243
+ tiles = tqdm(sliding_windows(h, w, tile_size // 8, tile_stride // 8), unit="tile", leave=False)
244
+ eps = torch.zeros_like(x)
245
+ count = torch.zeros_like(x, dtype=torch.float32)
246
+ weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None]
247
+ weights = torch.tensor(weights, dtype=torch.float32, device=x.device)
248
+ for hi, hi_end, wi, wi_end in tiles:
249
+ tiles.set_description(f"Process tile ({hi} {hi_end}), ({wi} {wi_end})")
250
+ tile_x = x[:, :, hi:hi_end, wi:wi_end]
251
+ tile_cond = {
252
+ "c_img": cond["c_img"][:, :, hi:hi_end, wi:wi_end],
253
+ "c_txt": cond["c_txt"]
254
+ }
255
+ if uncond:
256
+ tile_uncond = {
257
+ "c_img": uncond["c_img"][:, :, hi:hi_end, wi:wi_end],
258
+ "c_txt": uncond["c_txt"]
259
+ }
260
+ tile_eps = self.predict_noise(model, tile_x, t, tile_cond, tile_uncond, cfg_scale)
261
+ # accumulate noise
262
+ eps[:, :, hi:hi_end, wi:wi_end] += tile_eps * weights
263
+ count[:, :, hi:hi_end, wi:wi_end] += weights
264
+ # average on noise (score)
265
+ eps.div_(count)
266
+ return eps
267
+
268
+ @torch.no_grad()
269
+ def p_sample(
270
+ self,
271
+ model: ControlLDM,
272
+ x: torch.Tensor,
273
+ t: torch.Tensor,
274
+ index: torch.Tensor,
275
+ cond: Dict[str, torch.Tensor],
276
+ uncond: Optional[Dict[str, torch.Tensor]],
277
+ cfg_scale: float,
278
+ cond_fn: Optional[Guidance],
279
+ tiled: bool,
280
+ tile_size: int,
281
+ tile_stride: int
282
+ ) -> torch.Tensor:
283
+ if tiled:
284
+ eps = self.predict_noise_tiled(model, x, t, cond, uncond, cfg_scale, tile_size, tile_stride)
285
+ else:
286
+ eps = self.predict_noise(model, x, t, cond, uncond, cfg_scale)
287
+ pred_x0 = self._predict_xstart_from_eps(x, index, eps)
288
+ if cond_fn:
289
+ assert not tiled, f"tiled sampling currently doesn't support guidance"
290
+ pred_x0 = self.apply_cond_fn(model, pred_x0, t, index, cond_fn)
291
+ model_mean, model_variance, _ = self.q_posterior_mean_variance(pred_x0, x, index)
292
+ noise = torch.randn_like(x)
293
+ nonzero_mask = (
294
+ (index != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
295
+ )
296
+ x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise
297
+ return x_prev
298
+
299
+ @torch.no_grad()
300
+ def sample(
301
+ self,
302
+ model: ControlLDM,
303
+ device: str,
304
+ steps: int,
305
+ batch_size: int,
306
+ x_size: Tuple[int],
307
+ cond: Dict[str, torch.Tensor],
308
+ uncond: Dict[str, torch.Tensor],
309
+ cfg_scale: float,
310
+ cond_fn: Optional[Guidance]=None,
311
+ tiled: bool=False,
312
+ tile_size: int=-1,
313
+ tile_stride: int=-1,
314
+ x_T: Optional[torch.Tensor]=None,
315
+ progress: bool=True,
316
+ progress_leave: bool=True,
317
+ ) -> torch.Tensor:
318
+ self.make_schedule(steps)
319
+ self.to(device)
320
+ if x_T is None:
321
+ # TODO: not convert to float32, may trigger an error
322
+ img = torch.randn((batch_size, *x_size), device=device)
323
+ else:
324
+ img = x_T
325
+ timesteps = np.flip(self.timesteps) # [1000, 950, 900, ...]
326
+ total_steps = len(self.timesteps)
327
+ iterator = tqdm(timesteps, total=total_steps, leave=progress_leave, disable=not progress)
328
+ for i, step in enumerate(iterator):
329
+ ts = torch.full((batch_size,), step, device=device, dtype=torch.long)
330
+ index = torch.full_like(ts, fill_value=total_steps - i - 1)
331
+ img = self.p_sample(
332
+ model, img, ts, index, cond, uncond, cfg_scale, cond_fn,
333
+ tiled, tile_size, tile_stride
334
+ )
335
+ if cond_fn and self.context["g_apply"]:
336
+ loss_val = self.context["g_loss"]
337
+ desc = f"Spaced Sampler With Guidance, Loss: {loss_val:.6f}"
338
+ else:
339
+ desc = "Spaced Sampler"
340
+ iterator.set_description(desc)
341
+ return img