Spaces:
Running
on
Zero
Running
on
Zero
Upload 14 files
Browse files- KandiSuperRes/__init__.py +157 -0
- KandiSuperRes/model/__init__.py +0 -0
- KandiSuperRes/model/diffusion_refine.py +131 -0
- KandiSuperRes/model/diffusion_sr.py +146 -0
- KandiSuperRes/model/diffusion_sr_turbo.py +87 -0
- KandiSuperRes/model/nn.py +122 -0
- KandiSuperRes/model/unet.py +284 -0
- KandiSuperRes/model/unet_sr.py +260 -0
- KandiSuperRes/model/utils.py +62 -0
- KandiSuperRes/movq.py +541 -0
- KandiSuperRes/sr_pipeline.py +116 -0
- KandiSuperRes/utils.py +9 -0
- weights/context.pt +3 -0
- weights/context_mask.pt +3 -0
KandiSuperRes/__init__.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional, Union
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
|
5 |
+
from .sr_pipeline import KandiSuperResPipeline
|
6 |
+
from KandiSuperRes.model.unet import UNet
|
7 |
+
from KandiSuperRes.model.unet_sr import UNet as UNet_sr
|
8 |
+
from KandiSuperRes.movq import MoVQ
|
9 |
+
|
10 |
+
|
11 |
+
def get_sr_model(
|
12 |
+
device: Union[str, torch.device],
|
13 |
+
weights_path: Optional[str] = None,
|
14 |
+
dtype: Union[str, torch.dtype] = torch.float16
|
15 |
+
) -> (UNet_sr, Optional[dict], Optional[torch.Tensor]):
|
16 |
+
unet = UNet_sr(
|
17 |
+
init_channels=128,
|
18 |
+
model_channels=128,
|
19 |
+
num_channels=3,
|
20 |
+
time_embed_dim=512,
|
21 |
+
groups=32,
|
22 |
+
dim_mult=(1, 2, 4, 8),
|
23 |
+
num_resnet_blocks=(2,4,8,8),
|
24 |
+
add_cross_attention=(False, False, False, False),
|
25 |
+
add_self_attention=(False, False, False, False),
|
26 |
+
feature_pooling_type='attention',
|
27 |
+
lowres_cond =True
|
28 |
+
)
|
29 |
+
|
30 |
+
if weights_path:
|
31 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
32 |
+
try:
|
33 |
+
unet.load_state_dict(state_dict['unet'])
|
34 |
+
except:
|
35 |
+
unet.load_state_dict(state_dict)
|
36 |
+
unet.to(device=device, dtype=dtype).eval()
|
37 |
+
return unet
|
38 |
+
|
39 |
+
|
40 |
+
def get_T2I_unet(
|
41 |
+
device: Union[str, torch.device],
|
42 |
+
weights_path: Optional[str] = None,
|
43 |
+
dtype: Union[str, torch.dtype] = torch.float32,
|
44 |
+
) -> (UNet, Optional[torch.Tensor], Optional[dict]):
|
45 |
+
unet = UNet(
|
46 |
+
model_channels=384,
|
47 |
+
num_channels=4,
|
48 |
+
init_channels=192,
|
49 |
+
time_embed_dim=1536,
|
50 |
+
context_dim=4096,
|
51 |
+
groups=32,
|
52 |
+
head_dim=64,
|
53 |
+
expansion_ratio=4,
|
54 |
+
compression_ratio=2,
|
55 |
+
dim_mult=(1, 2, 4, 8),
|
56 |
+
num_blocks=(3, 3, 3, 3),
|
57 |
+
add_cross_attention=(False, True, True, True),
|
58 |
+
add_self_attention=(False, True, True, True),
|
59 |
+
)
|
60 |
+
|
61 |
+
null_embedding = None
|
62 |
+
if weights_path:
|
63 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
64 |
+
null_embedding = state_dict['null_embedding']
|
65 |
+
unet.load_state_dict(state_dict['unet'])
|
66 |
+
|
67 |
+
unet.to(device=device, dtype=dtype).eval()
|
68 |
+
return unet, null_embedding
|
69 |
+
|
70 |
+
|
71 |
+
def get_movq(
|
72 |
+
device: Union[str, torch.device],
|
73 |
+
weights_path: Optional[str] = None,
|
74 |
+
dtype: Union[str, torch.dtype] = torch.float32,
|
75 |
+
) -> MoVQ:
|
76 |
+
generator_config = {
|
77 |
+
'double_z': False,
|
78 |
+
'z_channels': 4,
|
79 |
+
'resolution': 256,
|
80 |
+
'in_channels': 3,
|
81 |
+
'out_ch': 3,
|
82 |
+
'ch': 256,
|
83 |
+
'ch_mult': [1, 2, 2, 4],
|
84 |
+
'num_res_blocks': 2,
|
85 |
+
'attn_resolutions': [32],
|
86 |
+
'dropout': 0.0,
|
87 |
+
'tile_sample_min_size': 1024,
|
88 |
+
'tile_overlap_factor_enc': 0.0,
|
89 |
+
'tile_overlap_factor_dec': 0.25,
|
90 |
+
'use_tiling': True
|
91 |
+
}
|
92 |
+
movq = MoVQ(generator_config)
|
93 |
+
|
94 |
+
if weights_path:
|
95 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
96 |
+
movq.load_state_dict(state_dict)
|
97 |
+
|
98 |
+
movq.to(device=device, dtype=dtype).eval()
|
99 |
+
return movq
|
100 |
+
|
101 |
+
|
102 |
+
def get_SR_pipeline(
|
103 |
+
device: Union[str, torch.device],
|
104 |
+
fp16: bool = True,
|
105 |
+
flash: bool = True,
|
106 |
+
scale: int = 2,
|
107 |
+
cache_dir: str = '/tmp/KandiSuperRes/',
|
108 |
+
movq_path: str = None,
|
109 |
+
refiner_path: str = None,
|
110 |
+
unet_sr_path: str = None,
|
111 |
+
) -> KandiSuperResPipeline:
|
112 |
+
|
113 |
+
if flash:
|
114 |
+
if scale == 2:
|
115 |
+
device_map = {
|
116 |
+
'movq': device, 'refiner': device, 'sr_model': device
|
117 |
+
}
|
118 |
+
dtype = torch.float16 if fp16 else torch.float32
|
119 |
+
dtype_map = {
|
120 |
+
'movq': torch.float32, 'refiner': dtype, 'sr_model': dtype
|
121 |
+
}
|
122 |
+
if movq_path is None:
|
123 |
+
print('Download movq weights')
|
124 |
+
movq_path = hf_hub_download(
|
125 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
|
126 |
+
)
|
127 |
+
if refiner_path is None:
|
128 |
+
print('Download refiner weights')
|
129 |
+
refiner_path = hf_hub_download(
|
130 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_flash.pt', cache_dir=cache_dir
|
131 |
+
)
|
132 |
+
if unet_sr_path is None:
|
133 |
+
print('Download KandiSuperRes Flash weights')
|
134 |
+
unet_sr_path = hf_hub_download(
|
135 |
+
repo_id="ai-forever/KandiSuperRes", filename='KandiSuperRes_flash_x2.pt', cache_dir=cache_dir
|
136 |
+
)
|
137 |
+
sr_model = get_sr_model(device_map['sr_model'], unet_sr_path, dtype=dtype_map['sr_model'])
|
138 |
+
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
|
139 |
+
refiner, _ = get_T2I_unet(device_map['refiner'], refiner_path, dtype=dtype_map['refiner'])
|
140 |
+
return KandiSuperResPipeline(
|
141 |
+
scale, device_map, dtype_map, flash, sr_model, movq, refiner
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
print('Flash model for x4 scale is not implemented.')
|
145 |
+
else:
|
146 |
+
if unet_sr_path is None:
|
147 |
+
if scale == 4:
|
148 |
+
unet_sr_path = hf_hub_download(
|
149 |
+
repo_id="ai-forever/KandiSuperRes", filename='KandiSuperRes.ckpt', cache_dir=cache_dir
|
150 |
+
)
|
151 |
+
elif scale == 2:
|
152 |
+
unet_sr_path = hf_hub_download(
|
153 |
+
repo_id="ai-forever/KandiSuperRes", filename='KandiSuperRes_x2.ckpt', cache_dir=cache_dir
|
154 |
+
)
|
155 |
+
dtype = torch.float16 if fp16 else torch.float32
|
156 |
+
sr_model = get_sr_model(device, unet_sr_path, dtype=dtype)
|
157 |
+
return KandiSuperResPipeline(scale, device, dtype, flash, sr_model)
|
KandiSuperRes/model/__init__.py
ADDED
File without changes
|
KandiSuperRes/model/diffusion_refine.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from .utils import get_tensor_items
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def get_named_beta_schedule(schedule_name, timesteps):
|
9 |
+
if schedule_name == "linear":
|
10 |
+
scale = 1000 / timesteps
|
11 |
+
beta_start = scale * 0.0001
|
12 |
+
beta_end = scale * 0.02
|
13 |
+
return torch.linspace(
|
14 |
+
beta_start, beta_end, timesteps, dtype=torch.float32
|
15 |
+
)
|
16 |
+
elif schedule_name == "cosine":
|
17 |
+
alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
18 |
+
betas = []
|
19 |
+
for i in range(timesteps):
|
20 |
+
t1 = i / timesteps
|
21 |
+
t2 = (i + 1) / timesteps
|
22 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999))
|
23 |
+
return torch.tensor(betas, dtype=torch.float32)
|
24 |
+
|
25 |
+
|
26 |
+
class BaseDiffusion:
|
27 |
+
|
28 |
+
def __init__(self, betas, percentile=None, gen_noise=torch.randn_like):
|
29 |
+
self.betas = betas
|
30 |
+
self.num_timesteps = betas.shape[0]
|
31 |
+
|
32 |
+
alphas = 1. - betas
|
33 |
+
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
34 |
+
self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]])
|
35 |
+
|
36 |
+
# calculate q(x_t | x_{t-1})
|
37 |
+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
38 |
+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
|
39 |
+
|
40 |
+
self.time_scale = 1000 // self.num_timesteps
|
41 |
+
self.gen_noise = gen_noise
|
42 |
+
|
43 |
+
def get_x_start(self, x, t, noise):
|
44 |
+
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
|
45 |
+
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, noise.shape)
|
46 |
+
pred_x_start = (x - sqrt_one_minus_alphas_cumprod * noise) / sqrt_alphas_cumprod
|
47 |
+
return pred_x_start
|
48 |
+
|
49 |
+
def q_sample(self, x_start, t, noise=None):
|
50 |
+
if noise is None:
|
51 |
+
noise = self.gen_noise(x_start)
|
52 |
+
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
|
53 |
+
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
|
54 |
+
x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise
|
55 |
+
return x_t
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
def refine(self, model, img, context, context_mask):
|
59 |
+
# for time in tqdm([479, 229]):
|
60 |
+
for time in [229]:
|
61 |
+
time = torch.tensor([time,] * img.shape[0], device=img.device)
|
62 |
+
x_t = self.q_sample(img, time)
|
63 |
+
pred_noise = model(x_t, time.type(x_t.dtype), context, context_mask.bool())
|
64 |
+
img = self.get_x_start(x_t, time, pred_noise)
|
65 |
+
return img
|
66 |
+
|
67 |
+
def blend_v(
|
68 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
69 |
+
) -> torch.Tensor:
|
70 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
71 |
+
for y in range(blend_extent):
|
72 |
+
b[ :, :, y, :] = a[ :, :, -blend_extent + y, :] * (
|
73 |
+
1 - y / blend_extent
|
74 |
+
) + b[ :, :, y, :] * (y / blend_extent)
|
75 |
+
return b
|
76 |
+
|
77 |
+
def blend_h(
|
78 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
79 |
+
) -> torch.Tensor:
|
80 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
81 |
+
for x in range(blend_extent):
|
82 |
+
b[ :, :, :, x] = a[ :, :, :, -blend_extent + x] * (
|
83 |
+
1 - x / blend_extent
|
84 |
+
) + b[ :, :, :, x] * (x / blend_extent)
|
85 |
+
return b
|
86 |
+
|
87 |
+
|
88 |
+
def refine_tiled(self, model, img, context, context_mask):
|
89 |
+
tile_sample_min_size = 352
|
90 |
+
tile_overlap_factor = 0.25
|
91 |
+
|
92 |
+
overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor))
|
93 |
+
tile_latent_min_size = int(tile_sample_min_size)
|
94 |
+
blend_extent = int(tile_latent_min_size * tile_overlap_factor)
|
95 |
+
row_limit = tile_latent_min_size - blend_extent
|
96 |
+
|
97 |
+
# Split the image into tiles and encode them separately.
|
98 |
+
rows = []
|
99 |
+
for i in tqdm(range(0, img.shape[2], overlap_size)):
|
100 |
+
row = []
|
101 |
+
for j in range(0, img.shape[3], overlap_size):
|
102 |
+
tile = img[
|
103 |
+
:,
|
104 |
+
:,
|
105 |
+
i : i + tile_sample_min_size,
|
106 |
+
j : j + tile_sample_min_size,
|
107 |
+
]
|
108 |
+
tile = self.refine(model, tile, context, context_mask)
|
109 |
+
row.append(tile)
|
110 |
+
rows.append(row)
|
111 |
+
result_rows = []
|
112 |
+
for i, row in enumerate(rows):
|
113 |
+
result_row = []
|
114 |
+
for j, tile in enumerate(row):
|
115 |
+
# blend the above tile and the left tile
|
116 |
+
# to the current tile and add the current tile to the result row
|
117 |
+
if i > 0:
|
118 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
119 |
+
if j > 0:
|
120 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
121 |
+
result_row.append(tile[ :, :, :row_limit, :row_limit])
|
122 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
123 |
+
|
124 |
+
refine_img = torch.cat(result_rows, dim=2)
|
125 |
+
return refine_img
|
126 |
+
|
127 |
+
|
128 |
+
def get_diffusion(conf):
|
129 |
+
betas = get_named_beta_schedule(**conf.schedule_params)
|
130 |
+
base_diffusion = BaseDiffusion(betas, **conf.diffusion_params)
|
131 |
+
return base_diffusion
|
KandiSuperRes/model/diffusion_sr.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler
|
2 |
+
from einops import repeat
|
3 |
+
import copy
|
4 |
+
import inspect
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
class DPMSolver:
|
12 |
+
|
13 |
+
def __init__(self, num_timesteps):
|
14 |
+
self.dpm_solver = DPMSolverMultistepScheduler(
|
15 |
+
beta_schedule="linear",
|
16 |
+
prediction_type= "sample",
|
17 |
+
# algorithm_type="sde-dpmsolver++",
|
18 |
+
thresholding=False
|
19 |
+
)
|
20 |
+
self.dpm_solver.set_timesteps(num_timesteps)
|
21 |
+
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def pred_noise(self, model, x, t, lowres_img, dtype):
|
25 |
+
pred_noise = model(x.to(dtype), t.to(dtype), lowres_img=lowres_img.to(dtype))
|
26 |
+
pred_noise = pred_noise.to(dtype=torch.float32)
|
27 |
+
return pred_noise
|
28 |
+
|
29 |
+
|
30 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
31 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
32 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
33 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
34 |
+
# and should be between [0, 1]
|
35 |
+
|
36 |
+
accepts_eta = "eta" in set(inspect.signature(self.dpm_solver.step).parameters.keys())
|
37 |
+
extra_step_kwargs = {}
|
38 |
+
if accepts_eta:
|
39 |
+
extra_step_kwargs["eta"] = eta
|
40 |
+
|
41 |
+
# check if the scheduler accepts generator
|
42 |
+
accepts_generator = "generator" in set(inspect.signature(self.dpm_solver.step).parameters.keys())
|
43 |
+
if accepts_generator:
|
44 |
+
extra_step_kwargs["generator"] = generator
|
45 |
+
return extra_step_kwargs
|
46 |
+
|
47 |
+
|
48 |
+
def get_views(self, panorama_height, panorama_width, window_size=1024, stride=800):
|
49 |
+
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
50 |
+
# if panorama's height/width < window_size, num_blocks of height/width should return 1
|
51 |
+
num_blocks_height = round(math.ceil((panorama_height - window_size) / stride)) + 1 if panorama_height > window_size else 1
|
52 |
+
num_blocks_width = round(math.ceil((panorama_width - window_size) / stride)) + 1 if panorama_width > window_size else 1
|
53 |
+
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
54 |
+
views = []
|
55 |
+
for i in range(total_num_blocks):
|
56 |
+
h_start = int((i // num_blocks_width) * stride)
|
57 |
+
h_end = h_start + window_size
|
58 |
+
if h_end > panorama_height and num_blocks_height > 1:
|
59 |
+
h_end = panorama_height
|
60 |
+
h_start = panorama_height - window_size
|
61 |
+
w_start = int((i % num_blocks_width) * stride)
|
62 |
+
w_end = w_start + window_size
|
63 |
+
if w_end > panorama_width and num_blocks_width > 1:
|
64 |
+
w_end = panorama_width
|
65 |
+
w_start = panorama_width - window_size
|
66 |
+
views.append((h_start, h_end, w_start, w_end))
|
67 |
+
return views
|
68 |
+
|
69 |
+
|
70 |
+
def generate_panorama(self, height, width, device, dtype, num_inference_steps,
|
71 |
+
unet, lowres_img, view_batch_size=15, eta=0, seed=0):
|
72 |
+
# 6. Define panorama grid and initialize views for synthesis.
|
73 |
+
# prepare batch grid
|
74 |
+
views = self.get_views(height, width)
|
75 |
+
views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
|
76 |
+
views_scheduler_status = [copy.deepcopy(self.dpm_solver.__dict__)] * len(views_batch)
|
77 |
+
|
78 |
+
shape = (1, 3, height, width)
|
79 |
+
count = torch.zeros(*shape, device=device)
|
80 |
+
value = torch.zeros(*shape, device=device)
|
81 |
+
|
82 |
+
generator = torch.Generator(device=device)
|
83 |
+
if seed is not None:
|
84 |
+
generator = generator.manual_seed(seed)
|
85 |
+
|
86 |
+
img = torch.randn(*shape, device=device, generator=generator)
|
87 |
+
up_lowres_img = F.interpolate(lowres_img, (shape[2], shape[3]), mode="bilinear")
|
88 |
+
|
89 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
90 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
91 |
+
|
92 |
+
# 8. Denoising loop
|
93 |
+
# Each denoising step also includes refinement of the latents with respect to the
|
94 |
+
# views.
|
95 |
+
timesteps = self.dpm_solver.timesteps
|
96 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.dpm_solver.order
|
97 |
+
|
98 |
+
for i, time in tqdm(enumerate(self.dpm_solver.timesteps)):
|
99 |
+
count.zero_()
|
100 |
+
value.zero_()
|
101 |
+
|
102 |
+
# generate views
|
103 |
+
# Here, we iterate through different spatial crops of the latents and denoise them. These
|
104 |
+
# denoised (latent) crops are then averaged to produce the final latent
|
105 |
+
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
|
106 |
+
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
|
107 |
+
# Batch views denoise
|
108 |
+
for j, batch_view in enumerate(views_batch):
|
109 |
+
vb_size = len(batch_view)
|
110 |
+
# get the latents corresponding to the current view coordinates
|
111 |
+
img_for_view = torch.cat(
|
112 |
+
[
|
113 |
+
img[:, :, h_start:h_end, w_start:w_end]
|
114 |
+
for h_start, h_end, w_start, w_end in batch_view
|
115 |
+
]
|
116 |
+
)
|
117 |
+
lowres_img_for_view = torch.cat(
|
118 |
+
[
|
119 |
+
up_lowres_img[:, :, h_start:h_end, w_start:w_end]
|
120 |
+
for h_start, h_end, w_start, w_end in batch_view
|
121 |
+
]
|
122 |
+
)
|
123 |
+
|
124 |
+
# rematch block's scheduler status
|
125 |
+
self.dpm_solver.__dict__.update(views_scheduler_status[j])
|
126 |
+
|
127 |
+
t = torch.tensor([time] * img_for_view.shape[0], device=device)
|
128 |
+
pred_noise = self.pred_noise(
|
129 |
+
unet, img_for_view, t, lowres_img_for_view, dtype
|
130 |
+
)
|
131 |
+
img_denoised_batch = self.dpm_solver.step(pred_noise, time, img_for_view, **extra_step_kwargs).prev_sample
|
132 |
+
|
133 |
+
# save views scheduler status after sample
|
134 |
+
views_scheduler_status[j] = copy.deepcopy(self.dpm_solver.__dict__)
|
135 |
+
|
136 |
+
# extract value from batch
|
137 |
+
for img_view_denoised, (h_start, h_end, w_start, w_end) in zip(
|
138 |
+
img_denoised_batch.chunk(vb_size), batch_view
|
139 |
+
):
|
140 |
+
value[:, :, h_start:h_end, w_start:w_end] += img_view_denoised
|
141 |
+
count[:, :, h_start:h_end, w_start:w_end] += 1
|
142 |
+
|
143 |
+
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
144 |
+
img = torch.where(count > 0, value / count, value)
|
145 |
+
|
146 |
+
return img
|
KandiSuperRes/model/diffusion_sr_turbo.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from .utils import get_tensor_items, exist
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
def get_named_beta_schedule(schedule_name, timesteps):
|
12 |
+
if schedule_name == "linear":
|
13 |
+
scale = 1000 / timesteps
|
14 |
+
beta_start = scale * 0.0001
|
15 |
+
beta_end = scale * 0.02
|
16 |
+
return torch.linspace(
|
17 |
+
beta_start, beta_end, timesteps, dtype=torch.float32
|
18 |
+
)
|
19 |
+
elif schedule_name == "cosine":
|
20 |
+
alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
21 |
+
betas = []
|
22 |
+
for i in range(timesteps):
|
23 |
+
t1 = i / timesteps
|
24 |
+
t2 = (i + 1) / timesteps
|
25 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999))
|
26 |
+
return torch.tensor(betas, dtype=torch.float32)
|
27 |
+
|
28 |
+
|
29 |
+
class BaseDiffusion:
|
30 |
+
|
31 |
+
def __init__(self, betas, percentile=None, gen_noise=torch.randn_like):
|
32 |
+
self.betas = betas
|
33 |
+
self.num_timesteps = betas.shape[0]
|
34 |
+
|
35 |
+
alphas = 1. - betas
|
36 |
+
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
37 |
+
self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]])
|
38 |
+
|
39 |
+
# calculate q(x_t | x_{t-1})
|
40 |
+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
41 |
+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
|
42 |
+
|
43 |
+
# calculate q(x_{t-1} | x_t, x_0)
|
44 |
+
self.posterior_mean_coef_1 = (torch.sqrt(self.alphas_cumprod_prev) * betas / (1. - self.alphas_cumprod))
|
45 |
+
self.posterior_mean_coef_2 = (torch.sqrt(alphas) * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod))
|
46 |
+
self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
|
47 |
+
self.posterior_log_variance = (torch.log(
|
48 |
+
torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]])
|
49 |
+
))
|
50 |
+
|
51 |
+
self.percentile = percentile
|
52 |
+
self.time_scale = 1000 // self.num_timesteps
|
53 |
+
self.gen_noise = gen_noise
|
54 |
+
|
55 |
+
def q_sample(self, x_start, t, noise=None):
|
56 |
+
if noise is None:
|
57 |
+
noise = self.gen_noise(x_start)
|
58 |
+
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
|
59 |
+
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
|
60 |
+
x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise
|
61 |
+
return x_t
|
62 |
+
|
63 |
+
@torch.no_grad()
|
64 |
+
def p_sample_loop(
|
65 |
+
self, model, shape, device, dtype, lowres_img, times=[979, 729, 479, 229]
|
66 |
+
):
|
67 |
+
img = torch.randn(*shape, device=device).to(dtype=dtype)
|
68 |
+
times = times + [0,]
|
69 |
+
times = list(zip(times[:-1], times[1:]))
|
70 |
+
for time, prev_time in tqdm(times):
|
71 |
+
time = torch.tensor([time] * shape[0], device=device)
|
72 |
+
x_t = self.q_sample(img, time)
|
73 |
+
img = model(x_t.to(dtype), time.to(dtype), lowres_img=lowres_img.to(dtype))
|
74 |
+
return img
|
75 |
+
|
76 |
+
@torch.no_grad()
|
77 |
+
def refine(self, model, img, **large_model_kwargs):
|
78 |
+
for time in tqdm([729, 479, 229]):
|
79 |
+
time = torch.tensor([time,] * img.shape[0], device=img.device)
|
80 |
+
x_t = self.q_sample(img, time)
|
81 |
+
img = model(x_t, time.type(x_t.dtype), **large_model_kwargs)
|
82 |
+
return img
|
83 |
+
|
84 |
+
def get_diffusion(conf):
|
85 |
+
betas = get_named_beta_schedule(**conf.schedule_params)
|
86 |
+
base_diffusion = BaseDiffusion(betas, **conf.diffusion_params)
|
87 |
+
return base_diffusion
|
KandiSuperRes/model/nn.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn, einsum
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
from .utils import exist, set_default_layer
|
8 |
+
|
9 |
+
|
10 |
+
class Identity(nn.Module):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def forward(x, *args, **kwargs):
|
16 |
+
return x
|
17 |
+
|
18 |
+
|
19 |
+
class SinusoidalPosEmb_sr(nn.Module):
|
20 |
+
|
21 |
+
def __init__(self, dim):
|
22 |
+
super().__init__()
|
23 |
+
self.dim = dim
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
half_dim = self.dim // 2
|
27 |
+
emb = math.log(10000) / (half_dim - 1)
|
28 |
+
emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
|
29 |
+
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j').to(dtype=x.dtype)
|
30 |
+
return torch.cat((emb.sin(), emb.cos()), dim=-1)
|
31 |
+
|
32 |
+
|
33 |
+
class UpDownResolution(nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, num_channels, up_resolution, change_type='conv'):
|
36 |
+
super().__init__()
|
37 |
+
if change_type == 'pooling':
|
38 |
+
self.change_resolution = set_default_layer(
|
39 |
+
up_resolution,
|
40 |
+
layer_1=nn.Upsample, kwargs_1={'scale_factor': 2., 'mode': 'nearest'},
|
41 |
+
layer_2=nn.AvgPool2d, kwargs_2={'kernel_size': 2, 'stride': 2}
|
42 |
+
)
|
43 |
+
|
44 |
+
elif change_type == 'conv':
|
45 |
+
self.change_resolution = set_default_layer(
|
46 |
+
up_resolution,
|
47 |
+
nn.ConvTranspose2d, (num_channels, num_channels), {'kernel_size': 4, 'stride': 2, 'padding': 1},
|
48 |
+
nn.Conv2d, (num_channels, num_channels), {'kernel_size': 4, 'stride': 2, 'padding': 1},
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = self.change_resolution(x)
|
55 |
+
return x
|
56 |
+
|
57 |
+
class SinusoidalPosEmb(nn.Module):
|
58 |
+
|
59 |
+
def __init__(self, dim):
|
60 |
+
super().__init__()
|
61 |
+
self.dim = dim
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
half_dim = self.dim // 2
|
65 |
+
emb = math.log(10000) / (half_dim - 1)
|
66 |
+
emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
|
67 |
+
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
68 |
+
return torch.cat((emb.sin(), emb.cos()), dim=-1)
|
69 |
+
|
70 |
+
|
71 |
+
class ConditionalGroupNorm(nn.Module):
|
72 |
+
|
73 |
+
def __init__(self, groups, normalized_shape, context_dim):
|
74 |
+
super().__init__()
|
75 |
+
self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
|
76 |
+
self.context_mlp = nn.Sequential(
|
77 |
+
nn.SiLU(),
|
78 |
+
nn.Linear(context_dim, 2 * normalized_shape)
|
79 |
+
)
|
80 |
+
self.context_mlp[1].weight.data.zero_()
|
81 |
+
self.context_mlp[1].bias.data.zero_()
|
82 |
+
|
83 |
+
def forward(self, x, context):
|
84 |
+
context = self.context_mlp(context)
|
85 |
+
ndims = ' 1' * len(x.shape[2:])
|
86 |
+
context = rearrange(context, f'b c -> b c{ndims}')
|
87 |
+
|
88 |
+
scale, shift = context.chunk(2, dim=1)
|
89 |
+
x = self.norm(x) * (scale + 1.) + shift
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
class Attention(nn.Module):
|
94 |
+
|
95 |
+
def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
|
96 |
+
super().__init__()
|
97 |
+
assert out_channels % head_dim == 0
|
98 |
+
self.num_heads = out_channels // head_dim
|
99 |
+
self.scale = head_dim ** -0.5
|
100 |
+
|
101 |
+
self.to_query = nn.Linear(in_channels, out_channels, bias=False)
|
102 |
+
self.to_key = nn.Linear(context_dim, out_channels, bias=False)
|
103 |
+
self.to_value = nn.Linear(context_dim, out_channels, bias=False)
|
104 |
+
|
105 |
+
self.output_layer = nn.Linear(out_channels, out_channels, bias=False)
|
106 |
+
|
107 |
+
def forward(self, x, context, context_mask=None):
|
108 |
+
query = rearrange(self.to_query(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
109 |
+
key = rearrange(self.to_key(context), 'b n (h d) -> b h n d', h=self.num_heads)
|
110 |
+
value = rearrange(self.to_value(context), 'b n (h d) -> b h n d', h=self.num_heads)
|
111 |
+
|
112 |
+
attention_matrix = einsum('b h i d, b h j d -> b h i j', query, key) * self.scale
|
113 |
+
if exist(context_mask):
|
114 |
+
max_neg_value = -torch.finfo(attention_matrix.dtype).max
|
115 |
+
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
|
116 |
+
attention_matrix = attention_matrix.masked_fill(~context_mask, max_neg_value)
|
117 |
+
attention_matrix = attention_matrix.softmax(dim=-1)
|
118 |
+
|
119 |
+
out = einsum('b h i j, b h j d -> b h i d', attention_matrix, value)
|
120 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
121 |
+
out = self.output_layer(out)
|
122 |
+
return out
|
KandiSuperRes/model/unet.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
from einops import rearrange
|
4 |
+
|
5 |
+
from .nn import Identity, Attention, SinusoidalPosEmb, ConditionalGroupNorm
|
6 |
+
from .utils import exist, set_default_item, set_default_layer
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class Block(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
|
13 |
+
super().__init__()
|
14 |
+
self.group_norm = ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
|
15 |
+
self.activation = nn.SiLU()
|
16 |
+
self.up_sample = set_default_layer(
|
17 |
+
exist(up_resolution) and up_resolution,
|
18 |
+
nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
|
19 |
+
)
|
20 |
+
padding = set_default_item(kernel_size == 1, 0, 1)
|
21 |
+
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
|
22 |
+
self.down_sample = set_default_layer(
|
23 |
+
exist(up_resolution) and not up_resolution,
|
24 |
+
nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x, time_embed):
|
28 |
+
x = self.group_norm(x, time_embed)
|
29 |
+
x = self.activation(x)
|
30 |
+
x = self.up_sample(x)
|
31 |
+
x = self.projection(x)
|
32 |
+
x = self.down_sample(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class ResNetBlock(nn.Module):
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4*[None]
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
kernel_sizes = [1, 3, 3, 1]
|
43 |
+
hidden_channel = max(in_channels, out_channels) // compression_ratio
|
44 |
+
hidden_channels = [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
|
45 |
+
self.resnet_blocks = nn.ModuleList([
|
46 |
+
Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
|
47 |
+
for (in_channel, out_channel), kernel_size, up_resolution in zip(hidden_channels, kernel_sizes, up_resolutions)
|
48 |
+
])
|
49 |
+
|
50 |
+
self.shortcut_up_sample = set_default_layer(
|
51 |
+
True in up_resolutions,
|
52 |
+
nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
|
53 |
+
)
|
54 |
+
self.shortcut_projection = set_default_layer(
|
55 |
+
in_channels != out_channels,
|
56 |
+
nn.Conv2d, (in_channels, out_channels), {'kernel_size': 1}
|
57 |
+
)
|
58 |
+
self.shortcut_down_sample = set_default_layer(
|
59 |
+
False in up_resolutions,
|
60 |
+
nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, x, time_embed):
|
64 |
+
out = x
|
65 |
+
for resnet_block in self.resnet_blocks:
|
66 |
+
out = resnet_block(out, time_embed)
|
67 |
+
|
68 |
+
x = self.shortcut_up_sample(x)
|
69 |
+
x = self.shortcut_projection(x)
|
70 |
+
x = self.shortcut_down_sample(x)
|
71 |
+
x = x + out
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class AttentionPolling(nn.Module):
|
76 |
+
|
77 |
+
def __init__(self, num_channels, context_dim, head_dim=64):
|
78 |
+
super().__init__()
|
79 |
+
self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
|
80 |
+
|
81 |
+
def forward(self, x, context, context_mask=None):
|
82 |
+
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
|
83 |
+
return x + context.squeeze(1)
|
84 |
+
|
85 |
+
|
86 |
+
class AttentionBlock(nn.Module):
|
87 |
+
|
88 |
+
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
|
89 |
+
super().__init__()
|
90 |
+
self.in_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
|
91 |
+
self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
|
92 |
+
|
93 |
+
hidden_channels = expansion_ratio * num_channels
|
94 |
+
self.out_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
|
95 |
+
self.feed_forward = nn.Sequential(
|
96 |
+
nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
|
97 |
+
nn.SiLU(),
|
98 |
+
nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
|
99 |
+
)
|
100 |
+
|
101 |
+
def forward(self, x, time_embed, context=None, context_mask=None):
|
102 |
+
height, width = x.shape[-2:]
|
103 |
+
out = self.in_norm(x, time_embed)
|
104 |
+
out = rearrange(out, 'b c h w -> b (h w) c', h=height, w=width)
|
105 |
+
context = set_default_item(exist(context), context, out)
|
106 |
+
out = self.attention(out, context, context_mask)
|
107 |
+
out = rearrange(out, 'b (h w) c -> b c h w', h=height, w=width)
|
108 |
+
x = x + out
|
109 |
+
|
110 |
+
out = self.out_norm(x, time_embed)
|
111 |
+
out = self.feed_forward(out)
|
112 |
+
x = x + out
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
class DownSampleBlock(nn.Module):
|
117 |
+
|
118 |
+
def __init__(
|
119 |
+
self, in_channels, out_channels, time_embed_dim, context_dim=None,
|
120 |
+
num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
|
121 |
+
down_sample=True, self_attention=True
|
122 |
+
):
|
123 |
+
super().__init__()
|
124 |
+
self.self_attention_block = set_default_layer(
|
125 |
+
self_attention,
|
126 |
+
AttentionBlock,
|
127 |
+
(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
|
128 |
+
layer_2=Identity
|
129 |
+
)
|
130 |
+
|
131 |
+
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
|
132 |
+
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
|
133 |
+
self.resnet_attn_blocks = nn.ModuleList([
|
134 |
+
nn.ModuleList([
|
135 |
+
ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
|
136 |
+
set_default_layer(
|
137 |
+
exist(context_dim),
|
138 |
+
AttentionBlock,
|
139 |
+
(out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
|
140 |
+
layer_2=Identity
|
141 |
+
),
|
142 |
+
ResNetBlock(out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution),
|
143 |
+
]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
|
144 |
+
])
|
145 |
+
|
146 |
+
def forward(self, x, time_embed, context=None, context_mask=None, control_net_residual=None):
|
147 |
+
x = self.self_attention_block(x, time_embed)
|
148 |
+
for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
|
149 |
+
x = in_resnet_block(x, time_embed)
|
150 |
+
x = attention(x, time_embed, context, context_mask)
|
151 |
+
x = out_resnet_block(x, time_embed)
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class UpSampleBlock(nn.Module):
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self, in_channels, cat_dim, out_channels, time_embed_dim, context_dim=None,
|
159 |
+
num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
|
160 |
+
up_sample=True, self_attention=True
|
161 |
+
):
|
162 |
+
super().__init__()
|
163 |
+
up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
|
164 |
+
hidden_channels = [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_blocks - 2) + [(in_channels, out_channels)]
|
165 |
+
self.resnet_attn_blocks = nn.ModuleList([
|
166 |
+
nn.ModuleList([
|
167 |
+
ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution),
|
168 |
+
set_default_layer(
|
169 |
+
exist(context_dim),
|
170 |
+
AttentionBlock,
|
171 |
+
(in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
|
172 |
+
layer_2=Identity
|
173 |
+
),
|
174 |
+
ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
|
175 |
+
]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
|
176 |
+
])
|
177 |
+
|
178 |
+
self.self_attention_block = set_default_layer(
|
179 |
+
self_attention,
|
180 |
+
AttentionBlock,
|
181 |
+
(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
|
182 |
+
layer_2=Identity
|
183 |
+
)
|
184 |
+
|
185 |
+
def forward(self, x, time_embed, context=None, context_mask=None):
|
186 |
+
for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
|
187 |
+
x = in_resnet_block(x, time_embed)
|
188 |
+
x = attention(x, time_embed, context, context_mask)
|
189 |
+
x = out_resnet_block(x, time_embed)
|
190 |
+
x = self.self_attention_block(x, time_embed)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
class UNet(nn.Module):
|
195 |
+
|
196 |
+
def __init__(self,
|
197 |
+
model_channels,
|
198 |
+
init_channels=None,
|
199 |
+
num_channels=3,
|
200 |
+
out_channels=4,
|
201 |
+
time_embed_dim=None,
|
202 |
+
context_dim=None,
|
203 |
+
groups=32,
|
204 |
+
head_dim=64,
|
205 |
+
expansion_ratio=4,
|
206 |
+
compression_ratio=2,
|
207 |
+
dim_mult=(1, 2, 4, 8),
|
208 |
+
num_blocks=(3, 3, 3, 3),
|
209 |
+
add_cross_attention=(False, True, True, True),
|
210 |
+
add_self_attention=(False, True, True, True),
|
211 |
+
*args,
|
212 |
+
**kwargs,
|
213 |
+
):
|
214 |
+
super().__init__()
|
215 |
+
init_channels = init_channels or model_channels
|
216 |
+
|
217 |
+
self.to_time_embed = nn.Sequential(
|
218 |
+
SinusoidalPosEmb(init_channels),
|
219 |
+
nn.Linear(init_channels, time_embed_dim),
|
220 |
+
nn.SiLU(),
|
221 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
222 |
+
)
|
223 |
+
self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
|
224 |
+
|
225 |
+
self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
|
226 |
+
|
227 |
+
hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
|
228 |
+
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
229 |
+
text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
|
230 |
+
layer_params = [num_blocks, text_dims, add_self_attention]
|
231 |
+
rev_layer_params = map(reversed, layer_params)
|
232 |
+
|
233 |
+
cat_dims = []
|
234 |
+
self.num_levels = len(in_out_dims)
|
235 |
+
self.down_samples = nn.ModuleList([])
|
236 |
+
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
|
237 |
+
down_sample = level != (self.num_levels - 1)
|
238 |
+
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
239 |
+
self.down_samples.append(
|
240 |
+
DownSampleBlock(
|
241 |
+
in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
|
242 |
+
compression_ratio, down_sample, self_attention
|
243 |
+
)
|
244 |
+
)
|
245 |
+
|
246 |
+
self.up_samples = nn.ModuleList([])
|
247 |
+
for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
|
248 |
+
up_sample = level != 0
|
249 |
+
self.up_samples.append(
|
250 |
+
UpSampleBlock(
|
251 |
+
in_dim, cat_dims.pop(), out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim,
|
252 |
+
expansion_ratio, compression_ratio, up_sample, self_attention
|
253 |
+
)
|
254 |
+
)
|
255 |
+
|
256 |
+
self.out_layer = nn.Sequential(
|
257 |
+
nn.GroupNorm(groups, init_channels),
|
258 |
+
nn.SiLU(),
|
259 |
+
nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
|
260 |
+
)
|
261 |
+
|
262 |
+
def forward(self, x, time, context=None, context_mask=None, is_text=None, null_embedding=None, control_net_residual=None):
|
263 |
+
|
264 |
+
time_embed = self.to_time_embed(time)
|
265 |
+
if exist(context):
|
266 |
+
time_embed = self.feature_pooling(time_embed, context, context_mask)
|
267 |
+
|
268 |
+
hidden_states = []
|
269 |
+
x = self.in_layer(x)
|
270 |
+
for level, down_sample in enumerate(self.down_samples):
|
271 |
+
x = down_sample(x, time_embed, context, context_mask, control_net_residual)
|
272 |
+
if level != self.num_levels - 1:
|
273 |
+
hidden_states.append(x)
|
274 |
+
for level, up_sample in enumerate(self.up_samples):
|
275 |
+
if level != 0:
|
276 |
+
x = torch.cat([x, hidden_states.pop()], dim=1)
|
277 |
+
x = up_sample(x, time_embed, context, context_mask)
|
278 |
+
x = self.out_layer(x)
|
279 |
+
return x
|
280 |
+
|
281 |
+
|
282 |
+
def get_unet(conf):
|
283 |
+
unet = UNet(**conf)
|
284 |
+
return unet
|
KandiSuperRes/model/unet_sr.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from einops import rearrange
|
4 |
+
from .nn import Identity, Attention, SinusoidalPosEmb, UpDownResolution
|
5 |
+
from .utils import exist, set_default_item, set_default_layer
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Block(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, in_channels, out_channels, time_embed_dim=None, groups=32, activation=None, up_resolution=None, dropout=None):
|
12 |
+
super().__init__()
|
13 |
+
self.group_norm = nn.GroupNorm(groups, in_channels)
|
14 |
+
self.activation = set_default_layer(
|
15 |
+
exist(activation),
|
16 |
+
nn.SiLU
|
17 |
+
)
|
18 |
+
self.change_resolution = set_default_layer(
|
19 |
+
exist(up_resolution),
|
20 |
+
UpDownResolution, (in_channels, up_resolution)
|
21 |
+
)
|
22 |
+
self.dropout = set_default_layer(
|
23 |
+
exist(dropout),
|
24 |
+
nn.Dropout, (), {'p': 0.1}
|
25 |
+
)
|
26 |
+
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
27 |
+
|
28 |
+
def forward(self, x, scale_shift=None):
|
29 |
+
x = self.group_norm(x)
|
30 |
+
if exist(scale_shift):
|
31 |
+
scale, shift = scale_shift
|
32 |
+
x = x * (scale + 1) + shift
|
33 |
+
x = self.activation(x)
|
34 |
+
x = self.dropout(x)
|
35 |
+
x = self.change_resolution(x)
|
36 |
+
x = self.projection(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class ResNetBlock(nn.Module):
|
41 |
+
|
42 |
+
def __init__(self, in_channels, out_channels, time_embed_dim=None, groups=32, up_resolution=None):
|
43 |
+
super().__init__()
|
44 |
+
self.time_mlp = set_default_item(
|
45 |
+
exist(time_embed_dim),
|
46 |
+
nn.Sequential(
|
47 |
+
nn.SiLU(),
|
48 |
+
nn.Linear(time_embed_dim, 2 * out_channels)
|
49 |
+
)
|
50 |
+
)
|
51 |
+
self.in_block = Block(in_channels, out_channels, time_embed_dim, groups, up_resolution=up_resolution)
|
52 |
+
self.out_block = Block(out_channels, out_channels, time_embed_dim, groups, activation=True, up_resolution=None, dropout=True)
|
53 |
+
|
54 |
+
self.change_resolution = set_default_layer(
|
55 |
+
exist(up_resolution),
|
56 |
+
UpDownResolution, (in_channels, up_resolution)
|
57 |
+
)
|
58 |
+
self.res_block = set_default_layer(
|
59 |
+
in_channels != out_channels or exist(up_resolution),
|
60 |
+
nn.Conv2d, (in_channels, out_channels), {'kernel_size': 1}
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, x, time_embed=None):
|
64 |
+
scale_shift = None
|
65 |
+
if exist(time_embed) and exist(self.time_mlp):
|
66 |
+
time_embed = self.time_mlp(time_embed)
|
67 |
+
time_embed = rearrange(time_embed, 'b c -> b c 1 1')
|
68 |
+
scale_shift = time_embed.chunk(2, dim=1)
|
69 |
+
out = self.in_block(x)
|
70 |
+
out = self.out_block(out, scale_shift=scale_shift)
|
71 |
+
x = self.change_resolution(x)
|
72 |
+
out = out + self.res_block(x)
|
73 |
+
return out
|
74 |
+
|
75 |
+
|
76 |
+
class AttentionBlock(nn.Module):
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self, dim, context_dim=None, groups=32, num_heads=8, num_conditions=1, feed_forward_mult=2
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
self.in_norm = nn.GroupNorm(groups, dim)
|
83 |
+
self.attention = Attention(
|
84 |
+
dim, context_dim or dim, num_heads, num_conditions=num_conditions
|
85 |
+
)
|
86 |
+
|
87 |
+
hidden_dim = feed_forward_mult * dim
|
88 |
+
self.out_norm = nn.GroupNorm(groups, dim)
|
89 |
+
self.feed_forward = nn.Sequential(
|
90 |
+
nn.Conv2d(dim, hidden_dim, kernel_size=1, bias=False),
|
91 |
+
nn.SiLU(),
|
92 |
+
nn.Conv2d(hidden_dim, dim, kernel_size=1, bias=False),
|
93 |
+
)
|
94 |
+
|
95 |
+
def forward(self, x, context=None, context_mask=None, context_idx=None):
|
96 |
+
width = x.shape[-1]
|
97 |
+
out = self.in_norm(x)
|
98 |
+
out = rearrange(out, 'b c h w -> b (h w) c')
|
99 |
+
context = set_default_item(exist(context), context, out)
|
100 |
+
out = self.attention(out, context, context_mask, context_idx)
|
101 |
+
out = rearrange(out, 'b (h w) c -> b c h w', w=width)
|
102 |
+
x = x + out
|
103 |
+
|
104 |
+
out = self.out_norm(x)
|
105 |
+
out = self.feed_forward(out)
|
106 |
+
x = x + out
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
class DownSampleBlock(nn.Module):
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self, in_channels, out_channels, time_embed_dim,
|
114 |
+
num_resnet_blocks=3, groups=32, down_sample=True, context_dim=None, self_attention=True, num_conditions=1):
|
115 |
+
super().__init__()
|
116 |
+
up_resolutions = [set_default_item(down_sample, False)] + [None] * (num_resnet_blocks - 1)
|
117 |
+
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_resnet_blocks - 1)
|
118 |
+
self.resnet_attn_blocks = nn.ModuleList([
|
119 |
+
nn.ModuleList([
|
120 |
+
ResNetBlock(in_channel, out_channel, time_embed_dim, groups, up_resolution),
|
121 |
+
set_default_layer(
|
122 |
+
exist(context_dim),
|
123 |
+
AttentionBlock, (out_channel, context_dim), {'num_conditions': num_conditions, 'groups': groups},
|
124 |
+
layer_2=Identity
|
125 |
+
)
|
126 |
+
]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
|
127 |
+
])
|
128 |
+
|
129 |
+
self.self_attention_block = set_default_layer(
|
130 |
+
self_attention,
|
131 |
+
AttentionBlock, (out_channels,), {'feed_forward_mult': 4, 'groups': groups},
|
132 |
+
layer_2=Identity
|
133 |
+
)
|
134 |
+
|
135 |
+
def forward(self, x, time_embed, context=None, context_mask=None, context_idx=None):
|
136 |
+
for resnet_block, attention in self.resnet_attn_blocks:
|
137 |
+
x = resnet_block(x, time_embed)
|
138 |
+
x = attention(x, context, context_mask, context_idx)
|
139 |
+
x = self.self_attention_block(x)
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
class UpSampleBlock(nn.Module):
|
144 |
+
|
145 |
+
def __init__(
|
146 |
+
self, in_channels, cat_dim, out_channels, time_embed_dim,
|
147 |
+
num_resnet_blocks=3, groups=32, up_sample=True, context_dim=None, self_attention=True, num_conditions=1):
|
148 |
+
super().__init__()
|
149 |
+
up_resolutions = [None] * (num_resnet_blocks - 1) + [set_default_item(up_sample, True)]
|
150 |
+
hidden_channels = [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_resnet_blocks - 2) + [(in_channels, out_channels)]
|
151 |
+
self.resnet_attn_blocks = nn.ModuleList([
|
152 |
+
nn.ModuleList([
|
153 |
+
ResNetBlock(in_channel, out_channel, time_embed_dim, groups, up_resolution),
|
154 |
+
set_default_layer(
|
155 |
+
exist(context_dim),
|
156 |
+
AttentionBlock, (out_channel, context_dim), {'num_conditions': num_conditions, 'groups': groups, 'feed_forward_mult': 4},
|
157 |
+
layer_2=Identity
|
158 |
+
)
|
159 |
+
]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
|
160 |
+
])
|
161 |
+
|
162 |
+
self.self_attention_block = set_default_layer(
|
163 |
+
self_attention,
|
164 |
+
AttentionBlock, (out_channels,), {'feed_forward_mult': 4, 'groups': groups},
|
165 |
+
layer_2=Identity
|
166 |
+
)
|
167 |
+
|
168 |
+
def forward(self, x, time_embed, context=None, context_mask=None, context_idx=None):
|
169 |
+
for resnet_block, attention in self.resnet_attn_blocks:
|
170 |
+
x = resnet_block(x, time_embed)
|
171 |
+
x = attention(x, context, context_mask, context_idx)
|
172 |
+
x = self.self_attention_block(x)
|
173 |
+
return x
|
174 |
+
|
175 |
+
|
176 |
+
class UNet(nn.Module):
|
177 |
+
|
178 |
+
def __init__(self,
|
179 |
+
model_channels,
|
180 |
+
init_channels=128,
|
181 |
+
num_channels=3,
|
182 |
+
time_embed_dim=512,
|
183 |
+
context_dim=None,
|
184 |
+
groups=32,
|
185 |
+
feature_pooling_type='attention',
|
186 |
+
dim_mult=(1, 2, 4, 8),
|
187 |
+
num_resnet_blocks=(2, 4, 8, 8),
|
188 |
+
num_conditions=1,
|
189 |
+
skip_connect_scale=1.,
|
190 |
+
add_cross_attention=(False, False, False, False),
|
191 |
+
add_self_attention=(False, False, False, False),
|
192 |
+
lowres_cond=True,
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
out_channels = num_channels
|
196 |
+
num_channels = set_default_item(lowres_cond, num_channels * 2, num_channels)
|
197 |
+
init_channels = init_channels or model_channels
|
198 |
+
self.num_conditions = num_conditions
|
199 |
+
self.skip_connect_scale = skip_connect_scale
|
200 |
+
self.to_time_embed = nn.Sequential(
|
201 |
+
SinusoidalPosEmb(init_channels),
|
202 |
+
nn.Linear(init_channels, time_embed_dim),
|
203 |
+
nn.SiLU(),
|
204 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
205 |
+
)
|
206 |
+
|
207 |
+
self.init_conv = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
|
208 |
+
|
209 |
+
hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
|
210 |
+
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
211 |
+
text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
|
212 |
+
layer_params = [num_resnet_blocks, text_dims, add_self_attention]
|
213 |
+
rev_layer_params = map(reversed, layer_params)
|
214 |
+
|
215 |
+
cat_dims = []
|
216 |
+
self.num_levels = len(in_out_dims)
|
217 |
+
self.down_samples = nn.ModuleList([])
|
218 |
+
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
|
219 |
+
down_sample = level != (self.num_levels - 1)
|
220 |
+
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
221 |
+
self.down_samples.append(
|
222 |
+
DownSampleBlock(
|
223 |
+
in_dim, out_dim, time_embed_dim, res_block_num, groups, down_sample, text_dim, self_attention, num_conditions
|
224 |
+
)
|
225 |
+
)
|
226 |
+
|
227 |
+
self.up_samples = nn.ModuleList([])
|
228 |
+
for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
|
229 |
+
up_sample = level != 0
|
230 |
+
self.up_samples.append(
|
231 |
+
UpSampleBlock(
|
232 |
+
in_dim, cat_dims.pop(), out_dim, time_embed_dim, res_block_num, groups, up_sample, text_dim, self_attention, num_conditions
|
233 |
+
)
|
234 |
+
)
|
235 |
+
|
236 |
+
self.norm = nn.GroupNorm(groups, init_channels)
|
237 |
+
self.activation = nn.SiLU()
|
238 |
+
self.out_conv = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
|
239 |
+
|
240 |
+
def forward(self, x, time, context=None, context_mask=None, context_idx=None, lowres_img=None):
|
241 |
+
if exist(lowres_img):
|
242 |
+
_, _, new_height, new_width = x.shape
|
243 |
+
upsampled = F.interpolate(lowres_img, (new_height, new_width), mode="bilinear")
|
244 |
+
x = torch.cat([x, upsampled], dim=1)
|
245 |
+
time_embed = self.to_time_embed(time)
|
246 |
+
|
247 |
+
hidden_states = []
|
248 |
+
x = self.init_conv(x)
|
249 |
+
for level, down_sample in enumerate(self.down_samples):
|
250 |
+
x = down_sample(x, time_embed, context, context_mask, context_idx)
|
251 |
+
if level != self.num_levels - 1:
|
252 |
+
hidden_states.append(x)
|
253 |
+
for level, up_sample in enumerate(self.up_samples):
|
254 |
+
if level != 0:
|
255 |
+
x = torch.cat([x, hidden_states.pop() / self.skip_connect_scale], dim=1)
|
256 |
+
x = up_sample(x, time_embed, context, context_mask, context_idx)
|
257 |
+
x = self.norm(x)
|
258 |
+
x = self.activation(x)
|
259 |
+
x = self.out_conv(x)
|
260 |
+
return x
|
KandiSuperRes/model/utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Identity
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
|
5 |
+
def exist(item):
|
6 |
+
return item is not None
|
7 |
+
|
8 |
+
|
9 |
+
def set_default_item(condition, item_1, item_2=None):
|
10 |
+
if condition:
|
11 |
+
return item_1
|
12 |
+
else:
|
13 |
+
return item_2
|
14 |
+
|
15 |
+
|
16 |
+
def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=Identity, args_2=[], kwargs_2={}):
|
17 |
+
if condition:
|
18 |
+
return layer_1(*args_1, **kwargs_1)
|
19 |
+
else:
|
20 |
+
return layer_2(*args_2, **kwargs_2)
|
21 |
+
|
22 |
+
|
23 |
+
def get_tensor_items(x, pos, broadcast_shape):
|
24 |
+
device = pos.device
|
25 |
+
bs = pos.shape[0]
|
26 |
+
ndims = len(broadcast_shape[1:])
|
27 |
+
x = x.cpu()[pos.cpu()]
|
28 |
+
return x.reshape(bs, *((1,) * ndims)).to(device)
|
29 |
+
|
30 |
+
|
31 |
+
def local_patching(x, height, width, group_size):
|
32 |
+
if group_size > 0:
|
33 |
+
x = rearrange(
|
34 |
+
x, 'b c (h g1) (w g2) -> b (h w) (g1 g2) c',
|
35 |
+
h=height//group_size, w=width//group_size, g1=group_size, g2=group_size
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
x = rearrange(x, 'b c h w -> b (h w) c', h=height, w=width)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
def local_merge(x, height, width, group_size):
|
43 |
+
if group_size > 0:
|
44 |
+
x = rearrange(
|
45 |
+
x, 'b (h w) (g1 g2) c -> b c (h g1) (w g2)',
|
46 |
+
h=height//group_size, w=width//group_size, g1=group_size, g2=group_size
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=height, w=width)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
def global_patching(x, height, width, group_size):
|
54 |
+
x = local_patching(x, height, width, height//group_size)
|
55 |
+
x = x.transpose(-2, -3)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def global_merge(x, height, width, group_size):
|
60 |
+
x = x.transpose(-2, -3)
|
61 |
+
x = local_merge(x, height, width, height//group_size)
|
62 |
+
return x
|
KandiSuperRes/movq.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from .utils import freeze
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
import time
|
11 |
+
|
12 |
+
def nonlinearity(x):
|
13 |
+
return x*torch.sigmoid(x)
|
14 |
+
|
15 |
+
|
16 |
+
class SpatialNorm(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self, f_channels, zq_channels=None, norm_layer=nn.GroupNorm, freeze_norm_layer=False, add_conv=False, **norm_layer_params
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
22 |
+
if zq_channels is not None:
|
23 |
+
if freeze_norm_layer:
|
24 |
+
for p in self.norm_layer.parameters:
|
25 |
+
p.requires_grad = False
|
26 |
+
self.add_conv = add_conv
|
27 |
+
if self.add_conv:
|
28 |
+
self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
|
29 |
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
30 |
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
31 |
+
def forward(self, f, zq=None):
|
32 |
+
norm_f = self.norm_layer(f)
|
33 |
+
if zq is not None:
|
34 |
+
f_size = f.shape[-2:]
|
35 |
+
zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
|
36 |
+
if self.add_conv:
|
37 |
+
zq = self.conv(zq)
|
38 |
+
norm_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
39 |
+
return norm_f
|
40 |
+
|
41 |
+
|
42 |
+
def Normalize(in_channels, zq_ch=None, add_conv=None):
|
43 |
+
return SpatialNorm(
|
44 |
+
in_channels, zq_ch, norm_layer=nn.GroupNorm,
|
45 |
+
freeze_norm_layer=False, add_conv=add_conv, num_groups=32, eps=1e-6, affine=True
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
class Upsample(nn.Module):
|
50 |
+
def __init__(self, in_channels, with_conv):
|
51 |
+
super().__init__()
|
52 |
+
self.with_conv = with_conv
|
53 |
+
if self.with_conv:
|
54 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
55 |
+
in_channels,
|
56 |
+
kernel_size=3,
|
57 |
+
stride=1,
|
58 |
+
padding=1)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
62 |
+
if self.with_conv:
|
63 |
+
x = self.conv(x)
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
class Downsample(nn.Module):
|
68 |
+
def __init__(self, in_channels, with_conv):
|
69 |
+
super().__init__()
|
70 |
+
self.with_conv = with_conv
|
71 |
+
if self.with_conv:
|
72 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
73 |
+
in_channels,
|
74 |
+
kernel_size=3,
|
75 |
+
stride=2,
|
76 |
+
padding=0)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
if self.with_conv:
|
80 |
+
pad = (0,1,0,1)
|
81 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
82 |
+
x = self.conv(x)
|
83 |
+
else:
|
84 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
class ResnetBlock(nn.Module):
|
89 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
90 |
+
dropout, temb_channels=512, zq_ch=None, add_conv=False):
|
91 |
+
super().__init__()
|
92 |
+
self.in_channels = in_channels
|
93 |
+
out_channels = in_channels if out_channels is None else out_channels
|
94 |
+
self.out_channels = out_channels
|
95 |
+
self.use_conv_shortcut = conv_shortcut
|
96 |
+
|
97 |
+
self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
|
98 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
99 |
+
out_channels,
|
100 |
+
kernel_size=3,
|
101 |
+
stride=1,
|
102 |
+
padding=1)
|
103 |
+
if temb_channels > 0:
|
104 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
105 |
+
out_channels)
|
106 |
+
self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
|
107 |
+
self.dropout = torch.nn.Dropout(dropout)
|
108 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
109 |
+
out_channels,
|
110 |
+
kernel_size=3,
|
111 |
+
stride=1,
|
112 |
+
padding=1)
|
113 |
+
if self.in_channels != self.out_channels:
|
114 |
+
if self.use_conv_shortcut:
|
115 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
116 |
+
out_channels,
|
117 |
+
kernel_size=3,
|
118 |
+
stride=1,
|
119 |
+
padding=1)
|
120 |
+
else:
|
121 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
122 |
+
out_channels,
|
123 |
+
kernel_size=1,
|
124 |
+
stride=1,
|
125 |
+
padding=0)
|
126 |
+
|
127 |
+
def forward(self, x, temb, zq=None):
|
128 |
+
h = x
|
129 |
+
h = self.norm1(h, zq)
|
130 |
+
h = nonlinearity(h)
|
131 |
+
h = self.conv1(h)
|
132 |
+
|
133 |
+
if temb is not None:
|
134 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
135 |
+
|
136 |
+
h = self.norm2(h, zq)
|
137 |
+
h = nonlinearity(h)
|
138 |
+
h = self.dropout(h)
|
139 |
+
h = self.conv2(h)
|
140 |
+
|
141 |
+
if self.in_channels != self.out_channels:
|
142 |
+
if self.use_conv_shortcut:
|
143 |
+
x = self.conv_shortcut(x)
|
144 |
+
else:
|
145 |
+
x = self.nin_shortcut(x)
|
146 |
+
return x+h
|
147 |
+
|
148 |
+
|
149 |
+
class AttnBlock(nn.Module):
|
150 |
+
def __init__(self, in_channels, zq_ch=None, add_conv=False):
|
151 |
+
super().__init__()
|
152 |
+
self.in_channels = in_channels
|
153 |
+
|
154 |
+
self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
|
155 |
+
self.q = torch.nn.Conv2d(in_channels,
|
156 |
+
in_channels,
|
157 |
+
kernel_size=1,
|
158 |
+
stride=1,
|
159 |
+
padding=0)
|
160 |
+
self.k = torch.nn.Conv2d(in_channels,
|
161 |
+
in_channels,
|
162 |
+
kernel_size=1,
|
163 |
+
stride=1,
|
164 |
+
padding=0)
|
165 |
+
self.v = torch.nn.Conv2d(in_channels,
|
166 |
+
in_channels,
|
167 |
+
kernel_size=1,
|
168 |
+
stride=1,
|
169 |
+
padding=0)
|
170 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
171 |
+
in_channels,
|
172 |
+
kernel_size=1,
|
173 |
+
stride=1,
|
174 |
+
padding=0)
|
175 |
+
|
176 |
+
|
177 |
+
def forward(self, x, zq=None):
|
178 |
+
h_ = x
|
179 |
+
h_ = self.norm(h_, zq)
|
180 |
+
q = self.q(h_)
|
181 |
+
k = self.k(h_)
|
182 |
+
v = self.v(h_)
|
183 |
+
|
184 |
+
# compute attention
|
185 |
+
b,c,h,w = q.shape
|
186 |
+
q = q.reshape(b,c,h*w)
|
187 |
+
q = q.permute(0,2,1) # b,hw,c
|
188 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
189 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
190 |
+
w_ = w_ * (int(c)**(-0.5))
|
191 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
192 |
+
|
193 |
+
# attend to values
|
194 |
+
v = v.reshape(b,c,h*w)
|
195 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
196 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
197 |
+
h_ = h_.reshape(b,c,h,w)
|
198 |
+
|
199 |
+
h_ = self.proj_out(h_)
|
200 |
+
return x+h_
|
201 |
+
|
202 |
+
|
203 |
+
class Encoder(nn.Module):
|
204 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
205 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
206 |
+
resolution, z_channels, double_z=True, **ignore_kwargs):
|
207 |
+
super().__init__()
|
208 |
+
self.ch = ch
|
209 |
+
self.temb_ch = 0
|
210 |
+
self.num_resolutions = len(ch_mult)
|
211 |
+
self.num_res_blocks = num_res_blocks
|
212 |
+
self.resolution = resolution
|
213 |
+
self.in_channels = in_channels
|
214 |
+
|
215 |
+
# downsampling
|
216 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
217 |
+
self.ch,
|
218 |
+
kernel_size=3,
|
219 |
+
stride=1,
|
220 |
+
padding=1)
|
221 |
+
|
222 |
+
curr_res = resolution
|
223 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
224 |
+
self.down = nn.ModuleList()
|
225 |
+
for i_level in range(self.num_resolutions):
|
226 |
+
block = nn.ModuleList()
|
227 |
+
attn = nn.ModuleList()
|
228 |
+
block_in = ch*in_ch_mult[i_level]
|
229 |
+
block_out = ch*ch_mult[i_level]
|
230 |
+
for i_block in range(self.num_res_blocks):
|
231 |
+
block.append(ResnetBlock(in_channels=block_in,
|
232 |
+
out_channels=block_out,
|
233 |
+
temb_channels=self.temb_ch,
|
234 |
+
dropout=dropout))
|
235 |
+
block_in = block_out
|
236 |
+
if curr_res in attn_resolutions:
|
237 |
+
attn.append(AttnBlock(block_in))
|
238 |
+
down = nn.Module()
|
239 |
+
down.block = block
|
240 |
+
down.attn = attn
|
241 |
+
if i_level != self.num_resolutions-1:
|
242 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
243 |
+
curr_res = curr_res // 2
|
244 |
+
self.down.append(down)
|
245 |
+
|
246 |
+
# middle
|
247 |
+
self.mid = nn.Module()
|
248 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
249 |
+
out_channels=block_in,
|
250 |
+
temb_channels=self.temb_ch,
|
251 |
+
dropout=dropout)
|
252 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
253 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
254 |
+
out_channels=block_in,
|
255 |
+
temb_channels=self.temb_ch,
|
256 |
+
dropout=dropout)
|
257 |
+
|
258 |
+
# end
|
259 |
+
self.norm_out = Normalize(block_in)
|
260 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
261 |
+
2*z_channels if double_z else z_channels,
|
262 |
+
kernel_size=3,
|
263 |
+
stride=1,
|
264 |
+
padding=1)
|
265 |
+
|
266 |
+
|
267 |
+
def forward(self, x):
|
268 |
+
temb = None
|
269 |
+
|
270 |
+
# downsampling
|
271 |
+
hs = [self.conv_in(x)]
|
272 |
+
for i_level in range(self.num_resolutions):
|
273 |
+
for i_block in range(self.num_res_blocks):
|
274 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
275 |
+
if len(self.down[i_level].attn) > 0:
|
276 |
+
h = self.down[i_level].attn[i_block](h)
|
277 |
+
hs.append(h)
|
278 |
+
if i_level != self.num_resolutions-1:
|
279 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
280 |
+
|
281 |
+
# middle
|
282 |
+
h = hs[-1]
|
283 |
+
h = self.mid.block_1(h, temb)
|
284 |
+
h = self.mid.attn_1(h)
|
285 |
+
h = self.mid.block_2(h, temb)
|
286 |
+
|
287 |
+
# end
|
288 |
+
h = self.norm_out(h)
|
289 |
+
h = nonlinearity(h)
|
290 |
+
h = self.conv_out(h)
|
291 |
+
return h
|
292 |
+
|
293 |
+
|
294 |
+
class Decoder(nn.Module):
|
295 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
296 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
297 |
+
resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, **ignorekwargs):
|
298 |
+
super().__init__()
|
299 |
+
self.ch = ch
|
300 |
+
self.temb_ch = 0
|
301 |
+
self.num_resolutions = len(ch_mult)
|
302 |
+
self.num_res_blocks = num_res_blocks
|
303 |
+
self.resolution = resolution
|
304 |
+
self.in_channels = in_channels
|
305 |
+
self.give_pre_end = give_pre_end
|
306 |
+
|
307 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
308 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
309 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
310 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
311 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
312 |
+
|
313 |
+
# z to block_in
|
314 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
315 |
+
block_in,
|
316 |
+
kernel_size=3,
|
317 |
+
stride=1,
|
318 |
+
padding=1)
|
319 |
+
|
320 |
+
# middle
|
321 |
+
self.mid = nn.Module()
|
322 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
323 |
+
out_channels=block_in,
|
324 |
+
temb_channels=self.temb_ch,
|
325 |
+
dropout=dropout,
|
326 |
+
zq_ch=zq_ch,
|
327 |
+
add_conv=add_conv)
|
328 |
+
self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
|
329 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
330 |
+
out_channels=block_in,
|
331 |
+
temb_channels=self.temb_ch,
|
332 |
+
dropout=dropout,
|
333 |
+
zq_ch=zq_ch,
|
334 |
+
add_conv=add_conv)
|
335 |
+
|
336 |
+
# upsampling
|
337 |
+
self.up = nn.ModuleList()
|
338 |
+
for i_level in reversed(range(self.num_resolutions)):
|
339 |
+
block = nn.ModuleList()
|
340 |
+
attn = nn.ModuleList()
|
341 |
+
block_out = ch*ch_mult[i_level]
|
342 |
+
for i_block in range(self.num_res_blocks+1):
|
343 |
+
block.append(ResnetBlock(in_channels=block_in,
|
344 |
+
out_channels=block_out,
|
345 |
+
temb_channels=self.temb_ch,
|
346 |
+
dropout=dropout,
|
347 |
+
zq_ch=zq_ch,
|
348 |
+
add_conv=add_conv))
|
349 |
+
block_in = block_out
|
350 |
+
if curr_res in attn_resolutions:
|
351 |
+
attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
|
352 |
+
up = nn.Module()
|
353 |
+
up.block = block
|
354 |
+
up.attn = attn
|
355 |
+
if i_level != 0:
|
356 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
357 |
+
curr_res = curr_res * 2
|
358 |
+
self.up.insert(0, up) # prepend to get consistent order
|
359 |
+
|
360 |
+
# end
|
361 |
+
self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
|
362 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
363 |
+
out_ch,
|
364 |
+
kernel_size=3,
|
365 |
+
stride=1,
|
366 |
+
padding=1)
|
367 |
+
|
368 |
+
def forward(self, z, zq):
|
369 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
370 |
+
self.last_z_shape = z.shape
|
371 |
+
|
372 |
+
# timestep embedding
|
373 |
+
temb = None
|
374 |
+
|
375 |
+
# z to block_in
|
376 |
+
h = self.conv_in(z)
|
377 |
+
|
378 |
+
# middle
|
379 |
+
h = self.mid.block_1(h, temb, zq)
|
380 |
+
h = self.mid.attn_1(h, zq)
|
381 |
+
h = self.mid.block_2(h, temb, zq)
|
382 |
+
|
383 |
+
# upsampling
|
384 |
+
for i_level in reversed(range(self.num_resolutions)):
|
385 |
+
for i_block in range(self.num_res_blocks+1):
|
386 |
+
h = self.up[i_level].block[i_block](h, temb, zq)
|
387 |
+
if len(self.up[i_level].attn) > 0:
|
388 |
+
h = self.up[i_level].attn[i_block](h, zq)
|
389 |
+
if i_level != 0:
|
390 |
+
h = self.up[i_level].upsample(h)
|
391 |
+
|
392 |
+
# end
|
393 |
+
if self.give_pre_end:
|
394 |
+
return h
|
395 |
+
|
396 |
+
h = self.norm_out(h, zq)
|
397 |
+
h = nonlinearity(h)
|
398 |
+
h = self.conv_out(h)
|
399 |
+
return h
|
400 |
+
|
401 |
+
|
402 |
+
class MoVQ(nn.Module):
|
403 |
+
|
404 |
+
def __init__(self, generator_params):
|
405 |
+
super().__init__()
|
406 |
+
z_channels = generator_params["z_channels"]
|
407 |
+
self.encoder = Encoder(**generator_params)
|
408 |
+
self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
409 |
+
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
410 |
+
self.decoder = Decoder(zq_ch=z_channels, **generator_params)
|
411 |
+
|
412 |
+
self.tile_sample_min_size = generator_params["tile_sample_min_size"]
|
413 |
+
self.scale_factor = 8
|
414 |
+
self.tile_latent_min_size = int(self.tile_sample_min_size / self.scale_factor)
|
415 |
+
self.tile_overlap_factor_enc = generator_params["tile_overlap_factor_enc"]
|
416 |
+
self.tile_overlap_factor_dec = generator_params["tile_overlap_factor_dec"]
|
417 |
+
self.use_tiling = generator_params["use_tiling"]
|
418 |
+
|
419 |
+
@torch.no_grad()
|
420 |
+
def encode(self, x):
|
421 |
+
if self.use_tiling and (
|
422 |
+
x.shape[-1] > self.tile_sample_min_size
|
423 |
+
or x.shape[-2] > self.tile_sample_min_size
|
424 |
+
):
|
425 |
+
print('tiled_encode')
|
426 |
+
return self.tiled_encode(x)
|
427 |
+
h = self.encoder(x)
|
428 |
+
h = self.quant_conv(h)
|
429 |
+
return h
|
430 |
+
|
431 |
+
@torch.no_grad()
|
432 |
+
def decode(self, quant):
|
433 |
+
if self.use_tiling and (
|
434 |
+
quant.shape[-1] > self.tile_latent_min_size
|
435 |
+
or quant.shape[-2] > self.tile_latent_min_size
|
436 |
+
):
|
437 |
+
print('tiled_decode')
|
438 |
+
return self.tiled_decode(quant)
|
439 |
+
decoder_input = self.post_quant_conv(quant)
|
440 |
+
decoded = self.decoder(decoder_input, quant)
|
441 |
+
return decoded
|
442 |
+
|
443 |
+
def blend_v(
|
444 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
445 |
+
) -> torch.Tensor:
|
446 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
447 |
+
for y in range(blend_extent):
|
448 |
+
b[ :, :, y, :] = a[ :, :, -blend_extent + y, :] * (
|
449 |
+
1 - y / blend_extent
|
450 |
+
) + b[ :, :, y, :] * (y / blend_extent)
|
451 |
+
return b
|
452 |
+
|
453 |
+
def blend_h(
|
454 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
455 |
+
) -> torch.Tensor:
|
456 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
457 |
+
for x in range(blend_extent):
|
458 |
+
b[ :, :, :, x] = a[ :, :, :, -blend_extent + x] * (
|
459 |
+
1 - x / blend_extent
|
460 |
+
) + b[ :, :, :, x] * (x / blend_extent)
|
461 |
+
return b
|
462 |
+
|
463 |
+
def tiled_encode(self, x):
|
464 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor_enc))
|
465 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor_enc)
|
466 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
467 |
+
|
468 |
+
# Split the image into tiles and encode them separately.
|
469 |
+
rows = []
|
470 |
+
for i in tqdm(range(0, x.shape[2], overlap_size)):
|
471 |
+
row = []
|
472 |
+
for j in range(0, x.shape[3], overlap_size):
|
473 |
+
tile = x[
|
474 |
+
:,
|
475 |
+
:,
|
476 |
+
i : i + self.tile_sample_min_size,
|
477 |
+
j : j + self.tile_sample_min_size,
|
478 |
+
]
|
479 |
+
tile = self.encode(tile)
|
480 |
+
row.append(tile)
|
481 |
+
rows.append(row)
|
482 |
+
result_rows = []
|
483 |
+
for i, row in enumerate(rows):
|
484 |
+
result_row = []
|
485 |
+
for j, tile in enumerate(row):
|
486 |
+
# blend the above tile and the left tile
|
487 |
+
# to the current tile and add the current tile to the result row
|
488 |
+
if i > 0:
|
489 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
490 |
+
if j > 0:
|
491 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
492 |
+
result_row.append(tile[ :, :, :row_limit, :row_limit])
|
493 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
494 |
+
|
495 |
+
h = torch.cat(result_rows, dim=2)
|
496 |
+
return h
|
497 |
+
|
498 |
+
def tiled_decode(self, z):
|
499 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor_dec))
|
500 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor_dec)
|
501 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
502 |
+
|
503 |
+
# Split z into overlapping tiles and decode them separately.
|
504 |
+
# The tiles have an overlap to avoid seams between tiles.
|
505 |
+
rows = []
|
506 |
+
for i in tqdm(range(0, z.shape[2], overlap_size)):
|
507 |
+
row = []
|
508 |
+
for j in range(0, z.shape[3], overlap_size):
|
509 |
+
tile = z[
|
510 |
+
:,
|
511 |
+
:,
|
512 |
+
i : i + self.tile_latent_min_size,
|
513 |
+
j : j + self.tile_latent_min_size,
|
514 |
+
]
|
515 |
+
decoded = self.decode(tile)
|
516 |
+
row.append(decoded)
|
517 |
+
rows.append(row)
|
518 |
+
result_rows = []
|
519 |
+
for i, row in enumerate(rows):
|
520 |
+
result_row = []
|
521 |
+
for j, tile in enumerate(row):
|
522 |
+
# blend the above tile and the left tile
|
523 |
+
# to the current tile and add the current tile to the result row
|
524 |
+
if i > 0:
|
525 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
526 |
+
if j > 0:
|
527 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
528 |
+
result_row.append(tile[ :, :, :row_limit, :row_limit])
|
529 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
530 |
+
|
531 |
+
dec = torch.cat(result_rows, dim=2)
|
532 |
+
return dec
|
533 |
+
|
534 |
+
|
535 |
+
def get_vae(conf):
|
536 |
+
movq = MoVQ(conf.params)
|
537 |
+
if conf.checkpoint is not None:
|
538 |
+
movq_state_dict = torch.load(conf.checkpoint)
|
539 |
+
movq.load_state_dict(movq_state_dict)
|
540 |
+
movq = freeze(movq)
|
541 |
+
return movq
|
KandiSuperRes/sr_pipeline.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
import torchvision.transforms as T
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from KandiSuperRes.model.unet import UNet
|
7 |
+
from KandiSuperRes.model.unet_sr import UNet as UNet_sr
|
8 |
+
from KandiSuperRes.movq import MoVQ
|
9 |
+
from KandiSuperRes.model.diffusion_sr import DPMSolver
|
10 |
+
from KandiSuperRes.model.diffusion_refine import BaseDiffusion, get_named_beta_schedule
|
11 |
+
from KandiSuperRes.model.diffusion_sr_turbo import BaseDiffusion as BaseDiffusion_turbo
|
12 |
+
|
13 |
+
|
14 |
+
class KandiSuperResPipeline:
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
scale: int,
|
19 |
+
device: str,
|
20 |
+
dtype: str,
|
21 |
+
flash: bool,
|
22 |
+
sr_model: UNet_sr,
|
23 |
+
movq: MoVQ = None,
|
24 |
+
refiner: UNet = None,
|
25 |
+
):
|
26 |
+
self.device = device
|
27 |
+
self.dtype = dtype
|
28 |
+
self.scale = scale
|
29 |
+
self.flash = flash
|
30 |
+
self.to_pil = T.ToPILImage()
|
31 |
+
self.image_transform = T.Compose([
|
32 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
33 |
+
T.ToTensor(),
|
34 |
+
T.Lambda(lambda img: 2. * img - 1.),
|
35 |
+
])
|
36 |
+
|
37 |
+
self.sr_model = sr_model
|
38 |
+
self.movq = movq
|
39 |
+
self.refiner = refiner
|
40 |
+
|
41 |
+
def __call__(
|
42 |
+
self,
|
43 |
+
pil_image: PIL.Image.Image = None,
|
44 |
+
steps: int = 5,
|
45 |
+
view_batch_size: int = 15,
|
46 |
+
seed: int = 0,
|
47 |
+
refine=True
|
48 |
+
) -> PIL.Image.Image:
|
49 |
+
|
50 |
+
if self.flash:
|
51 |
+
betas_turbo = get_named_beta_schedule('linear', 1000)
|
52 |
+
base_diffusion_sr = BaseDiffusion_turbo(betas_turbo)
|
53 |
+
|
54 |
+
old_height = pil_image.size[1]
|
55 |
+
old_width = pil_image.size[0]
|
56 |
+
height = int(old_height-np.mod(old_height,32))
|
57 |
+
width = int(old_width-np.mod(old_width,32))
|
58 |
+
|
59 |
+
pil_image = pil_image.resize((width,height))
|
60 |
+
lr_image = self.image_transform(pil_image).unsqueeze(0).to(self.device['sr_model'])
|
61 |
+
|
62 |
+
sr_image = base_diffusion_sr.p_sample_loop(
|
63 |
+
self.sr_model, (1, 3, height*self.scale, width*self.scale), self.device['sr_model'], self.dtype['sr_model'], lowres_img=lr_image
|
64 |
+
)
|
65 |
+
|
66 |
+
if refine:
|
67 |
+
betas = get_named_beta_schedule('cosine', 1000)
|
68 |
+
base_diffusion = BaseDiffusion(betas, 0.99)
|
69 |
+
|
70 |
+
with torch.cuda.amp.autocast(dtype=self.dtype['movq']):
|
71 |
+
lr_image_latent = self.movq.encode(sr_image)
|
72 |
+
|
73 |
+
pil_images = []
|
74 |
+
context = torch.load('weights/context.pt').to(self.dtype['refiner'])
|
75 |
+
context_mask = torch.load('weights/context_mask.pt').to(self.dtype['refiner'])
|
76 |
+
|
77 |
+
with torch.no_grad():
|
78 |
+
with torch.cuda.amp.autocast(dtype=self.dtype['refiner']):
|
79 |
+
refiner_image = base_diffusion.refine_tiled(self.refiner, lr_image_latent, context, context_mask)
|
80 |
+
|
81 |
+
with torch.cuda.amp.autocast(dtype=self.dtype['movq']):
|
82 |
+
refiner_image = self.movq.decode(refiner_image)
|
83 |
+
refiner_image = torch.clip((refiner_image + 1.) / 2., 0., 1.)
|
84 |
+
|
85 |
+
if old_height*self.scale != refiner_image.shape[2] or old_width*self.scale != refiner_image.shape[3]:
|
86 |
+
refiner_image = F.interpolate(refiner_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
|
87 |
+
refined_pil_image = self.to_pil(refiner_image[0])
|
88 |
+
return refined_pil_image
|
89 |
+
|
90 |
+
sr_image = torch.clip((sr_image + 1.) / 2., 0., 1.)
|
91 |
+
if old_height*self.scale != sr_image.shape[2] or old_width*self.scale != sr_image.shape[3]:
|
92 |
+
sr_image = F.interpolate(sr_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
|
93 |
+
pil_sr_image = self.to_pil(sr_image[0])
|
94 |
+
return pil_sr_image
|
95 |
+
|
96 |
+
else:
|
97 |
+
base_diffusion = DPMSolver(steps)
|
98 |
+
|
99 |
+
lr_image = self.image_transform(pil_image).unsqueeze(0).to(self.device)
|
100 |
+
|
101 |
+
old_height = pil_image.size[1]
|
102 |
+
old_width = pil_image.size[0]
|
103 |
+
|
104 |
+
height = int(old_height+np.mod(old_height,2))*self.scale
|
105 |
+
width = int(old_width+np.mod(old_width,2))*self.scale
|
106 |
+
|
107 |
+
sr_image = base_diffusion.generate_panorama(height, width, self.device, self.dtype, steps,
|
108 |
+
self.sr_model, lowres_img=lr_image,
|
109 |
+
view_batch_size=view_batch_size, eta=0.0, seed=seed)
|
110 |
+
|
111 |
+
sr_image = torch.clip((sr_image + 1.) / 2., 0., 1.)
|
112 |
+
if old_height*self.scale != height or old_width*self.scale != width:
|
113 |
+
sr_image = F.interpolate(sr_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
|
114 |
+
|
115 |
+
pil_sr_image = self.to_pil(sr_image[0])
|
116 |
+
return pil_sr_image
|
KandiSuperRes/utils.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def freeze(model):
|
2 |
+
for p in model.parameters():
|
3 |
+
p.requires_grad = False
|
4 |
+
return model
|
5 |
+
|
6 |
+
def unfreeze(model):
|
7 |
+
for p in model.parameters():
|
8 |
+
p.requires_grad = True
|
9 |
+
return model
|
weights/context.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:92cc62ca3e341bd4ea03df187c06aceb43505e79107b6a2fef717a86051a6296
|
3 |
+
size 1049756
|
weights/context_mask.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6ff4ea52d9deb41f4732dd422fef48b1382247d9dbe0493d0a274e0d2e13591f
|
3 |
+
size 2229
|