Upload 6 files
Browse files- utils/common.py +159 -0
- utils/cond_fn.py +98 -0
- utils/face_restoration_helper.py +517 -0
- utils/helpers.py +216 -0
- utils/inference.py +320 -0
- utils/sampler.py +341 -0
utils/common.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Mapping, Any, Tuple, Callable
|
2 |
+
import importlib
|
3 |
+
import os
|
4 |
+
from urllib.parse import urlparse
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import Tensor
|
8 |
+
from torch.nn import functional as F
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torch.hub import download_url_to_file, get_dir
|
12 |
+
|
13 |
+
|
14 |
+
def get_obj_from_str(string: str, reload: bool=False) -> Any:
|
15 |
+
module, cls = string.rsplit(".", 1)
|
16 |
+
if reload:
|
17 |
+
module_imp = importlib.import_module(module)
|
18 |
+
importlib.reload(module_imp)
|
19 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
20 |
+
|
21 |
+
|
22 |
+
def instantiate_from_config(config: Mapping[str, Any]) -> Any:
|
23 |
+
if not "target" in config:
|
24 |
+
raise KeyError("Expected key `target` to instantiate.")
|
25 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
26 |
+
|
27 |
+
|
28 |
+
def wavelet_blur(image: Tensor, radius: int):
|
29 |
+
"""
|
30 |
+
Apply wavelet blur to the input tensor.
|
31 |
+
"""
|
32 |
+
# input shape: (1, 3, H, W)
|
33 |
+
# convolution kernel
|
34 |
+
kernel_vals = [
|
35 |
+
[0.0625, 0.125, 0.0625],
|
36 |
+
[0.125, 0.25, 0.125],
|
37 |
+
[0.0625, 0.125, 0.0625],
|
38 |
+
]
|
39 |
+
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
40 |
+
# add channel dimensions to the kernel to make it a 4D tensor
|
41 |
+
kernel = kernel[None, None]
|
42 |
+
# repeat the kernel across all input channels
|
43 |
+
kernel = kernel.repeat(3, 1, 1, 1)
|
44 |
+
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
45 |
+
# apply convolution
|
46 |
+
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
47 |
+
return output
|
48 |
+
|
49 |
+
|
50 |
+
def wavelet_decomposition(image: Tensor, levels=5):
|
51 |
+
"""
|
52 |
+
Apply wavelet decomposition to the input tensor.
|
53 |
+
This function only returns the low frequency & the high frequency.
|
54 |
+
"""
|
55 |
+
high_freq = torch.zeros_like(image)
|
56 |
+
for i in range(levels):
|
57 |
+
radius = 2 ** i
|
58 |
+
low_freq = wavelet_blur(image, radius)
|
59 |
+
high_freq += (image - low_freq)
|
60 |
+
image = low_freq
|
61 |
+
|
62 |
+
return high_freq, low_freq
|
63 |
+
|
64 |
+
|
65 |
+
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
66 |
+
"""
|
67 |
+
Apply wavelet decomposition, so that the content will have the same color as the style.
|
68 |
+
"""
|
69 |
+
# calculate the wavelet decomposition of the content feature
|
70 |
+
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
71 |
+
del content_low_freq
|
72 |
+
# calculate the wavelet decomposition of the style feature
|
73 |
+
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
74 |
+
del style_high_freq
|
75 |
+
# reconstruct the content feature with the style's high frequency
|
76 |
+
return content_high_freq + style_low_freq
|
77 |
+
|
78 |
+
|
79 |
+
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
|
80 |
+
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
81 |
+
"""Load file form http url, will download models if necessary.
|
82 |
+
|
83 |
+
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
84 |
+
|
85 |
+
Args:
|
86 |
+
url (str): URL to be downloaded.
|
87 |
+
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
88 |
+
Default: None.
|
89 |
+
progress (bool): Whether to show the download progress. Default: True.
|
90 |
+
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
str: The path to the downloaded file.
|
94 |
+
"""
|
95 |
+
if model_dir is None: # use the pytorch hub_dir
|
96 |
+
hub_dir = get_dir()
|
97 |
+
model_dir = os.path.join(hub_dir, 'checkpoints')
|
98 |
+
|
99 |
+
os.makedirs(model_dir, exist_ok=True)
|
100 |
+
|
101 |
+
parts = urlparse(url)
|
102 |
+
filename = os.path.basename(parts.path)
|
103 |
+
if file_name is not None:
|
104 |
+
filename = file_name
|
105 |
+
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
106 |
+
if not os.path.exists(cached_file):
|
107 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
108 |
+
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
109 |
+
return cached_file
|
110 |
+
|
111 |
+
|
112 |
+
def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]:
|
113 |
+
hi_list = list(range(0, h - tile_size + 1, tile_stride))
|
114 |
+
if (h - tile_size) % tile_stride != 0:
|
115 |
+
hi_list.append(h - tile_size)
|
116 |
+
|
117 |
+
wi_list = list(range(0, w - tile_size + 1, tile_stride))
|
118 |
+
if (w - tile_size) % tile_stride != 0:
|
119 |
+
wi_list.append(w - tile_size)
|
120 |
+
|
121 |
+
coords = []
|
122 |
+
for hi in hi_list:
|
123 |
+
for wi in wi_list:
|
124 |
+
coords.append((hi, hi + tile_size, wi, wi + tile_size))
|
125 |
+
return coords
|
126 |
+
|
127 |
+
|
128 |
+
# https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503
|
129 |
+
def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray:
|
130 |
+
"""Generates a gaussian mask of weights for tile contributions"""
|
131 |
+
latent_width = tile_width
|
132 |
+
latent_height = tile_height
|
133 |
+
var = 0.01
|
134 |
+
midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
|
135 |
+
x_probs = [
|
136 |
+
np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var)
|
137 |
+
for x in range(latent_width)]
|
138 |
+
midpoint = latent_height / 2
|
139 |
+
y_probs = [
|
140 |
+
np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var)
|
141 |
+
for y in range(latent_height)]
|
142 |
+
weights = np.outer(y_probs, x_probs)
|
143 |
+
return weights
|
144 |
+
|
145 |
+
|
146 |
+
COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False))
|
147 |
+
|
148 |
+
def count_vram_usage(func: Callable) -> Callable:
|
149 |
+
if not COUNT_VRAM:
|
150 |
+
return func
|
151 |
+
|
152 |
+
def wrapper(*args, **kwargs):
|
153 |
+
peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
154 |
+
ret = func(*args, **kwargs)
|
155 |
+
torch.cuda.synchronize()
|
156 |
+
peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
157 |
+
print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB")
|
158 |
+
return ret
|
159 |
+
return wrapper
|
utils/cond_fn.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import overload, Tuple
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class Guidance:
|
7 |
+
|
8 |
+
def __init__(self, scale: float, t_start: int, t_stop: int, space: str, repeat: int) -> "Guidance":
|
9 |
+
"""
|
10 |
+
Initialize restoration guidance.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale,
|
14 |
+
the closer the final result will be to the output of the first stage model.
|
15 |
+
t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling
|
16 |
+
process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`.
|
17 |
+
space (str): The data space for computing loss function (rgb or latent).
|
18 |
+
|
19 |
+
Our restoration guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior).
|
20 |
+
Thanks for their work!
|
21 |
+
"""
|
22 |
+
self.scale = scale * 3000
|
23 |
+
self.t_start = t_start
|
24 |
+
self.t_stop = t_stop
|
25 |
+
self.target = None
|
26 |
+
self.space = space
|
27 |
+
self.repeat = repeat
|
28 |
+
|
29 |
+
def load_target(self, target: torch.Tensor) -> None:
|
30 |
+
self.target = target
|
31 |
+
|
32 |
+
def __call__(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
|
33 |
+
# avoid propagating gradient out of this scope
|
34 |
+
pred_x0 = pred_x0.detach().clone()
|
35 |
+
target_x0 = target_x0.detach().clone()
|
36 |
+
return self._forward(target_x0, pred_x0, t)
|
37 |
+
|
38 |
+
@overload
|
39 |
+
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
|
40 |
+
...
|
41 |
+
|
42 |
+
|
43 |
+
class MSEGuidance(Guidance):
|
44 |
+
|
45 |
+
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
|
46 |
+
# inputs: [-1, 1], nchw, rgb
|
47 |
+
with torch.enable_grad():
|
48 |
+
pred_x0.requires_grad_(True)
|
49 |
+
loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
|
50 |
+
scale = self.scale
|
51 |
+
g = -torch.autograd.grad(loss, pred_x0)[0] * scale
|
52 |
+
return g, loss.item()
|
53 |
+
|
54 |
+
|
55 |
+
class WeightedMSEGuidance(Guidance):
|
56 |
+
|
57 |
+
def _get_weight(self, target: torch.Tensor) -> torch.Tensor:
|
58 |
+
# convert RGB to G
|
59 |
+
rgb_to_gray_kernel = torch.tensor([0.2989, 0.5870, 0.1140]).view(1, 3, 1, 1)
|
60 |
+
target = torch.sum(target * rgb_to_gray_kernel.to(target.device), dim=1, keepdim=True)
|
61 |
+
# initialize sobel kernel in x and y axis
|
62 |
+
G_x = [
|
63 |
+
[1, 0, -1],
|
64 |
+
[2, 0, -2],
|
65 |
+
[1, 0, -1]
|
66 |
+
]
|
67 |
+
G_y = [
|
68 |
+
[1, 2, 1],
|
69 |
+
[0, 0, 0],
|
70 |
+
[-1, -2, -1]
|
71 |
+
]
|
72 |
+
G_x = torch.tensor(G_x, dtype=target.dtype, device=target.device)[None]
|
73 |
+
G_y = torch.tensor(G_y, dtype=target.dtype, device=target.device)[None]
|
74 |
+
G = torch.stack((G_x, G_y))
|
75 |
+
|
76 |
+
target = F.pad(target, (1, 1, 1, 1), mode='replicate') # padding = 1
|
77 |
+
grad = F.conv2d(target, G, stride=1)
|
78 |
+
mag = grad.pow(2).sum(dim=1, keepdim=True).sqrt()
|
79 |
+
|
80 |
+
n, c, h, w = mag.size()
|
81 |
+
block_size = 2
|
82 |
+
blocks = mag.view(n, c, h // block_size, block_size, w // block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous()
|
83 |
+
block_mean = blocks.sum(dim=(-2, -1), keepdim=True).tanh().repeat(1, 1, 1, 1, block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous()
|
84 |
+
block_mean = block_mean.view(n, c, h, w)
|
85 |
+
weight_map = 1 - block_mean
|
86 |
+
|
87 |
+
return weight_map
|
88 |
+
|
89 |
+
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
|
90 |
+
# inputs: [-1, 1], nchw, rgb
|
91 |
+
with torch.no_grad():
|
92 |
+
w = self._get_weight((target_x0 + 1) / 2)
|
93 |
+
with torch.enable_grad():
|
94 |
+
pred_x0.requires_grad_(True)
|
95 |
+
loss = ((pred_x0 - target_x0).pow(2) * w).mean((1, 2, 3)).sum()
|
96 |
+
scale = self.scale
|
97 |
+
g = -torch.autograd.grad(loss, pred_x0)[0] * scale
|
98 |
+
return g, loss.item()
|
utils/face_restoration_helper.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms.functional import normalize
|
6 |
+
|
7 |
+
from facexlib.detection import init_detection_model
|
8 |
+
from facexlib.parsing import init_parsing_model
|
9 |
+
from facexlib.utils.misc import img2tensor, imwrite
|
10 |
+
|
11 |
+
from utils.common import load_file_from_url
|
12 |
+
|
13 |
+
def get_largest_face(det_faces, h, w):
|
14 |
+
|
15 |
+
def get_location(val, length):
|
16 |
+
if val < 0:
|
17 |
+
return 0
|
18 |
+
elif val > length:
|
19 |
+
return length
|
20 |
+
else:
|
21 |
+
return val
|
22 |
+
|
23 |
+
face_areas = []
|
24 |
+
for det_face in det_faces:
|
25 |
+
left = get_location(det_face[0], w)
|
26 |
+
right = get_location(det_face[2], w)
|
27 |
+
top = get_location(det_face[1], h)
|
28 |
+
bottom = get_location(det_face[3], h)
|
29 |
+
face_area = (right - left) * (bottom - top)
|
30 |
+
face_areas.append(face_area)
|
31 |
+
largest_idx = face_areas.index(max(face_areas))
|
32 |
+
return det_faces[largest_idx], largest_idx
|
33 |
+
|
34 |
+
|
35 |
+
def get_center_face(det_faces, h=0, w=0, center=None):
|
36 |
+
if center is not None:
|
37 |
+
center = np.array(center)
|
38 |
+
else:
|
39 |
+
center = np.array([w / 2, h / 2])
|
40 |
+
center_dist = []
|
41 |
+
for det_face in det_faces:
|
42 |
+
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
|
43 |
+
dist = np.linalg.norm(face_center - center)
|
44 |
+
center_dist.append(dist)
|
45 |
+
center_idx = center_dist.index(min(center_dist))
|
46 |
+
return det_faces[center_idx], center_idx
|
47 |
+
|
48 |
+
|
49 |
+
class FaceRestoreHelper(object):
|
50 |
+
"""Helper for the face restoration pipeline (base class)."""
|
51 |
+
|
52 |
+
def __init__(self,
|
53 |
+
upscale_factor,
|
54 |
+
face_size=512,
|
55 |
+
crop_ratio=(1, 1),
|
56 |
+
det_model='retinaface_resnet50',
|
57 |
+
save_ext='png',
|
58 |
+
template_3points=False,
|
59 |
+
pad_blur=False,
|
60 |
+
use_parse=False,
|
61 |
+
device=None):
|
62 |
+
self.template_3points = template_3points # improve robustness
|
63 |
+
self.upscale_factor = int(upscale_factor)
|
64 |
+
# the cropped face ratio based on the square face
|
65 |
+
self.crop_ratio = crop_ratio # (h, w)
|
66 |
+
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
67 |
+
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
68 |
+
self.det_model = det_model
|
69 |
+
|
70 |
+
if self.det_model == 'dlib':
|
71 |
+
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
72 |
+
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
|
73 |
+
[337.91089109, 488.38613861], [437.95049505, 493.51485149],
|
74 |
+
[513.58415842, 678.5049505]])
|
75 |
+
self.face_template = self.face_template / (1024 // face_size)
|
76 |
+
elif self.template_3points:
|
77 |
+
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
78 |
+
else:
|
79 |
+
# standard 5 landmarks for FFHQ faces with 512 x 512
|
80 |
+
# facexlib
|
81 |
+
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
|
82 |
+
[201.26117, 371.41043], [313.08905, 371.15118]])
|
83 |
+
|
84 |
+
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
|
85 |
+
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
86 |
+
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
87 |
+
|
88 |
+
self.face_template = self.face_template * (face_size / 512.0)
|
89 |
+
if self.crop_ratio[0] > 1:
|
90 |
+
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
91 |
+
if self.crop_ratio[1] > 1:
|
92 |
+
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
93 |
+
self.save_ext = save_ext
|
94 |
+
self.pad_blur = pad_blur
|
95 |
+
if self.pad_blur is True:
|
96 |
+
self.template_3points = False
|
97 |
+
|
98 |
+
self.all_landmarks_5 = []
|
99 |
+
self.det_faces = []
|
100 |
+
self.affine_matrices = []
|
101 |
+
self.inverse_affine_matrices = []
|
102 |
+
self.cropped_faces = []
|
103 |
+
self.restored_faces = []
|
104 |
+
self.pad_input_imgs = []
|
105 |
+
|
106 |
+
if device is None:
|
107 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
108 |
+
# self.device = get_device()
|
109 |
+
else:
|
110 |
+
self.device = device
|
111 |
+
|
112 |
+
# init face detection model
|
113 |
+
self.face_detector = init_detection_model(det_model, half=False, device=self.device)
|
114 |
+
|
115 |
+
# init face parsing model
|
116 |
+
self.use_parse = use_parse
|
117 |
+
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
|
118 |
+
|
119 |
+
def set_upscale_factor(self, upscale_factor):
|
120 |
+
self.upscale_factor = upscale_factor
|
121 |
+
|
122 |
+
def read_image(self, img):
|
123 |
+
"""img can be image path or cv2 loaded image."""
|
124 |
+
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
125 |
+
if isinstance(img, str):
|
126 |
+
img = cv2.imread(img)
|
127 |
+
|
128 |
+
if np.max(img) > 256: # 16-bit image
|
129 |
+
img = img / 65535 * 255
|
130 |
+
if len(img.shape) == 2: # gray image
|
131 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
132 |
+
elif img.shape[2] == 4: # BGRA image with alpha channel
|
133 |
+
img = img[:, :, 0:3]
|
134 |
+
|
135 |
+
self.input_img = img
|
136 |
+
# self.is_gray = is_gray(img, threshold=10)
|
137 |
+
# if self.is_gray:
|
138 |
+
# print('Grayscale input: True')
|
139 |
+
|
140 |
+
if min(self.input_img.shape[:2])<512:
|
141 |
+
f = 512.0/min(self.input_img.shape[:2])
|
142 |
+
self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
143 |
+
|
144 |
+
def init_dlib(self, detection_path, landmark5_path):
|
145 |
+
"""Initialize the dlib detectors and predictors."""
|
146 |
+
try:
|
147 |
+
import dlib
|
148 |
+
except ImportError:
|
149 |
+
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
|
150 |
+
detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
|
151 |
+
landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
|
152 |
+
face_detector = dlib.cnn_face_detection_model_v1(detection_path)
|
153 |
+
shape_predictor_5 = dlib.shape_predictor(landmark5_path)
|
154 |
+
return face_detector, shape_predictor_5
|
155 |
+
|
156 |
+
def get_face_landmarks_5_dlib(self,
|
157 |
+
only_keep_largest=False,
|
158 |
+
scale=1):
|
159 |
+
det_faces = self.face_detector(self.input_img, scale)
|
160 |
+
|
161 |
+
if len(det_faces) == 0:
|
162 |
+
print('No face detected. Try to increase upsample_num_times.')
|
163 |
+
return 0
|
164 |
+
else:
|
165 |
+
if only_keep_largest:
|
166 |
+
print('Detect several faces and only keep the largest.')
|
167 |
+
face_areas = []
|
168 |
+
for i in range(len(det_faces)):
|
169 |
+
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
|
170 |
+
det_faces[i].rect.bottom() - det_faces[i].rect.top())
|
171 |
+
face_areas.append(face_area)
|
172 |
+
largest_idx = face_areas.index(max(face_areas))
|
173 |
+
self.det_faces = [det_faces[largest_idx]]
|
174 |
+
else:
|
175 |
+
self.det_faces = det_faces
|
176 |
+
|
177 |
+
if len(self.det_faces) == 0:
|
178 |
+
return 0
|
179 |
+
|
180 |
+
for face in self.det_faces:
|
181 |
+
shape = self.shape_predictor_5(self.input_img, face.rect)
|
182 |
+
landmark = np.array([[part.x, part.y] for part in shape.parts()])
|
183 |
+
self.all_landmarks_5.append(landmark)
|
184 |
+
|
185 |
+
return len(self.all_landmarks_5)
|
186 |
+
|
187 |
+
|
188 |
+
def get_face_landmarks_5(self,
|
189 |
+
only_keep_largest=False,
|
190 |
+
only_center_face=False,
|
191 |
+
resize=None,
|
192 |
+
blur_ratio=0.01,
|
193 |
+
eye_dist_threshold=None):
|
194 |
+
if self.det_model == 'dlib':
|
195 |
+
return self.get_face_landmarks_5_dlib(only_keep_largest)
|
196 |
+
|
197 |
+
if resize is None:
|
198 |
+
scale = 1
|
199 |
+
input_img = self.input_img
|
200 |
+
else:
|
201 |
+
h, w = self.input_img.shape[0:2]
|
202 |
+
scale = resize / min(h, w)
|
203 |
+
scale = max(1, scale) # always scale up
|
204 |
+
h, w = int(h * scale), int(w * scale)
|
205 |
+
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
206 |
+
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
207 |
+
|
208 |
+
with torch.no_grad():
|
209 |
+
bboxes = self.face_detector.detect_faces(input_img)
|
210 |
+
|
211 |
+
if bboxes is None or bboxes.shape[0] == 0:
|
212 |
+
return 0
|
213 |
+
else:
|
214 |
+
bboxes = bboxes / scale
|
215 |
+
|
216 |
+
for bbox in bboxes:
|
217 |
+
# remove faces with too small eye distance: side faces or too small faces
|
218 |
+
eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
|
219 |
+
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
220 |
+
continue
|
221 |
+
|
222 |
+
if self.template_3points:
|
223 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
224 |
+
else:
|
225 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
226 |
+
self.all_landmarks_5.append(landmark)
|
227 |
+
self.det_faces.append(bbox[0:5])
|
228 |
+
|
229 |
+
if len(self.det_faces) == 0:
|
230 |
+
return 0
|
231 |
+
if only_keep_largest:
|
232 |
+
h, w, _ = self.input_img.shape
|
233 |
+
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
234 |
+
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
235 |
+
elif only_center_face:
|
236 |
+
h, w, _ = self.input_img.shape
|
237 |
+
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
238 |
+
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
239 |
+
|
240 |
+
# pad blurry images
|
241 |
+
if self.pad_blur:
|
242 |
+
self.pad_input_imgs = []
|
243 |
+
for landmarks in self.all_landmarks_5:
|
244 |
+
# get landmarks
|
245 |
+
eye_left = landmarks[0, :]
|
246 |
+
eye_right = landmarks[1, :]
|
247 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
248 |
+
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
249 |
+
eye_to_eye = eye_right - eye_left
|
250 |
+
eye_to_mouth = mouth_avg - eye_avg
|
251 |
+
|
252 |
+
# Get the oriented crop rectangle
|
253 |
+
# x: half width of the oriented crop rectangle
|
254 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
255 |
+
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
256 |
+
# norm with the hypotenuse: get the direction
|
257 |
+
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
258 |
+
rect_scale = 1.5
|
259 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
260 |
+
# y: half height of the oriented crop rectangle
|
261 |
+
y = np.flipud(x) * [-1, 1]
|
262 |
+
|
263 |
+
# c: center
|
264 |
+
c = eye_avg + eye_to_mouth * 0.1
|
265 |
+
# quad: (left_top, left_bottom, right_bottom, right_top)
|
266 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
267 |
+
# qsize: side length of the square
|
268 |
+
qsize = np.hypot(*x) * 2
|
269 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
270 |
+
|
271 |
+
# get pad
|
272 |
+
# pad: (width_left, height_top, width_right, height_bottom)
|
273 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
274 |
+
int(np.ceil(max(quad[:, 1]))))
|
275 |
+
pad = [
|
276 |
+
max(-pad[0] + border, 1),
|
277 |
+
max(-pad[1] + border, 1),
|
278 |
+
max(pad[2] - self.input_img.shape[0] + border, 1),
|
279 |
+
max(pad[3] - self.input_img.shape[1] + border, 1)
|
280 |
+
]
|
281 |
+
|
282 |
+
if max(pad) > 1:
|
283 |
+
# pad image
|
284 |
+
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
285 |
+
# modify landmark coords
|
286 |
+
landmarks[:, 0] += pad[0]
|
287 |
+
landmarks[:, 1] += pad[1]
|
288 |
+
# blur pad images
|
289 |
+
h, w, _ = pad_img.shape
|
290 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
291 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
292 |
+
np.float32(w - 1 - x) / pad[2]),
|
293 |
+
1.0 - np.minimum(np.float32(y) / pad[1],
|
294 |
+
np.float32(h - 1 - y) / pad[3]))
|
295 |
+
blur = int(qsize * blur_ratio)
|
296 |
+
if blur % 2 == 0:
|
297 |
+
blur += 1
|
298 |
+
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
299 |
+
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
300 |
+
|
301 |
+
pad_img = pad_img.astype('float32')
|
302 |
+
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
303 |
+
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
|
304 |
+
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
305 |
+
self.pad_input_imgs.append(pad_img)
|
306 |
+
else:
|
307 |
+
self.pad_input_imgs.append(np.copy(self.input_img))
|
308 |
+
|
309 |
+
return len(self.all_landmarks_5)
|
310 |
+
|
311 |
+
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
|
312 |
+
"""Align and warp faces with face template.
|
313 |
+
"""
|
314 |
+
if self.pad_blur:
|
315 |
+
assert len(self.pad_input_imgs) == len(
|
316 |
+
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
|
317 |
+
for idx, landmark in enumerate(self.all_landmarks_5):
|
318 |
+
# use 5 landmarks to get affine matrix
|
319 |
+
# use cv2.LMEDS method for the equivalence to skimage transform
|
320 |
+
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
321 |
+
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
|
322 |
+
self.affine_matrices.append(affine_matrix)
|
323 |
+
# warp and crop faces
|
324 |
+
if border_mode == 'constant':
|
325 |
+
border_mode = cv2.BORDER_CONSTANT
|
326 |
+
elif border_mode == 'reflect101':
|
327 |
+
border_mode = cv2.BORDER_REFLECT101
|
328 |
+
elif border_mode == 'reflect':
|
329 |
+
border_mode = cv2.BORDER_REFLECT
|
330 |
+
if self.pad_blur:
|
331 |
+
input_img = self.pad_input_imgs[idx]
|
332 |
+
else:
|
333 |
+
input_img = self.input_img
|
334 |
+
cropped_face = cv2.warpAffine(
|
335 |
+
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
|
336 |
+
self.cropped_faces.append(cropped_face)
|
337 |
+
# save the cropped face
|
338 |
+
if save_cropped_path is not None:
|
339 |
+
path = os.path.splitext(save_cropped_path)[0]
|
340 |
+
save_path = f'{path}_{idx:02d}.{self.save_ext}'
|
341 |
+
imwrite(cropped_face, save_path)
|
342 |
+
|
343 |
+
def get_inverse_affine(self, save_inverse_affine_path=None):
|
344 |
+
"""Get inverse affine matrix."""
|
345 |
+
for idx, affine_matrix in enumerate(self.affine_matrices):
|
346 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
347 |
+
inverse_affine *= self.upscale_factor
|
348 |
+
self.inverse_affine_matrices.append(inverse_affine)
|
349 |
+
# save inverse affine matrices
|
350 |
+
if save_inverse_affine_path is not None:
|
351 |
+
path, _ = os.path.splitext(save_inverse_affine_path)
|
352 |
+
save_path = f'{path}_{idx:02d}.pth'
|
353 |
+
torch.save(inverse_affine, save_path)
|
354 |
+
|
355 |
+
|
356 |
+
def add_restored_face(self, restored_face, input_face=None):
|
357 |
+
# if self.is_gray:
|
358 |
+
# restored_face = bgr2gray(restored_face) # convert img into grayscale
|
359 |
+
# if input_face is not None:
|
360 |
+
# restored_face = adain_npy(restored_face, input_face) # transfer the color
|
361 |
+
self.restored_faces.append(restored_face)
|
362 |
+
|
363 |
+
|
364 |
+
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
365 |
+
h, w, _ = self.input_img.shape
|
366 |
+
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
367 |
+
|
368 |
+
if upsample_img is None:
|
369 |
+
# simply resize the background
|
370 |
+
# upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
371 |
+
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
|
372 |
+
else:
|
373 |
+
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
374 |
+
|
375 |
+
assert len(self.restored_faces) == len(
|
376 |
+
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
|
377 |
+
|
378 |
+
inv_mask_borders = []
|
379 |
+
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
|
380 |
+
if face_upsampler is not None:
|
381 |
+
restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
|
382 |
+
inverse_affine /= self.upscale_factor
|
383 |
+
inverse_affine[:, 2] *= self.upscale_factor
|
384 |
+
face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
|
385 |
+
else:
|
386 |
+
# Add an offset to inverse affine matrix, for more precise back alignment
|
387 |
+
if self.upscale_factor > 1:
|
388 |
+
extra_offset = 0.5 * self.upscale_factor
|
389 |
+
else:
|
390 |
+
extra_offset = 0
|
391 |
+
inverse_affine[:, 2] += extra_offset
|
392 |
+
face_size = self.face_size
|
393 |
+
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
394 |
+
|
395 |
+
# if draw_box or not self.use_parse: # use square parse maps
|
396 |
+
# mask = np.ones(face_size, dtype=np.float32)
|
397 |
+
# inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
398 |
+
# # remove the black borders
|
399 |
+
# inv_mask_erosion = cv2.erode(
|
400 |
+
# inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
401 |
+
# pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
402 |
+
# total_face_area = np.sum(inv_mask_erosion) # // 3
|
403 |
+
# # add border
|
404 |
+
# if draw_box:
|
405 |
+
# h, w = face_size
|
406 |
+
# mask_border = np.ones((h, w, 3), dtype=np.float32)
|
407 |
+
# border = int(1400/np.sqrt(total_face_area))
|
408 |
+
# mask_border[border:h-border, border:w-border,:] = 0
|
409 |
+
# inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
410 |
+
# inv_mask_borders.append(inv_mask_border)
|
411 |
+
# if not self.use_parse:
|
412 |
+
# # compute the fusion edge based on the area of face
|
413 |
+
# w_edge = int(total_face_area**0.5) // 20
|
414 |
+
# erosion_radius = w_edge * 2
|
415 |
+
# inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
416 |
+
# blur_size = w_edge * 2
|
417 |
+
# inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
418 |
+
# if len(upsample_img.shape) == 2: # upsample_img is gray image
|
419 |
+
# upsample_img = upsample_img[:, :, None]
|
420 |
+
# inv_soft_mask = inv_soft_mask[:, :, None]
|
421 |
+
|
422 |
+
# always use square mask
|
423 |
+
mask = np.ones(face_size, dtype=np.float32)
|
424 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
425 |
+
# remove the black borders
|
426 |
+
inv_mask_erosion = cv2.erode(
|
427 |
+
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
428 |
+
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
429 |
+
total_face_area = np.sum(inv_mask_erosion) # // 3
|
430 |
+
# add border
|
431 |
+
if draw_box:
|
432 |
+
h, w = face_size
|
433 |
+
mask_border = np.ones((h, w, 3), dtype=np.float32)
|
434 |
+
border = int(1400/np.sqrt(total_face_area))
|
435 |
+
mask_border[border:h-border, border:w-border,:] = 0
|
436 |
+
inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
437 |
+
inv_mask_borders.append(inv_mask_border)
|
438 |
+
# compute the fusion edge based on the area of face
|
439 |
+
w_edge = int(total_face_area**0.5) // 20
|
440 |
+
erosion_radius = w_edge * 2
|
441 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
442 |
+
blur_size = w_edge * 2
|
443 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
444 |
+
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
445 |
+
upsample_img = upsample_img[:, :, None]
|
446 |
+
inv_soft_mask = inv_soft_mask[:, :, None]
|
447 |
+
|
448 |
+
# parse mask
|
449 |
+
if self.use_parse:
|
450 |
+
# inference
|
451 |
+
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
|
452 |
+
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
|
453 |
+
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
454 |
+
face_input = torch.unsqueeze(face_input, 0).to(self.device)
|
455 |
+
with torch.no_grad():
|
456 |
+
out = self.face_parse(face_input)[0]
|
457 |
+
out = out.argmax(dim=1).squeeze().cpu().numpy()
|
458 |
+
|
459 |
+
parse_mask = np.zeros(out.shape)
|
460 |
+
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
|
461 |
+
for idx, color in enumerate(MASK_COLORMAP):
|
462 |
+
parse_mask[out == idx] = color
|
463 |
+
# blur the mask
|
464 |
+
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
465 |
+
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
466 |
+
# remove the black borders
|
467 |
+
thres = 10
|
468 |
+
parse_mask[:thres, :] = 0
|
469 |
+
parse_mask[-thres:, :] = 0
|
470 |
+
parse_mask[:, :thres] = 0
|
471 |
+
parse_mask[:, -thres:] = 0
|
472 |
+
parse_mask = parse_mask / 255.
|
473 |
+
|
474 |
+
parse_mask = cv2.resize(parse_mask, face_size)
|
475 |
+
parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
|
476 |
+
inv_soft_parse_mask = parse_mask[:, :, None]
|
477 |
+
# pasted_face = inv_restored
|
478 |
+
fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
|
479 |
+
inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
|
480 |
+
|
481 |
+
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
482 |
+
alpha = upsample_img[:, :, 3:]
|
483 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
484 |
+
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
485 |
+
else:
|
486 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
487 |
+
|
488 |
+
if np.max(upsample_img) > 256: # 16-bit image
|
489 |
+
upsample_img = upsample_img.astype(np.uint16)
|
490 |
+
else:
|
491 |
+
upsample_img = upsample_img.astype(np.uint8)
|
492 |
+
|
493 |
+
# draw bounding box
|
494 |
+
if draw_box:
|
495 |
+
# upsample_input_img = cv2.resize(input_img, (w_up, h_up))
|
496 |
+
img_color = np.ones([*upsample_img.shape], dtype=np.float32)
|
497 |
+
img_color[:,:,0] = 0
|
498 |
+
img_color[:,:,1] = 255
|
499 |
+
img_color[:,:,2] = 0
|
500 |
+
for inv_mask_border in inv_mask_borders:
|
501 |
+
upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
|
502 |
+
# upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
|
503 |
+
|
504 |
+
if save_path is not None:
|
505 |
+
path = os.path.splitext(save_path)[0]
|
506 |
+
save_path = f'{path}.{self.save_ext}'
|
507 |
+
imwrite(upsample_img, save_path)
|
508 |
+
return upsample_img
|
509 |
+
|
510 |
+
def clean_all(self):
|
511 |
+
self.all_landmarks_5 = []
|
512 |
+
self.restored_faces = []
|
513 |
+
self.affine_matrices = []
|
514 |
+
self.cropped_faces = []
|
515 |
+
self.inverse_affine_matrices = []
|
516 |
+
self.det_faces = []
|
517 |
+
self.pad_input_imgs = []
|
utils/helpers.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import overload, Tuple, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
from model.cldm import ControlLDM
|
11 |
+
from model.gaussian_diffusion import Diffusion
|
12 |
+
from model.bsrnet import RRDBNet
|
13 |
+
from model.swinir import SwinIR
|
14 |
+
from model.scunet import SCUNet
|
15 |
+
from utils.sampler import SpacedSampler
|
16 |
+
from utils.cond_fn import Guidance
|
17 |
+
from utils.common import wavelet_decomposition, wavelet_reconstruction, count_vram_usage
|
18 |
+
|
19 |
+
|
20 |
+
def bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray:
|
21 |
+
pil = Image.fromarray(img)
|
22 |
+
res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC)
|
23 |
+
return np.array(res)
|
24 |
+
|
25 |
+
|
26 |
+
def resize_short_edge_to(imgs: torch.Tensor, size: int) -> torch.Tensor:
|
27 |
+
_, _, h, w = imgs.size()
|
28 |
+
if h == w:
|
29 |
+
new_h, new_w = size, size
|
30 |
+
elif h < w:
|
31 |
+
new_h, new_w = size, int(w * (size / h))
|
32 |
+
else:
|
33 |
+
new_h, new_w = int(h * (size / w)), size
|
34 |
+
return F.interpolate(imgs, size=(new_h, new_w), mode="bicubic", antialias=True)
|
35 |
+
|
36 |
+
|
37 |
+
def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor:
|
38 |
+
_, _, h, w = imgs.size()
|
39 |
+
if h % multiple == 0 and w % multiple == 0:
|
40 |
+
return imgs.clone()
|
41 |
+
# get_pad = lambda x: (x // multiple + 1) * multiple - x
|
42 |
+
get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x
|
43 |
+
ph, pw = get_pad(h), get_pad(w)
|
44 |
+
return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0)
|
45 |
+
|
46 |
+
|
47 |
+
class Pipeline:
|
48 |
+
|
49 |
+
def __init__(self, stage1_model: nn.Module, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
|
50 |
+
self.stage1_model = stage1_model
|
51 |
+
self.cldm = cldm
|
52 |
+
self.diffusion = diffusion
|
53 |
+
self.cond_fn = cond_fn
|
54 |
+
self.device = device
|
55 |
+
self.final_size: Tuple[int] = None
|
56 |
+
|
57 |
+
def set_final_size(self, lq: torch.Tensor) -> None:
|
58 |
+
h, w = lq.shape[2:]
|
59 |
+
self.final_size = (h, w)
|
60 |
+
|
61 |
+
@overload
|
62 |
+
def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
|
63 |
+
...
|
64 |
+
|
65 |
+
@count_vram_usage
|
66 |
+
def run_stage2(
|
67 |
+
self,
|
68 |
+
clean: torch.Tensor,
|
69 |
+
steps: int,
|
70 |
+
strength: float,
|
71 |
+
tiled: bool,
|
72 |
+
tile_size: int,
|
73 |
+
tile_stride: int,
|
74 |
+
pos_prompt: str,
|
75 |
+
neg_prompt: str,
|
76 |
+
cfg_scale: float,
|
77 |
+
better_start: float
|
78 |
+
) -> torch.Tensor:
|
79 |
+
### preprocess
|
80 |
+
bs, _, ori_h, ori_w = clean.shape
|
81 |
+
# pad: ensure that height & width are multiples of 64
|
82 |
+
pad_clean = pad_to_multiples_of(clean, multiple=64)
|
83 |
+
h, w = pad_clean.shape[2:]
|
84 |
+
# prepare conditon
|
85 |
+
if not tiled:
|
86 |
+
cond = self.cldm.prepare_condition(pad_clean, [pos_prompt] * bs)
|
87 |
+
uncond = self.cldm.prepare_condition(pad_clean, [neg_prompt] * bs)
|
88 |
+
else:
|
89 |
+
cond = self.cldm.prepare_condition_tiled(pad_clean, [pos_prompt] * bs, tile_size, tile_stride)
|
90 |
+
uncond = self.cldm.prepare_condition_tiled(pad_clean, [neg_prompt] * bs, tile_size, tile_stride)
|
91 |
+
if self.cond_fn:
|
92 |
+
self.cond_fn.load_target(pad_clean * 2 - 1)
|
93 |
+
old_control_scales = self.cldm.control_scales
|
94 |
+
self.cldm.control_scales = [strength] * 13
|
95 |
+
if better_start:
|
96 |
+
# using noised low frequency part of condition as a better start point of
|
97 |
+
# reverse sampling, which can prevent our model from generating noise in
|
98 |
+
# image background.
|
99 |
+
_, low_freq = wavelet_decomposition(pad_clean)
|
100 |
+
if not tiled:
|
101 |
+
x_0 = self.cldm.vae_encode(low_freq)
|
102 |
+
else:
|
103 |
+
x_0 = self.cldm.vae_encode_tiled(low_freq, tile_size, tile_stride)
|
104 |
+
x_T = self.diffusion.q_sample(
|
105 |
+
x_0,
|
106 |
+
torch.full((bs, ), self.diffusion.num_timesteps - 1, dtype=torch.long, device=self.device),
|
107 |
+
torch.randn(x_0.shape, dtype=torch.float32, device=self.device)
|
108 |
+
)
|
109 |
+
# print(f"diffusion sqrt_alphas_cumprod: {self.diffusion.sqrt_alphas_cumprod[-1]}")
|
110 |
+
else:
|
111 |
+
x_T = torch.randn((bs, 4, h // 8, w // 8), dtype=torch.float32, device=self.device)
|
112 |
+
### run sampler
|
113 |
+
sampler = SpacedSampler(self.diffusion.betas)
|
114 |
+
z = sampler.sample(
|
115 |
+
model=self.cldm, device=self.device, steps=steps, batch_size=bs, x_size=(4, h // 8, w // 8),
|
116 |
+
cond=cond, uncond=uncond, cfg_scale=cfg_scale, x_T=x_T, progress=True,
|
117 |
+
progress_leave=True, cond_fn=self.cond_fn, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
118 |
+
)
|
119 |
+
if not tiled:
|
120 |
+
x = self.cldm.vae_decode(z)
|
121 |
+
else:
|
122 |
+
x = self.cldm.vae_decode_tiled(z, tile_size // 8, tile_stride // 8)
|
123 |
+
### postprocess
|
124 |
+
self.cldm.control_scales = old_control_scales
|
125 |
+
sample = x[:, :, :ori_h, :ori_w]
|
126 |
+
return sample
|
127 |
+
|
128 |
+
@torch.no_grad()
|
129 |
+
def run(
|
130 |
+
self,
|
131 |
+
lq: np.ndarray,
|
132 |
+
steps: int,
|
133 |
+
strength: float,
|
134 |
+
tiled: bool,
|
135 |
+
tile_size: int,
|
136 |
+
tile_stride: int,
|
137 |
+
pos_prompt: str,
|
138 |
+
neg_prompt: str,
|
139 |
+
cfg_scale: float,
|
140 |
+
better_start: bool
|
141 |
+
) -> np.ndarray:
|
142 |
+
# image to tensor
|
143 |
+
lq = torch.tensor((lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device)
|
144 |
+
lq = rearrange(lq, "n h w c -> n c h w").contiguous()
|
145 |
+
# set pipeline output size
|
146 |
+
self.set_final_size(lq)
|
147 |
+
clean = self.run_stage1(lq)
|
148 |
+
sample = self.run_stage2(
|
149 |
+
clean, steps, strength, tiled, tile_size, tile_stride,
|
150 |
+
pos_prompt, neg_prompt, cfg_scale, better_start
|
151 |
+
)
|
152 |
+
# colorfix (borrowed from StableSR, thanks for their work)
|
153 |
+
sample = (sample + 1) / 2
|
154 |
+
sample = wavelet_reconstruction(sample, clean)
|
155 |
+
# resize to desired output size
|
156 |
+
sample = F.interpolate(sample, size=self.final_size, mode="bicubic", antialias=True)
|
157 |
+
# tensor to image
|
158 |
+
sample = rearrange(sample * 255., "n c h w -> n h w c")
|
159 |
+
sample = sample.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
160 |
+
return sample
|
161 |
+
|
162 |
+
|
163 |
+
class BSRNetPipeline(Pipeline):
|
164 |
+
|
165 |
+
def __init__(self, bsrnet: RRDBNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str, upscale: float) -> None:
|
166 |
+
super().__init__(bsrnet, cldm, diffusion, cond_fn, device)
|
167 |
+
self.upscale = upscale
|
168 |
+
|
169 |
+
def set_final_size(self, lq: torch.Tensor) -> None:
|
170 |
+
h, w = lq.shape[2:]
|
171 |
+
self.final_size = (int(h * self.upscale), int(w * self.upscale))
|
172 |
+
|
173 |
+
@count_vram_usage
|
174 |
+
def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
|
175 |
+
# NOTE: upscale is always set to 4 in our experiments
|
176 |
+
clean = self.stage1_model(lq)
|
177 |
+
# if self.final_size[0] < 512 and self.final_size[1] < 512:
|
178 |
+
if min(self.final_size) < 512:
|
179 |
+
clean = resize_short_edge_to(clean, size=512)
|
180 |
+
else:
|
181 |
+
clean = F.interpolate(clean, size=self.final_size, mode="bicubic", antialias=True)
|
182 |
+
return clean
|
183 |
+
|
184 |
+
|
185 |
+
class SwinIRPipeline(Pipeline):
|
186 |
+
|
187 |
+
def __init__(self, swinir: SwinIR, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
|
188 |
+
super().__init__(swinir, cldm, diffusion, cond_fn, device)
|
189 |
+
|
190 |
+
@count_vram_usage
|
191 |
+
def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
|
192 |
+
# NOTE: lq size is always equal to 512 in our experiments
|
193 |
+
# resize: ensure the input lq size is as least 512, since SwinIR is trained on 512 resolution
|
194 |
+
if min(lq.shape[2:]) < 512:
|
195 |
+
lq = resize_short_edge_to(lq, size=512)
|
196 |
+
ori_h, ori_w = lq.shape[2:]
|
197 |
+
# pad: ensure that height & width are multiples of 64
|
198 |
+
pad_lq = pad_to_multiples_of(lq, multiple=64)
|
199 |
+
# run
|
200 |
+
clean = self.stage1_model(pad_lq)
|
201 |
+
# remove padding
|
202 |
+
clean = clean[:, :, :ori_h, :ori_w]
|
203 |
+
return clean
|
204 |
+
|
205 |
+
|
206 |
+
class SCUNetPipeline(Pipeline):
|
207 |
+
|
208 |
+
def __init__(self, scunet: SCUNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
|
209 |
+
super().__init__(scunet, cldm, diffusion, cond_fn, device)
|
210 |
+
|
211 |
+
@count_vram_usage
|
212 |
+
def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
|
213 |
+
clean = self.stage1_model(lq)
|
214 |
+
if min(clean.shape[2:]) < 512:
|
215 |
+
clean = resize_short_edge_to(clean, size=512)
|
216 |
+
return clean
|
utils/inference.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import overload, Generator, Dict
|
3 |
+
from argparse import Namespace
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
from model.cldm import ControlLDM
|
11 |
+
from model.gaussian_diffusion import Diffusion
|
12 |
+
from model.bsrnet import RRDBNet
|
13 |
+
from model.scunet import SCUNet
|
14 |
+
from model.swinir import SwinIR
|
15 |
+
from utils.common import instantiate_from_config, load_file_from_url, count_vram_usage
|
16 |
+
from utils.face_restoration_helper import FaceRestoreHelper
|
17 |
+
from utils.helpers import (
|
18 |
+
Pipeline,
|
19 |
+
BSRNetPipeline, SwinIRPipeline, SCUNetPipeline,
|
20 |
+
bicubic_resize
|
21 |
+
)
|
22 |
+
from utils.cond_fn import MSEGuidance, WeightedMSEGuidance
|
23 |
+
|
24 |
+
|
25 |
+
MODELS = {
|
26 |
+
### stage_1 model weights
|
27 |
+
"bsrnet": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRNet.pth",
|
28 |
+
# the following checkpoint is up-to-date, but we use the old version in our paper
|
29 |
+
# "swinir_face": "https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth",
|
30 |
+
"swinir_face": "https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt",
|
31 |
+
"scunet_psnr": "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth",
|
32 |
+
"swinir_general": "https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt",
|
33 |
+
### stage_2 model weights
|
34 |
+
"sd_v21": "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt",
|
35 |
+
"v1_face": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_face.pth",
|
36 |
+
"v1_general": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_general.pth",
|
37 |
+
"v2": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v2.pth"
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def load_model_from_url(url: str) -> Dict[str, torch.Tensor]:
|
42 |
+
sd_path = load_file_from_url(url, model_dir="weights")
|
43 |
+
sd = torch.load(sd_path, map_location="cpu")
|
44 |
+
if "state_dict" in sd:
|
45 |
+
sd = sd["state_dict"]
|
46 |
+
if list(sd.keys())[0].startswith("module"):
|
47 |
+
sd = {k[len("module."):]: v for k, v in sd.items()}
|
48 |
+
return sd
|
49 |
+
|
50 |
+
|
51 |
+
class InferenceLoop:
|
52 |
+
|
53 |
+
def __init__(self, args: Namespace) -> "InferenceLoop":
|
54 |
+
self.args = args
|
55 |
+
self.loop_ctx = {}
|
56 |
+
self.pipeline: Pipeline = None
|
57 |
+
self.init_stage1_model()
|
58 |
+
self.init_stage2_model()
|
59 |
+
self.init_cond_fn()
|
60 |
+
self.init_pipeline()
|
61 |
+
|
62 |
+
@overload
|
63 |
+
def init_stage1_model(self) -> None:
|
64 |
+
...
|
65 |
+
|
66 |
+
@count_vram_usage
|
67 |
+
def init_stage2_model(self) -> None:
|
68 |
+
### load uent, vae, clip
|
69 |
+
self.cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml"))
|
70 |
+
sd = load_model_from_url(MODELS["sd_v21"])
|
71 |
+
unused = self.cldm.load_pretrained_sd(sd)
|
72 |
+
print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
|
73 |
+
### load controlnet
|
74 |
+
if self.args.version == "v1":
|
75 |
+
if self.args.task == "fr":
|
76 |
+
control_sd = load_model_from_url(MODELS["v1_face"])
|
77 |
+
elif self.args.task == "sr":
|
78 |
+
control_sd = load_model_from_url(MODELS["v1_general"])
|
79 |
+
else:
|
80 |
+
raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'")
|
81 |
+
else:
|
82 |
+
control_sd = load_model_from_url(MODELS["v2"])
|
83 |
+
self.cldm.load_controlnet_from_ckpt(control_sd)
|
84 |
+
print(f"strictly load controlnet weight")
|
85 |
+
self.cldm.eval().to(self.args.device)
|
86 |
+
### load diffusion
|
87 |
+
self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load("configs/inference/diffusion.yaml"))
|
88 |
+
self.diffusion.to(self.args.device)
|
89 |
+
|
90 |
+
def init_cond_fn(self) -> None:
|
91 |
+
if not self.args.guidance:
|
92 |
+
self.cond_fn = None
|
93 |
+
return
|
94 |
+
if self.args.g_loss == "mse":
|
95 |
+
cond_fn_cls = MSEGuidance
|
96 |
+
elif self.args.g_loss == "w_mse":
|
97 |
+
cond_fn_cls = WeightedMSEGuidance
|
98 |
+
else:
|
99 |
+
raise ValueError(self.args.g_loss)
|
100 |
+
self.cond_fn = cond_fn_cls(
|
101 |
+
scale=self.args.g_scale, t_start=self.args.g_start, t_stop=self.args.g_stop,
|
102 |
+
space=self.args.g_space, repeat=self.args.g_repeat
|
103 |
+
)
|
104 |
+
|
105 |
+
@overload
|
106 |
+
def init_pipeline(self) -> None:
|
107 |
+
...
|
108 |
+
|
109 |
+
def setup(self) -> None:
|
110 |
+
self.output_dir = self.args.output
|
111 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
112 |
+
|
113 |
+
def lq_loader(self) -> Generator[np.ndarray, None, None]:
|
114 |
+
img_exts = [".png", ".jpg", ".jpeg"]
|
115 |
+
if os.path.isdir(self.args.input):
|
116 |
+
file_names = sorted([
|
117 |
+
file_name for file_name in os.listdir(self.args.input) if os.path.splitext(file_name)[-1] in img_exts
|
118 |
+
])
|
119 |
+
file_paths = [os.path.join(self.args.input, file_name) for file_name in file_names]
|
120 |
+
else:
|
121 |
+
assert os.path.splitext(self.args.input)[-1] in img_exts
|
122 |
+
file_paths = [self.args.input]
|
123 |
+
|
124 |
+
def _loader() -> Generator[np.ndarray, None, None]:
|
125 |
+
for file_path in file_paths:
|
126 |
+
### load lq
|
127 |
+
lq = np.array(Image.open(file_path).convert("RGB"))
|
128 |
+
print(f"load lq: {file_path}")
|
129 |
+
### set context for saving results
|
130 |
+
self.loop_ctx["file_stem"] = os.path.splitext(os.path.basename(file_path))[0]
|
131 |
+
for i in range(self.args.n_samples):
|
132 |
+
self.loop_ctx["repeat_idx"] = i
|
133 |
+
yield lq
|
134 |
+
|
135 |
+
return _loader
|
136 |
+
|
137 |
+
def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
|
138 |
+
return lq
|
139 |
+
|
140 |
+
@torch.no_grad()
|
141 |
+
def run(self) -> None:
|
142 |
+
self.setup()
|
143 |
+
# We don't support batch processing since input images may have different size
|
144 |
+
loader = self.lq_loader()
|
145 |
+
for lq in loader():
|
146 |
+
lq = self.after_load_lq(lq)
|
147 |
+
sample = self.pipeline.run(
|
148 |
+
lq[None], self.args.steps, 1.0, self.args.tiled,
|
149 |
+
self.args.tile_size, self.args.tile_stride,
|
150 |
+
self.args.pos_prompt, self.args.neg_prompt, self.args.cfg_scale,
|
151 |
+
self.args.better_start
|
152 |
+
)[0]
|
153 |
+
self.save(sample)
|
154 |
+
|
155 |
+
def save(self, sample: np.ndarray) -> None:
|
156 |
+
file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"]
|
157 |
+
file_name = f"{file_stem}_{repeat_idx}.png" if self.args.n_samples > 1 else f"{file_stem}.png"
|
158 |
+
save_path = os.path.join(self.args.output, file_name)
|
159 |
+
Image.fromarray(sample).save(save_path)
|
160 |
+
print(f"save result to {save_path}")
|
161 |
+
|
162 |
+
|
163 |
+
class BSRInferenceLoop(InferenceLoop):
|
164 |
+
|
165 |
+
@count_vram_usage
|
166 |
+
def init_stage1_model(self) -> None:
|
167 |
+
self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml"))
|
168 |
+
sd = load_model_from_url(MODELS["bsrnet"])
|
169 |
+
self.bsrnet.load_state_dict(sd, strict=True)
|
170 |
+
self.bsrnet.eval().to(self.args.device)
|
171 |
+
|
172 |
+
def init_pipeline(self) -> None:
|
173 |
+
self.pipeline = BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale)
|
174 |
+
|
175 |
+
|
176 |
+
class BFRInferenceLoop(InferenceLoop):
|
177 |
+
|
178 |
+
@count_vram_usage
|
179 |
+
def init_stage1_model(self) -> None:
|
180 |
+
self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
|
181 |
+
sd = load_model_from_url(MODELS["swinir_face"])
|
182 |
+
self.swinir_face.load_state_dict(sd, strict=True)
|
183 |
+
self.swinir_face.eval().to(self.args.device)
|
184 |
+
|
185 |
+
def init_pipeline(self) -> None:
|
186 |
+
self.pipeline = SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device)
|
187 |
+
|
188 |
+
def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
|
189 |
+
# For BFR task, super resolution is achieved by directly upscaling lq
|
190 |
+
return bicubic_resize(lq, self.args.upscale)
|
191 |
+
|
192 |
+
|
193 |
+
class BIDInferenceLoop(InferenceLoop):
|
194 |
+
|
195 |
+
@count_vram_usage
|
196 |
+
def init_stage1_model(self) -> None:
|
197 |
+
self.scunet_psnr: SCUNet = instantiate_from_config(OmegaConf.load("configs/inference/scunet.yaml"))
|
198 |
+
sd = load_model_from_url(MODELS["scunet_psnr"])
|
199 |
+
self.scunet_psnr.load_state_dict(sd, strict=True)
|
200 |
+
self.scunet_psnr.eval().to(self.args.device)
|
201 |
+
|
202 |
+
def init_pipeline(self) -> None:
|
203 |
+
self.pipeline = SCUNetPipeline(self.scunet_psnr, self.cldm, self.diffusion, self.cond_fn, self.args.device)
|
204 |
+
|
205 |
+
def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
|
206 |
+
# For BID task, super resolution is achieved by directly upscaling lq
|
207 |
+
return bicubic_resize(lq, self.args.upscale)
|
208 |
+
|
209 |
+
|
210 |
+
class V1InferenceLoop(InferenceLoop):
|
211 |
+
|
212 |
+
@count_vram_usage
|
213 |
+
def init_stage1_model(self) -> None:
|
214 |
+
self.swinir: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
|
215 |
+
if self.args.task == "fr":
|
216 |
+
sd = load_model_from_url(MODELS["swinir_face"])
|
217 |
+
elif self.args.task == "sr":
|
218 |
+
sd = load_model_from_url(MODELS["swinir_general"])
|
219 |
+
else:
|
220 |
+
raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'")
|
221 |
+
self.swinir.load_state_dict(sd, strict=True)
|
222 |
+
self.swinir.eval().to(self.args.device)
|
223 |
+
|
224 |
+
def init_pipeline(self) -> None:
|
225 |
+
self.pipeline = SwinIRPipeline(self.swinir, self.cldm, self.diffusion, self.cond_fn, self.args.device)
|
226 |
+
|
227 |
+
def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
|
228 |
+
# For BFR task, super resolution is achieved by directly upscaling lq
|
229 |
+
return bicubic_resize(lq, self.args.upscale)
|
230 |
+
|
231 |
+
|
232 |
+
class UnAlignedBFRInferenceLoop(InferenceLoop):
|
233 |
+
|
234 |
+
@count_vram_usage
|
235 |
+
def init_stage1_model(self) -> None:
|
236 |
+
self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml"))
|
237 |
+
sd = load_model_from_url(MODELS["bsrnet"])
|
238 |
+
self.bsrnet.load_state_dict(sd, strict=True)
|
239 |
+
self.bsrnet.eval().to(self.args.device)
|
240 |
+
|
241 |
+
self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
|
242 |
+
sd = load_model_from_url(MODELS["swinir_face"])
|
243 |
+
self.swinir_face.load_state_dict(sd, strict=True)
|
244 |
+
self.swinir_face.eval().to(self.args.device)
|
245 |
+
|
246 |
+
def init_pipeline(self) -> None:
|
247 |
+
self.pipes = {
|
248 |
+
"bg": BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale),
|
249 |
+
"face": SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device)
|
250 |
+
}
|
251 |
+
self.pipeline = self.pipes["face"]
|
252 |
+
|
253 |
+
def setup(self) -> None:
|
254 |
+
super().setup()
|
255 |
+
self.cropped_face_dir = os.path.join(self.args.output, "cropped_faces")
|
256 |
+
os.makedirs(self.cropped_face_dir, exist_ok=True)
|
257 |
+
self.restored_face_dir = os.path.join(self.args.output, "restored_faces")
|
258 |
+
os.makedirs(self.restored_face_dir, exist_ok=True)
|
259 |
+
self.restored_bg_dir = os.path.join(self.args.output, "restored_backgrounds")
|
260 |
+
os.makedirs(self.restored_bg_dir, exist_ok=True)
|
261 |
+
|
262 |
+
def lq_loader(self) -> Generator[np.ndarray, None, None]:
|
263 |
+
base_loader = super().lq_loader()
|
264 |
+
self.face_helper = FaceRestoreHelper(
|
265 |
+
device=self.args.device,
|
266 |
+
upscale_factor=1,
|
267 |
+
face_size=512,
|
268 |
+
use_parse=True,
|
269 |
+
det_model="retinaface_resnet50"
|
270 |
+
)
|
271 |
+
|
272 |
+
def _loader() -> Generator[np.ndarray, None, None]:
|
273 |
+
for lq in base_loader():
|
274 |
+
### set input image
|
275 |
+
self.face_helper.clean_all()
|
276 |
+
upscaled_bg = bicubic_resize(lq, self.args.upscale)
|
277 |
+
self.face_helper.read_image(upscaled_bg)
|
278 |
+
### get face landmarks for each face
|
279 |
+
self.face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
|
280 |
+
self.face_helper.align_warp_face()
|
281 |
+
print(f"detect {len(self.face_helper.cropped_faces)} faces")
|
282 |
+
### restore each face (has been upscaeled)
|
283 |
+
for i, lq_face in enumerate(self.face_helper.cropped_faces):
|
284 |
+
self.loop_ctx["is_face"] = True
|
285 |
+
self.loop_ctx["face_idx"] = i
|
286 |
+
self.loop_ctx["cropped_face"] = lq_face
|
287 |
+
yield lq_face
|
288 |
+
### restore background (hasn't been upscaled)
|
289 |
+
self.loop_ctx["is_face"] = False
|
290 |
+
yield lq
|
291 |
+
|
292 |
+
return _loader
|
293 |
+
|
294 |
+
def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
|
295 |
+
if self.loop_ctx["is_face"]:
|
296 |
+
self.pipeline = self.pipes["face"]
|
297 |
+
else:
|
298 |
+
self.pipeline = self.pipes["bg"]
|
299 |
+
return lq
|
300 |
+
|
301 |
+
def save(self, sample: np.ndarray) -> None:
|
302 |
+
file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"]
|
303 |
+
if self.loop_ctx["is_face"]:
|
304 |
+
face_idx = self.loop_ctx["face_idx"]
|
305 |
+
file_name = f"{file_stem}_{repeat_idx}_face_{face_idx}.png"
|
306 |
+
Image.fromarray(sample).save(os.path.join(self.restored_face_dir, file_name))
|
307 |
+
|
308 |
+
cropped_face = self.loop_ctx["cropped_face"]
|
309 |
+
Image.fromarray(cropped_face).save(os.path.join(self.cropped_face_dir, file_name))
|
310 |
+
|
311 |
+
self.face_helper.add_restored_face(sample)
|
312 |
+
else:
|
313 |
+
self.face_helper.get_inverse_affine()
|
314 |
+
# paste each restored face to the input image
|
315 |
+
restored_img = self.face_helper.paste_faces_to_input_image(
|
316 |
+
upsample_img=sample
|
317 |
+
)
|
318 |
+
file_name = f"{file_stem}_{repeat_idx}.png"
|
319 |
+
Image.fromarray(sample).save(os.path.join(self.restored_bg_dir, file_name))
|
320 |
+
Image.fromarray(restored_img).save(os.path.join(self.output_dir, file_name))
|
utils/sampler.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Dict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from model.gaussian_diffusion import extract_into_tensor
|
9 |
+
from model.cldm import ControlLDM
|
10 |
+
from utils.cond_fn import Guidance
|
11 |
+
from utils.common import sliding_windows, gaussian_weights
|
12 |
+
|
13 |
+
|
14 |
+
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
|
15 |
+
def space_timesteps(num_timesteps, section_counts):
|
16 |
+
"""
|
17 |
+
Create a list of timesteps to use from an original diffusion process,
|
18 |
+
given the number of timesteps we want to take from equally-sized portions
|
19 |
+
of the original process.
|
20 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
21 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
22 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
23 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
24 |
+
from the DDIM paper is used, and only one section is allowed.
|
25 |
+
:param num_timesteps: the number of diffusion steps in the original
|
26 |
+
process to divide up.
|
27 |
+
:param section_counts: either a list of numbers, or a string containing
|
28 |
+
comma-separated numbers, indicating the step count
|
29 |
+
per section. As a special case, use "ddimN" where N
|
30 |
+
is a number of steps to use the striding from the
|
31 |
+
DDIM paper.
|
32 |
+
:return: a set of diffusion steps from the original process to use.
|
33 |
+
"""
|
34 |
+
if isinstance(section_counts, str):
|
35 |
+
if section_counts.startswith("ddim"):
|
36 |
+
desired_count = int(section_counts[len("ddim") :])
|
37 |
+
for i in range(1, num_timesteps):
|
38 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
39 |
+
return set(range(0, num_timesteps, i))
|
40 |
+
raise ValueError(
|
41 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
42 |
+
)
|
43 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
44 |
+
size_per = num_timesteps // len(section_counts)
|
45 |
+
extra = num_timesteps % len(section_counts)
|
46 |
+
start_idx = 0
|
47 |
+
all_steps = []
|
48 |
+
for i, section_count in enumerate(section_counts):
|
49 |
+
size = size_per + (1 if i < extra else 0)
|
50 |
+
if size < section_count:
|
51 |
+
raise ValueError(
|
52 |
+
f"cannot divide section of {size} steps into {section_count}"
|
53 |
+
)
|
54 |
+
if section_count <= 1:
|
55 |
+
frac_stride = 1
|
56 |
+
else:
|
57 |
+
frac_stride = (size - 1) / (section_count - 1)
|
58 |
+
cur_idx = 0.0
|
59 |
+
taken_steps = []
|
60 |
+
for _ in range(section_count):
|
61 |
+
taken_steps.append(start_idx + round(cur_idx))
|
62 |
+
cur_idx += frac_stride
|
63 |
+
all_steps += taken_steps
|
64 |
+
start_idx += size
|
65 |
+
return set(all_steps)
|
66 |
+
|
67 |
+
|
68 |
+
class SpacedSampler(nn.Module):
|
69 |
+
"""
|
70 |
+
Implementation for spaced sampling schedule proposed in IDDPM. This class is designed
|
71 |
+
for sampling ControlLDM.
|
72 |
+
|
73 |
+
https://arxiv.org/pdf/2102.09672.pdf
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, betas: np.ndarray) -> "SpacedSampler":
|
77 |
+
super().__init__()
|
78 |
+
self.num_timesteps = len(betas)
|
79 |
+
self.original_betas = betas
|
80 |
+
self.original_alphas_cumprod = np.cumprod(1.0 - betas, axis=0)
|
81 |
+
self.context = {}
|
82 |
+
|
83 |
+
def register(self, name: str, value: np.ndarray) -> None:
|
84 |
+
self.register_buffer(name, torch.tensor(value, dtype=torch.float32))
|
85 |
+
|
86 |
+
def make_schedule(self, num_steps: int) -> None:
|
87 |
+
# calcualte betas for spaced sampling
|
88 |
+
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
|
89 |
+
used_timesteps = space_timesteps(self.num_timesteps, str(num_steps))
|
90 |
+
betas = []
|
91 |
+
last_alpha_cumprod = 1.0
|
92 |
+
for i, alpha_cumprod in enumerate(self.original_alphas_cumprod):
|
93 |
+
if i in used_timesteps:
|
94 |
+
# marginal distribution is the same as q(x_{S_t}|x_0)
|
95 |
+
betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
96 |
+
last_alpha_cumprod = alpha_cumprod
|
97 |
+
assert len(betas) == num_steps
|
98 |
+
self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...]
|
99 |
+
|
100 |
+
betas = np.array(betas, dtype=np.float64)
|
101 |
+
alphas = 1.0 - betas
|
102 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
103 |
+
# print(f"sampler sqrt_alphas_cumprod: {np.sqrt(alphas_cumprod)[-1]}")
|
104 |
+
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
105 |
+
sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
|
106 |
+
sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
|
107 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
108 |
+
posterior_variance = (
|
109 |
+
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
110 |
+
)
|
111 |
+
# log calculation clipped because the posterior variance is 0 at the
|
112 |
+
# beginning of the diffusion chain.
|
113 |
+
posterior_log_variance_clipped = np.log(
|
114 |
+
np.append(posterior_variance[1], posterior_variance[1:])
|
115 |
+
)
|
116 |
+
posterior_mean_coef1 = (
|
117 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
118 |
+
)
|
119 |
+
posterior_mean_coef2 = (
|
120 |
+
(1.0 - alphas_cumprod_prev)
|
121 |
+
* np.sqrt(alphas)
|
122 |
+
/ (1.0 - alphas_cumprod)
|
123 |
+
)
|
124 |
+
|
125 |
+
self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod)
|
126 |
+
self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod)
|
127 |
+
self.register("posterior_variance", posterior_variance)
|
128 |
+
self.register("posterior_log_variance_clipped", posterior_log_variance_clipped)
|
129 |
+
self.register("posterior_mean_coef1", posterior_mean_coef1)
|
130 |
+
self.register("posterior_mean_coef2", posterior_mean_coef2)
|
131 |
+
|
132 |
+
def q_posterior_mean_variance(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor]:
|
133 |
+
"""
|
134 |
+
Implement the posterior distribution q(x_{t-1}|x_t, x_0).
|
135 |
+
|
136 |
+
Args:
|
137 |
+
x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`.
|
138 |
+
x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`.
|
139 |
+
t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get
|
140 |
+
parameters for each timestep.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
posterior_mean (torch.Tensor): Mean of the posterior distribution.
|
144 |
+
posterior_variance (torch.Tensor): Variance of the posterior distribution.
|
145 |
+
posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution.
|
146 |
+
"""
|
147 |
+
posterior_mean = (
|
148 |
+
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
149 |
+
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
150 |
+
)
|
151 |
+
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
152 |
+
posterior_log_variance_clipped = extract_into_tensor(
|
153 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
154 |
+
)
|
155 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
156 |
+
|
157 |
+
def _predict_xstart_from_eps(self, x_t: torch.Tensor, t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor:
|
158 |
+
return (
|
159 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
160 |
+
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
161 |
+
)
|
162 |
+
|
163 |
+
def apply_cond_fn(
|
164 |
+
self,
|
165 |
+
model: ControlLDM,
|
166 |
+
pred_x0: torch.Tensor,
|
167 |
+
t: torch.Tensor,
|
168 |
+
index: torch.Tensor,
|
169 |
+
cond_fn: Guidance
|
170 |
+
) -> torch.Tensor:
|
171 |
+
t_now = int(t[0].item()) + 1
|
172 |
+
if not (cond_fn.t_stop < t_now and t_now < cond_fn.t_start):
|
173 |
+
# stop guidance
|
174 |
+
self.context["g_apply"] = False
|
175 |
+
return pred_x0
|
176 |
+
grad_rescale = 1 / extract_into_tensor(self.posterior_mean_coef1, index, pred_x0.shape)
|
177 |
+
# apply guidance for multiple times
|
178 |
+
loss_vals = []
|
179 |
+
for _ in range(cond_fn.repeat):
|
180 |
+
# set target and pred for gradient computation
|
181 |
+
target, pred = None, None
|
182 |
+
if cond_fn.space == "latent":
|
183 |
+
target = model.vae_encode(cond_fn.target)
|
184 |
+
pred = pred_x0
|
185 |
+
elif cond_fn.space == "rgb":
|
186 |
+
# We need to backward gradient to x0 in latent space, so it's required
|
187 |
+
# to trace the computation graph while decoding the latent.
|
188 |
+
with torch.enable_grad():
|
189 |
+
target = cond_fn.target
|
190 |
+
pred_x0_rg = pred_x0.detach().clone().requires_grad_(True)
|
191 |
+
pred = model.vae_decode(pred_x0_rg)
|
192 |
+
assert pred.requires_grad
|
193 |
+
else:
|
194 |
+
raise NotImplementedError(cond_fn.space)
|
195 |
+
# compute gradient
|
196 |
+
delta_pred, loss_val = cond_fn(target, pred, t_now)
|
197 |
+
loss_vals.append(loss_val)
|
198 |
+
# update pred_x0 w.r.t gradient
|
199 |
+
if cond_fn.space == "latent":
|
200 |
+
delta_pred_x0 = delta_pred
|
201 |
+
pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale
|
202 |
+
elif cond_fn.space == "rgb":
|
203 |
+
pred.backward(delta_pred)
|
204 |
+
delta_pred_x0 = pred_x0_rg.grad
|
205 |
+
pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale
|
206 |
+
else:
|
207 |
+
raise NotImplementedError(cond_fn.space)
|
208 |
+
self.context["g_apply"] = True
|
209 |
+
self.context["g_loss"] = float(np.mean(loss_vals))
|
210 |
+
return pred_x0
|
211 |
+
|
212 |
+
def predict_noise(
|
213 |
+
self,
|
214 |
+
model: ControlLDM,
|
215 |
+
x: torch.Tensor,
|
216 |
+
t: torch.Tensor,
|
217 |
+
cond: Dict[str, torch.Tensor],
|
218 |
+
uncond: Optional[Dict[str, torch.Tensor]],
|
219 |
+
cfg_scale: float
|
220 |
+
) -> torch.Tensor:
|
221 |
+
if uncond is None or cfg_scale == 1.:
|
222 |
+
model_output = model(x, t, cond)
|
223 |
+
else:
|
224 |
+
# apply classifier-free guidance
|
225 |
+
model_cond = model(x, t, cond)
|
226 |
+
model_uncond = model(x, t, uncond)
|
227 |
+
model_output = model_uncond + cfg_scale * (model_cond - model_uncond)
|
228 |
+
return model_output
|
229 |
+
|
230 |
+
@torch.no_grad()
|
231 |
+
def predict_noise_tiled(
|
232 |
+
self,
|
233 |
+
model: ControlLDM,
|
234 |
+
x: torch.Tensor,
|
235 |
+
t: torch.Tensor,
|
236 |
+
cond: Dict[str, torch.Tensor],
|
237 |
+
uncond: Optional[Dict[str, torch.Tensor]],
|
238 |
+
cfg_scale: float,
|
239 |
+
tile_size: int,
|
240 |
+
tile_stride: int
|
241 |
+
):
|
242 |
+
_, _, h, w = x.shape
|
243 |
+
tiles = tqdm(sliding_windows(h, w, tile_size // 8, tile_stride // 8), unit="tile", leave=False)
|
244 |
+
eps = torch.zeros_like(x)
|
245 |
+
count = torch.zeros_like(x, dtype=torch.float32)
|
246 |
+
weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None]
|
247 |
+
weights = torch.tensor(weights, dtype=torch.float32, device=x.device)
|
248 |
+
for hi, hi_end, wi, wi_end in tiles:
|
249 |
+
tiles.set_description(f"Process tile ({hi} {hi_end}), ({wi} {wi_end})")
|
250 |
+
tile_x = x[:, :, hi:hi_end, wi:wi_end]
|
251 |
+
tile_cond = {
|
252 |
+
"c_img": cond["c_img"][:, :, hi:hi_end, wi:wi_end],
|
253 |
+
"c_txt": cond["c_txt"]
|
254 |
+
}
|
255 |
+
if uncond:
|
256 |
+
tile_uncond = {
|
257 |
+
"c_img": uncond["c_img"][:, :, hi:hi_end, wi:wi_end],
|
258 |
+
"c_txt": uncond["c_txt"]
|
259 |
+
}
|
260 |
+
tile_eps = self.predict_noise(model, tile_x, t, tile_cond, tile_uncond, cfg_scale)
|
261 |
+
# accumulate noise
|
262 |
+
eps[:, :, hi:hi_end, wi:wi_end] += tile_eps * weights
|
263 |
+
count[:, :, hi:hi_end, wi:wi_end] += weights
|
264 |
+
# average on noise (score)
|
265 |
+
eps.div_(count)
|
266 |
+
return eps
|
267 |
+
|
268 |
+
@torch.no_grad()
|
269 |
+
def p_sample(
|
270 |
+
self,
|
271 |
+
model: ControlLDM,
|
272 |
+
x: torch.Tensor,
|
273 |
+
t: torch.Tensor,
|
274 |
+
index: torch.Tensor,
|
275 |
+
cond: Dict[str, torch.Tensor],
|
276 |
+
uncond: Optional[Dict[str, torch.Tensor]],
|
277 |
+
cfg_scale: float,
|
278 |
+
cond_fn: Optional[Guidance],
|
279 |
+
tiled: bool,
|
280 |
+
tile_size: int,
|
281 |
+
tile_stride: int
|
282 |
+
) -> torch.Tensor:
|
283 |
+
if tiled:
|
284 |
+
eps = self.predict_noise_tiled(model, x, t, cond, uncond, cfg_scale, tile_size, tile_stride)
|
285 |
+
else:
|
286 |
+
eps = self.predict_noise(model, x, t, cond, uncond, cfg_scale)
|
287 |
+
pred_x0 = self._predict_xstart_from_eps(x, index, eps)
|
288 |
+
if cond_fn:
|
289 |
+
assert not tiled, f"tiled sampling currently doesn't support guidance"
|
290 |
+
pred_x0 = self.apply_cond_fn(model, pred_x0, t, index, cond_fn)
|
291 |
+
model_mean, model_variance, _ = self.q_posterior_mean_variance(pred_x0, x, index)
|
292 |
+
noise = torch.randn_like(x)
|
293 |
+
nonzero_mask = (
|
294 |
+
(index != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
295 |
+
)
|
296 |
+
x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise
|
297 |
+
return x_prev
|
298 |
+
|
299 |
+
@torch.no_grad()
|
300 |
+
def sample(
|
301 |
+
self,
|
302 |
+
model: ControlLDM,
|
303 |
+
device: str,
|
304 |
+
steps: int,
|
305 |
+
batch_size: int,
|
306 |
+
x_size: Tuple[int],
|
307 |
+
cond: Dict[str, torch.Tensor],
|
308 |
+
uncond: Dict[str, torch.Tensor],
|
309 |
+
cfg_scale: float,
|
310 |
+
cond_fn: Optional[Guidance]=None,
|
311 |
+
tiled: bool=False,
|
312 |
+
tile_size: int=-1,
|
313 |
+
tile_stride: int=-1,
|
314 |
+
x_T: Optional[torch.Tensor]=None,
|
315 |
+
progress: bool=True,
|
316 |
+
progress_leave: bool=True,
|
317 |
+
) -> torch.Tensor:
|
318 |
+
self.make_schedule(steps)
|
319 |
+
self.to(device)
|
320 |
+
if x_T is None:
|
321 |
+
# TODO: not convert to float32, may trigger an error
|
322 |
+
img = torch.randn((batch_size, *x_size), device=device)
|
323 |
+
else:
|
324 |
+
img = x_T
|
325 |
+
timesteps = np.flip(self.timesteps) # [1000, 950, 900, ...]
|
326 |
+
total_steps = len(self.timesteps)
|
327 |
+
iterator = tqdm(timesteps, total=total_steps, leave=progress_leave, disable=not progress)
|
328 |
+
for i, step in enumerate(iterator):
|
329 |
+
ts = torch.full((batch_size,), step, device=device, dtype=torch.long)
|
330 |
+
index = torch.full_like(ts, fill_value=total_steps - i - 1)
|
331 |
+
img = self.p_sample(
|
332 |
+
model, img, ts, index, cond, uncond, cfg_scale, cond_fn,
|
333 |
+
tiled, tile_size, tile_stride
|
334 |
+
)
|
335 |
+
if cond_fn and self.context["g_apply"]:
|
336 |
+
loss_val = self.context["g_loss"]
|
337 |
+
desc = f"Spaced Sampler With Guidance, Loss: {loss_val:.6f}"
|
338 |
+
else:
|
339 |
+
desc = "Spaced Sampler"
|
340 |
+
iterator.set_description(desc)
|
341 |
+
return img
|