doevent commited on
Commit
5004324
·
verified ·
1 Parent(s): 6dcbc25

Upload 14 files

Browse files
KandiSuperRes/__init__.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Union
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ from .sr_pipeline import KandiSuperResPipeline
6
+ from KandiSuperRes.model.unet import UNet
7
+ from KandiSuperRes.model.unet_sr import UNet as UNet_sr
8
+ from KandiSuperRes.movq import MoVQ
9
+
10
+
11
+ def get_sr_model(
12
+ device: Union[str, torch.device],
13
+ weights_path: Optional[str] = None,
14
+ dtype: Union[str, torch.dtype] = torch.float16
15
+ ) -> (UNet_sr, Optional[dict], Optional[torch.Tensor]):
16
+ unet = UNet_sr(
17
+ init_channels=128,
18
+ model_channels=128,
19
+ num_channels=3,
20
+ time_embed_dim=512,
21
+ groups=32,
22
+ dim_mult=(1, 2, 4, 8),
23
+ num_resnet_blocks=(2,4,8,8),
24
+ add_cross_attention=(False, False, False, False),
25
+ add_self_attention=(False, False, False, False),
26
+ feature_pooling_type='attention',
27
+ lowres_cond =True
28
+ )
29
+
30
+ if weights_path:
31
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
32
+ try:
33
+ unet.load_state_dict(state_dict['unet'])
34
+ except:
35
+ unet.load_state_dict(state_dict)
36
+ unet.to(device=device, dtype=dtype).eval()
37
+ return unet
38
+
39
+
40
+ def get_T2I_unet(
41
+ device: Union[str, torch.device],
42
+ weights_path: Optional[str] = None,
43
+ dtype: Union[str, torch.dtype] = torch.float32,
44
+ ) -> (UNet, Optional[torch.Tensor], Optional[dict]):
45
+ unet = UNet(
46
+ model_channels=384,
47
+ num_channels=4,
48
+ init_channels=192,
49
+ time_embed_dim=1536,
50
+ context_dim=4096,
51
+ groups=32,
52
+ head_dim=64,
53
+ expansion_ratio=4,
54
+ compression_ratio=2,
55
+ dim_mult=(1, 2, 4, 8),
56
+ num_blocks=(3, 3, 3, 3),
57
+ add_cross_attention=(False, True, True, True),
58
+ add_self_attention=(False, True, True, True),
59
+ )
60
+
61
+ null_embedding = None
62
+ if weights_path:
63
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
64
+ null_embedding = state_dict['null_embedding']
65
+ unet.load_state_dict(state_dict['unet'])
66
+
67
+ unet.to(device=device, dtype=dtype).eval()
68
+ return unet, null_embedding
69
+
70
+
71
+ def get_movq(
72
+ device: Union[str, torch.device],
73
+ weights_path: Optional[str] = None,
74
+ dtype: Union[str, torch.dtype] = torch.float32,
75
+ ) -> MoVQ:
76
+ generator_config = {
77
+ 'double_z': False,
78
+ 'z_channels': 4,
79
+ 'resolution': 256,
80
+ 'in_channels': 3,
81
+ 'out_ch': 3,
82
+ 'ch': 256,
83
+ 'ch_mult': [1, 2, 2, 4],
84
+ 'num_res_blocks': 2,
85
+ 'attn_resolutions': [32],
86
+ 'dropout': 0.0,
87
+ 'tile_sample_min_size': 1024,
88
+ 'tile_overlap_factor_enc': 0.0,
89
+ 'tile_overlap_factor_dec': 0.25,
90
+ 'use_tiling': True
91
+ }
92
+ movq = MoVQ(generator_config)
93
+
94
+ if weights_path:
95
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
96
+ movq.load_state_dict(state_dict)
97
+
98
+ movq.to(device=device, dtype=dtype).eval()
99
+ return movq
100
+
101
+
102
+ def get_SR_pipeline(
103
+ device: Union[str, torch.device],
104
+ fp16: bool = True,
105
+ flash: bool = True,
106
+ scale: int = 2,
107
+ cache_dir: str = '/tmp/KandiSuperRes/',
108
+ movq_path: str = None,
109
+ refiner_path: str = None,
110
+ unet_sr_path: str = None,
111
+ ) -> KandiSuperResPipeline:
112
+
113
+ if flash:
114
+ if scale == 2:
115
+ device_map = {
116
+ 'movq': device, 'refiner': device, 'sr_model': device
117
+ }
118
+ dtype = torch.float16 if fp16 else torch.float32
119
+ dtype_map = {
120
+ 'movq': torch.float32, 'refiner': dtype, 'sr_model': dtype
121
+ }
122
+ if movq_path is None:
123
+ print('Download movq weights')
124
+ movq_path = hf_hub_download(
125
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
126
+ )
127
+ if refiner_path is None:
128
+ print('Download refiner weights')
129
+ refiner_path = hf_hub_download(
130
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_flash.pt', cache_dir=cache_dir
131
+ )
132
+ if unet_sr_path is None:
133
+ print('Download KandiSuperRes Flash weights')
134
+ unet_sr_path = hf_hub_download(
135
+ repo_id="ai-forever/KandiSuperRes", filename='KandiSuperRes_flash_x2.pt', cache_dir=cache_dir
136
+ )
137
+ sr_model = get_sr_model(device_map['sr_model'], unet_sr_path, dtype=dtype_map['sr_model'])
138
+ movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
139
+ refiner, _ = get_T2I_unet(device_map['refiner'], refiner_path, dtype=dtype_map['refiner'])
140
+ return KandiSuperResPipeline(
141
+ scale, device_map, dtype_map, flash, sr_model, movq, refiner
142
+ )
143
+ else:
144
+ print('Flash model for x4 scale is not implemented.')
145
+ else:
146
+ if unet_sr_path is None:
147
+ if scale == 4:
148
+ unet_sr_path = hf_hub_download(
149
+ repo_id="ai-forever/KandiSuperRes", filename='KandiSuperRes.ckpt', cache_dir=cache_dir
150
+ )
151
+ elif scale == 2:
152
+ unet_sr_path = hf_hub_download(
153
+ repo_id="ai-forever/KandiSuperRes", filename='KandiSuperRes_x2.ckpt', cache_dir=cache_dir
154
+ )
155
+ dtype = torch.float16 if fp16 else torch.float32
156
+ sr_model = get_sr_model(device, unet_sr_path, dtype=dtype)
157
+ return KandiSuperResPipeline(scale, device, dtype, flash, sr_model)
KandiSuperRes/model/__init__.py ADDED
File without changes
KandiSuperRes/model/diffusion_refine.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from tqdm import tqdm
4
+ from .utils import get_tensor_items
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def get_named_beta_schedule(schedule_name, timesteps):
9
+ if schedule_name == "linear":
10
+ scale = 1000 / timesteps
11
+ beta_start = scale * 0.0001
12
+ beta_end = scale * 0.02
13
+ return torch.linspace(
14
+ beta_start, beta_end, timesteps, dtype=torch.float32
15
+ )
16
+ elif schedule_name == "cosine":
17
+ alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
18
+ betas = []
19
+ for i in range(timesteps):
20
+ t1 = i / timesteps
21
+ t2 = (i + 1) / timesteps
22
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999))
23
+ return torch.tensor(betas, dtype=torch.float32)
24
+
25
+
26
+ class BaseDiffusion:
27
+
28
+ def __init__(self, betas, percentile=None, gen_noise=torch.randn_like):
29
+ self.betas = betas
30
+ self.num_timesteps = betas.shape[0]
31
+
32
+ alphas = 1. - betas
33
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
34
+ self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]])
35
+
36
+ # calculate q(x_t | x_{t-1})
37
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
38
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
39
+
40
+ self.time_scale = 1000 // self.num_timesteps
41
+ self.gen_noise = gen_noise
42
+
43
+ def get_x_start(self, x, t, noise):
44
+ sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
45
+ sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, noise.shape)
46
+ pred_x_start = (x - sqrt_one_minus_alphas_cumprod * noise) / sqrt_alphas_cumprod
47
+ return pred_x_start
48
+
49
+ def q_sample(self, x_start, t, noise=None):
50
+ if noise is None:
51
+ noise = self.gen_noise(x_start)
52
+ sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
53
+ sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
54
+ x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise
55
+ return x_t
56
+
57
+ @torch.no_grad()
58
+ def refine(self, model, img, context, context_mask):
59
+ # for time in tqdm([479, 229]):
60
+ for time in [229]:
61
+ time = torch.tensor([time,] * img.shape[0], device=img.device)
62
+ x_t = self.q_sample(img, time)
63
+ pred_noise = model(x_t, time.type(x_t.dtype), context, context_mask.bool())
64
+ img = self.get_x_start(x_t, time, pred_noise)
65
+ return img
66
+
67
+ def blend_v(
68
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
69
+ ) -> torch.Tensor:
70
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
71
+ for y in range(blend_extent):
72
+ b[ :, :, y, :] = a[ :, :, -blend_extent + y, :] * (
73
+ 1 - y / blend_extent
74
+ ) + b[ :, :, y, :] * (y / blend_extent)
75
+ return b
76
+
77
+ def blend_h(
78
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
79
+ ) -> torch.Tensor:
80
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
81
+ for x in range(blend_extent):
82
+ b[ :, :, :, x] = a[ :, :, :, -blend_extent + x] * (
83
+ 1 - x / blend_extent
84
+ ) + b[ :, :, :, x] * (x / blend_extent)
85
+ return b
86
+
87
+
88
+ def refine_tiled(self, model, img, context, context_mask):
89
+ tile_sample_min_size = 352
90
+ tile_overlap_factor = 0.25
91
+
92
+ overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor))
93
+ tile_latent_min_size = int(tile_sample_min_size)
94
+ blend_extent = int(tile_latent_min_size * tile_overlap_factor)
95
+ row_limit = tile_latent_min_size - blend_extent
96
+
97
+ # Split the image into tiles and encode them separately.
98
+ rows = []
99
+ for i in tqdm(range(0, img.shape[2], overlap_size)):
100
+ row = []
101
+ for j in range(0, img.shape[3], overlap_size):
102
+ tile = img[
103
+ :,
104
+ :,
105
+ i : i + tile_sample_min_size,
106
+ j : j + tile_sample_min_size,
107
+ ]
108
+ tile = self.refine(model, tile, context, context_mask)
109
+ row.append(tile)
110
+ rows.append(row)
111
+ result_rows = []
112
+ for i, row in enumerate(rows):
113
+ result_row = []
114
+ for j, tile in enumerate(row):
115
+ # blend the above tile and the left tile
116
+ # to the current tile and add the current tile to the result row
117
+ if i > 0:
118
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
119
+ if j > 0:
120
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
121
+ result_row.append(tile[ :, :, :row_limit, :row_limit])
122
+ result_rows.append(torch.cat(result_row, dim=3))
123
+
124
+ refine_img = torch.cat(result_rows, dim=2)
125
+ return refine_img
126
+
127
+
128
+ def get_diffusion(conf):
129
+ betas = get_named_beta_schedule(**conf.schedule_params)
130
+ base_diffusion = BaseDiffusion(betas, **conf.diffusion_params)
131
+ return base_diffusion
KandiSuperRes/model/diffusion_sr.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DDIMScheduler, DPMSolverMultistepScheduler
2
+ from einops import repeat
3
+ import copy
4
+ import inspect
5
+ import math
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from tqdm import tqdm
9
+
10
+
11
+ class DPMSolver:
12
+
13
+ def __init__(self, num_timesteps):
14
+ self.dpm_solver = DPMSolverMultistepScheduler(
15
+ beta_schedule="linear",
16
+ prediction_type= "sample",
17
+ # algorithm_type="sde-dpmsolver++",
18
+ thresholding=False
19
+ )
20
+ self.dpm_solver.set_timesteps(num_timesteps)
21
+
22
+
23
+ @torch.no_grad()
24
+ def pred_noise(self, model, x, t, lowres_img, dtype):
25
+ pred_noise = model(x.to(dtype), t.to(dtype), lowres_img=lowres_img.to(dtype))
26
+ pred_noise = pred_noise.to(dtype=torch.float32)
27
+ return pred_noise
28
+
29
+
30
+ def prepare_extra_step_kwargs(self, generator, eta):
31
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
32
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
33
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
34
+ # and should be between [0, 1]
35
+
36
+ accepts_eta = "eta" in set(inspect.signature(self.dpm_solver.step).parameters.keys())
37
+ extra_step_kwargs = {}
38
+ if accepts_eta:
39
+ extra_step_kwargs["eta"] = eta
40
+
41
+ # check if the scheduler accepts generator
42
+ accepts_generator = "generator" in set(inspect.signature(self.dpm_solver.step).parameters.keys())
43
+ if accepts_generator:
44
+ extra_step_kwargs["generator"] = generator
45
+ return extra_step_kwargs
46
+
47
+
48
+ def get_views(self, panorama_height, panorama_width, window_size=1024, stride=800):
49
+ # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
50
+ # if panorama's height/width < window_size, num_blocks of height/width should return 1
51
+ num_blocks_height = round(math.ceil((panorama_height - window_size) / stride)) + 1 if panorama_height > window_size else 1
52
+ num_blocks_width = round(math.ceil((panorama_width - window_size) / stride)) + 1 if panorama_width > window_size else 1
53
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
54
+ views = []
55
+ for i in range(total_num_blocks):
56
+ h_start = int((i // num_blocks_width) * stride)
57
+ h_end = h_start + window_size
58
+ if h_end > panorama_height and num_blocks_height > 1:
59
+ h_end = panorama_height
60
+ h_start = panorama_height - window_size
61
+ w_start = int((i % num_blocks_width) * stride)
62
+ w_end = w_start + window_size
63
+ if w_end > panorama_width and num_blocks_width > 1:
64
+ w_end = panorama_width
65
+ w_start = panorama_width - window_size
66
+ views.append((h_start, h_end, w_start, w_end))
67
+ return views
68
+
69
+
70
+ def generate_panorama(self, height, width, device, dtype, num_inference_steps,
71
+ unet, lowres_img, view_batch_size=15, eta=0, seed=0):
72
+ # 6. Define panorama grid and initialize views for synthesis.
73
+ # prepare batch grid
74
+ views = self.get_views(height, width)
75
+ views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
76
+ views_scheduler_status = [copy.deepcopy(self.dpm_solver.__dict__)] * len(views_batch)
77
+
78
+ shape = (1, 3, height, width)
79
+ count = torch.zeros(*shape, device=device)
80
+ value = torch.zeros(*shape, device=device)
81
+
82
+ generator = torch.Generator(device=device)
83
+ if seed is not None:
84
+ generator = generator.manual_seed(seed)
85
+
86
+ img = torch.randn(*shape, device=device, generator=generator)
87
+ up_lowres_img = F.interpolate(lowres_img, (shape[2], shape[3]), mode="bilinear")
88
+
89
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
90
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
91
+
92
+ # 8. Denoising loop
93
+ # Each denoising step also includes refinement of the latents with respect to the
94
+ # views.
95
+ timesteps = self.dpm_solver.timesteps
96
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.dpm_solver.order
97
+
98
+ for i, time in tqdm(enumerate(self.dpm_solver.timesteps)):
99
+ count.zero_()
100
+ value.zero_()
101
+
102
+ # generate views
103
+ # Here, we iterate through different spatial crops of the latents and denoise them. These
104
+ # denoised (latent) crops are then averaged to produce the final latent
105
+ # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
106
+ # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
107
+ # Batch views denoise
108
+ for j, batch_view in enumerate(views_batch):
109
+ vb_size = len(batch_view)
110
+ # get the latents corresponding to the current view coordinates
111
+ img_for_view = torch.cat(
112
+ [
113
+ img[:, :, h_start:h_end, w_start:w_end]
114
+ for h_start, h_end, w_start, w_end in batch_view
115
+ ]
116
+ )
117
+ lowres_img_for_view = torch.cat(
118
+ [
119
+ up_lowres_img[:, :, h_start:h_end, w_start:w_end]
120
+ for h_start, h_end, w_start, w_end in batch_view
121
+ ]
122
+ )
123
+
124
+ # rematch block's scheduler status
125
+ self.dpm_solver.__dict__.update(views_scheduler_status[j])
126
+
127
+ t = torch.tensor([time] * img_for_view.shape[0], device=device)
128
+ pred_noise = self.pred_noise(
129
+ unet, img_for_view, t, lowres_img_for_view, dtype
130
+ )
131
+ img_denoised_batch = self.dpm_solver.step(pred_noise, time, img_for_view, **extra_step_kwargs).prev_sample
132
+
133
+ # save views scheduler status after sample
134
+ views_scheduler_status[j] = copy.deepcopy(self.dpm_solver.__dict__)
135
+
136
+ # extract value from batch
137
+ for img_view_denoised, (h_start, h_end, w_start, w_end) in zip(
138
+ img_denoised_batch.chunk(vb_size), batch_view
139
+ ):
140
+ value[:, :, h_start:h_end, w_start:w_end] += img_view_denoised
141
+ count[:, :, h_start:h_end, w_start:w_end] += 1
142
+
143
+ # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
144
+ img = torch.where(count > 0, value / count, value)
145
+
146
+ return img
KandiSuperRes/model/diffusion_sr_turbo.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from tqdm import tqdm
6
+
7
+ from .utils import get_tensor_items, exist
8
+ import numpy as np
9
+
10
+
11
+ def get_named_beta_schedule(schedule_name, timesteps):
12
+ if schedule_name == "linear":
13
+ scale = 1000 / timesteps
14
+ beta_start = scale * 0.0001
15
+ beta_end = scale * 0.02
16
+ return torch.linspace(
17
+ beta_start, beta_end, timesteps, dtype=torch.float32
18
+ )
19
+ elif schedule_name == "cosine":
20
+ alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
21
+ betas = []
22
+ for i in range(timesteps):
23
+ t1 = i / timesteps
24
+ t2 = (i + 1) / timesteps
25
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999))
26
+ return torch.tensor(betas, dtype=torch.float32)
27
+
28
+
29
+ class BaseDiffusion:
30
+
31
+ def __init__(self, betas, percentile=None, gen_noise=torch.randn_like):
32
+ self.betas = betas
33
+ self.num_timesteps = betas.shape[0]
34
+
35
+ alphas = 1. - betas
36
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
37
+ self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]])
38
+
39
+ # calculate q(x_t | x_{t-1})
40
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
41
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
42
+
43
+ # calculate q(x_{t-1} | x_t, x_0)
44
+ self.posterior_mean_coef_1 = (torch.sqrt(self.alphas_cumprod_prev) * betas / (1. - self.alphas_cumprod))
45
+ self.posterior_mean_coef_2 = (torch.sqrt(alphas) * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod))
46
+ self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
47
+ self.posterior_log_variance = (torch.log(
48
+ torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]])
49
+ ))
50
+
51
+ self.percentile = percentile
52
+ self.time_scale = 1000 // self.num_timesteps
53
+ self.gen_noise = gen_noise
54
+
55
+ def q_sample(self, x_start, t, noise=None):
56
+ if noise is None:
57
+ noise = self.gen_noise(x_start)
58
+ sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
59
+ sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
60
+ x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise
61
+ return x_t
62
+
63
+ @torch.no_grad()
64
+ def p_sample_loop(
65
+ self, model, shape, device, dtype, lowres_img, times=[979, 729, 479, 229]
66
+ ):
67
+ img = torch.randn(*shape, device=device).to(dtype=dtype)
68
+ times = times + [0,]
69
+ times = list(zip(times[:-1], times[1:]))
70
+ for time, prev_time in tqdm(times):
71
+ time = torch.tensor([time] * shape[0], device=device)
72
+ x_t = self.q_sample(img, time)
73
+ img = model(x_t.to(dtype), time.to(dtype), lowres_img=lowres_img.to(dtype))
74
+ return img
75
+
76
+ @torch.no_grad()
77
+ def refine(self, model, img, **large_model_kwargs):
78
+ for time in tqdm([729, 479, 229]):
79
+ time = torch.tensor([time,] * img.shape[0], device=img.device)
80
+ x_t = self.q_sample(img, time)
81
+ img = model(x_t, time.type(x_t.dtype), **large_model_kwargs)
82
+ return img
83
+
84
+ def get_diffusion(conf):
85
+ betas = get_named_beta_schedule(**conf.schedule_params)
86
+ base_diffusion = BaseDiffusion(betas, **conf.diffusion_params)
87
+ return base_diffusion
KandiSuperRes/model/nn.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn, einsum
5
+ from einops import rearrange, repeat
6
+
7
+ from .utils import exist, set_default_layer
8
+
9
+
10
+ class Identity(nn.Module):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__()
13
+
14
+ @staticmethod
15
+ def forward(x, *args, **kwargs):
16
+ return x
17
+
18
+
19
+ class SinusoidalPosEmb_sr(nn.Module):
20
+
21
+ def __init__(self, dim):
22
+ super().__init__()
23
+ self.dim = dim
24
+
25
+ def forward(self, x):
26
+ half_dim = self.dim // 2
27
+ emb = math.log(10000) / (half_dim - 1)
28
+ emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
29
+ emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j').to(dtype=x.dtype)
30
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
31
+
32
+
33
+ class UpDownResolution(nn.Module):
34
+
35
+ def __init__(self, num_channels, up_resolution, change_type='conv'):
36
+ super().__init__()
37
+ if change_type == 'pooling':
38
+ self.change_resolution = set_default_layer(
39
+ up_resolution,
40
+ layer_1=nn.Upsample, kwargs_1={'scale_factor': 2., 'mode': 'nearest'},
41
+ layer_2=nn.AvgPool2d, kwargs_2={'kernel_size': 2, 'stride': 2}
42
+ )
43
+
44
+ elif change_type == 'conv':
45
+ self.change_resolution = set_default_layer(
46
+ up_resolution,
47
+ nn.ConvTranspose2d, (num_channels, num_channels), {'kernel_size': 4, 'stride': 2, 'padding': 1},
48
+ nn.Conv2d, (num_channels, num_channels), {'kernel_size': 4, 'stride': 2, 'padding': 1},
49
+ )
50
+ else:
51
+ raise NotImplementedError
52
+
53
+ def forward(self, x):
54
+ x = self.change_resolution(x)
55
+ return x
56
+
57
+ class SinusoidalPosEmb(nn.Module):
58
+
59
+ def __init__(self, dim):
60
+ super().__init__()
61
+ self.dim = dim
62
+
63
+ def forward(self, x):
64
+ half_dim = self.dim // 2
65
+ emb = math.log(10000) / (half_dim - 1)
66
+ emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
67
+ emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
68
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
69
+
70
+
71
+ class ConditionalGroupNorm(nn.Module):
72
+
73
+ def __init__(self, groups, normalized_shape, context_dim):
74
+ super().__init__()
75
+ self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
76
+ self.context_mlp = nn.Sequential(
77
+ nn.SiLU(),
78
+ nn.Linear(context_dim, 2 * normalized_shape)
79
+ )
80
+ self.context_mlp[1].weight.data.zero_()
81
+ self.context_mlp[1].bias.data.zero_()
82
+
83
+ def forward(self, x, context):
84
+ context = self.context_mlp(context)
85
+ ndims = ' 1' * len(x.shape[2:])
86
+ context = rearrange(context, f'b c -> b c{ndims}')
87
+
88
+ scale, shift = context.chunk(2, dim=1)
89
+ x = self.norm(x) * (scale + 1.) + shift
90
+ return x
91
+
92
+
93
+ class Attention(nn.Module):
94
+
95
+ def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
96
+ super().__init__()
97
+ assert out_channels % head_dim == 0
98
+ self.num_heads = out_channels // head_dim
99
+ self.scale = head_dim ** -0.5
100
+
101
+ self.to_query = nn.Linear(in_channels, out_channels, bias=False)
102
+ self.to_key = nn.Linear(context_dim, out_channels, bias=False)
103
+ self.to_value = nn.Linear(context_dim, out_channels, bias=False)
104
+
105
+ self.output_layer = nn.Linear(out_channels, out_channels, bias=False)
106
+
107
+ def forward(self, x, context, context_mask=None):
108
+ query = rearrange(self.to_query(x), 'b n (h d) -> b h n d', h=self.num_heads)
109
+ key = rearrange(self.to_key(context), 'b n (h d) -> b h n d', h=self.num_heads)
110
+ value = rearrange(self.to_value(context), 'b n (h d) -> b h n d', h=self.num_heads)
111
+
112
+ attention_matrix = einsum('b h i d, b h j d -> b h i j', query, key) * self.scale
113
+ if exist(context_mask):
114
+ max_neg_value = -torch.finfo(attention_matrix.dtype).max
115
+ context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
116
+ attention_matrix = attention_matrix.masked_fill(~context_mask, max_neg_value)
117
+ attention_matrix = attention_matrix.softmax(dim=-1)
118
+
119
+ out = einsum('b h i j, b h j d -> b h i d', attention_matrix, value)
120
+ out = rearrange(out, 'b h n d -> b n (h d)')
121
+ out = self.output_layer(out)
122
+ return out
KandiSuperRes/model/unet.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ from einops import rearrange
4
+
5
+ from .nn import Identity, Attention, SinusoidalPosEmb, ConditionalGroupNorm
6
+ from .utils import exist, set_default_item, set_default_layer
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class Block(nn.Module):
11
+
12
+ def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
13
+ super().__init__()
14
+ self.group_norm = ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
15
+ self.activation = nn.SiLU()
16
+ self.up_sample = set_default_layer(
17
+ exist(up_resolution) and up_resolution,
18
+ nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
19
+ )
20
+ padding = set_default_item(kernel_size == 1, 0, 1)
21
+ self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
22
+ self.down_sample = set_default_layer(
23
+ exist(up_resolution) and not up_resolution,
24
+ nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
25
+ )
26
+
27
+ def forward(self, x, time_embed):
28
+ x = self.group_norm(x, time_embed)
29
+ x = self.activation(x)
30
+ x = self.up_sample(x)
31
+ x = self.projection(x)
32
+ x = self.down_sample(x)
33
+ return x
34
+
35
+
36
+ class ResNetBlock(nn.Module):
37
+
38
+ def __init__(
39
+ self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4*[None]
40
+ ):
41
+ super().__init__()
42
+ kernel_sizes = [1, 3, 3, 1]
43
+ hidden_channel = max(in_channels, out_channels) // compression_ratio
44
+ hidden_channels = [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
45
+ self.resnet_blocks = nn.ModuleList([
46
+ Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
47
+ for (in_channel, out_channel), kernel_size, up_resolution in zip(hidden_channels, kernel_sizes, up_resolutions)
48
+ ])
49
+
50
+ self.shortcut_up_sample = set_default_layer(
51
+ True in up_resolutions,
52
+ nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
53
+ )
54
+ self.shortcut_projection = set_default_layer(
55
+ in_channels != out_channels,
56
+ nn.Conv2d, (in_channels, out_channels), {'kernel_size': 1}
57
+ )
58
+ self.shortcut_down_sample = set_default_layer(
59
+ False in up_resolutions,
60
+ nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
61
+ )
62
+
63
+ def forward(self, x, time_embed):
64
+ out = x
65
+ for resnet_block in self.resnet_blocks:
66
+ out = resnet_block(out, time_embed)
67
+
68
+ x = self.shortcut_up_sample(x)
69
+ x = self.shortcut_projection(x)
70
+ x = self.shortcut_down_sample(x)
71
+ x = x + out
72
+ return x
73
+
74
+
75
+ class AttentionPolling(nn.Module):
76
+
77
+ def __init__(self, num_channels, context_dim, head_dim=64):
78
+ super().__init__()
79
+ self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
80
+
81
+ def forward(self, x, context, context_mask=None):
82
+ context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
83
+ return x + context.squeeze(1)
84
+
85
+
86
+ class AttentionBlock(nn.Module):
87
+
88
+ def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
89
+ super().__init__()
90
+ self.in_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
91
+ self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
92
+
93
+ hidden_channels = expansion_ratio * num_channels
94
+ self.out_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
95
+ self.feed_forward = nn.Sequential(
96
+ nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
97
+ nn.SiLU(),
98
+ nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
99
+ )
100
+
101
+ def forward(self, x, time_embed, context=None, context_mask=None):
102
+ height, width = x.shape[-2:]
103
+ out = self.in_norm(x, time_embed)
104
+ out = rearrange(out, 'b c h w -> b (h w) c', h=height, w=width)
105
+ context = set_default_item(exist(context), context, out)
106
+ out = self.attention(out, context, context_mask)
107
+ out = rearrange(out, 'b (h w) c -> b c h w', h=height, w=width)
108
+ x = x + out
109
+
110
+ out = self.out_norm(x, time_embed)
111
+ out = self.feed_forward(out)
112
+ x = x + out
113
+ return x
114
+
115
+
116
+ class DownSampleBlock(nn.Module):
117
+
118
+ def __init__(
119
+ self, in_channels, out_channels, time_embed_dim, context_dim=None,
120
+ num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
121
+ down_sample=True, self_attention=True
122
+ ):
123
+ super().__init__()
124
+ self.self_attention_block = set_default_layer(
125
+ self_attention,
126
+ AttentionBlock,
127
+ (in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
128
+ layer_2=Identity
129
+ )
130
+
131
+ up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
132
+ hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
133
+ self.resnet_attn_blocks = nn.ModuleList([
134
+ nn.ModuleList([
135
+ ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
136
+ set_default_layer(
137
+ exist(context_dim),
138
+ AttentionBlock,
139
+ (out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
140
+ layer_2=Identity
141
+ ),
142
+ ResNetBlock(out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution),
143
+ ]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
144
+ ])
145
+
146
+ def forward(self, x, time_embed, context=None, context_mask=None, control_net_residual=None):
147
+ x = self.self_attention_block(x, time_embed)
148
+ for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
149
+ x = in_resnet_block(x, time_embed)
150
+ x = attention(x, time_embed, context, context_mask)
151
+ x = out_resnet_block(x, time_embed)
152
+ return x
153
+
154
+
155
+ class UpSampleBlock(nn.Module):
156
+
157
+ def __init__(
158
+ self, in_channels, cat_dim, out_channels, time_embed_dim, context_dim=None,
159
+ num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
160
+ up_sample=True, self_attention=True
161
+ ):
162
+ super().__init__()
163
+ up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
164
+ hidden_channels = [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_blocks - 2) + [(in_channels, out_channels)]
165
+ self.resnet_attn_blocks = nn.ModuleList([
166
+ nn.ModuleList([
167
+ ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution),
168
+ set_default_layer(
169
+ exist(context_dim),
170
+ AttentionBlock,
171
+ (in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
172
+ layer_2=Identity
173
+ ),
174
+ ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
175
+ ]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
176
+ ])
177
+
178
+ self.self_attention_block = set_default_layer(
179
+ self_attention,
180
+ AttentionBlock,
181
+ (out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
182
+ layer_2=Identity
183
+ )
184
+
185
+ def forward(self, x, time_embed, context=None, context_mask=None):
186
+ for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
187
+ x = in_resnet_block(x, time_embed)
188
+ x = attention(x, time_embed, context, context_mask)
189
+ x = out_resnet_block(x, time_embed)
190
+ x = self.self_attention_block(x, time_embed)
191
+ return x
192
+
193
+
194
+ class UNet(nn.Module):
195
+
196
+ def __init__(self,
197
+ model_channels,
198
+ init_channels=None,
199
+ num_channels=3,
200
+ out_channels=4,
201
+ time_embed_dim=None,
202
+ context_dim=None,
203
+ groups=32,
204
+ head_dim=64,
205
+ expansion_ratio=4,
206
+ compression_ratio=2,
207
+ dim_mult=(1, 2, 4, 8),
208
+ num_blocks=(3, 3, 3, 3),
209
+ add_cross_attention=(False, True, True, True),
210
+ add_self_attention=(False, True, True, True),
211
+ *args,
212
+ **kwargs,
213
+ ):
214
+ super().__init__()
215
+ init_channels = init_channels or model_channels
216
+
217
+ self.to_time_embed = nn.Sequential(
218
+ SinusoidalPosEmb(init_channels),
219
+ nn.Linear(init_channels, time_embed_dim),
220
+ nn.SiLU(),
221
+ nn.Linear(time_embed_dim, time_embed_dim)
222
+ )
223
+ self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
224
+
225
+ self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
226
+
227
+ hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
228
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
229
+ text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
230
+ layer_params = [num_blocks, text_dims, add_self_attention]
231
+ rev_layer_params = map(reversed, layer_params)
232
+
233
+ cat_dims = []
234
+ self.num_levels = len(in_out_dims)
235
+ self.down_samples = nn.ModuleList([])
236
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
237
+ down_sample = level != (self.num_levels - 1)
238
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
239
+ self.down_samples.append(
240
+ DownSampleBlock(
241
+ in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
242
+ compression_ratio, down_sample, self_attention
243
+ )
244
+ )
245
+
246
+ self.up_samples = nn.ModuleList([])
247
+ for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
248
+ up_sample = level != 0
249
+ self.up_samples.append(
250
+ UpSampleBlock(
251
+ in_dim, cat_dims.pop(), out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim,
252
+ expansion_ratio, compression_ratio, up_sample, self_attention
253
+ )
254
+ )
255
+
256
+ self.out_layer = nn.Sequential(
257
+ nn.GroupNorm(groups, init_channels),
258
+ nn.SiLU(),
259
+ nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
260
+ )
261
+
262
+ def forward(self, x, time, context=None, context_mask=None, is_text=None, null_embedding=None, control_net_residual=None):
263
+
264
+ time_embed = self.to_time_embed(time)
265
+ if exist(context):
266
+ time_embed = self.feature_pooling(time_embed, context, context_mask)
267
+
268
+ hidden_states = []
269
+ x = self.in_layer(x)
270
+ for level, down_sample in enumerate(self.down_samples):
271
+ x = down_sample(x, time_embed, context, context_mask, control_net_residual)
272
+ if level != self.num_levels - 1:
273
+ hidden_states.append(x)
274
+ for level, up_sample in enumerate(self.up_samples):
275
+ if level != 0:
276
+ x = torch.cat([x, hidden_states.pop()], dim=1)
277
+ x = up_sample(x, time_embed, context, context_mask)
278
+ x = self.out_layer(x)
279
+ return x
280
+
281
+
282
+ def get_unet(conf):
283
+ unet = UNet(**conf)
284
+ return unet
KandiSuperRes/model/unet_sr.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from einops import rearrange
4
+ from .nn import Identity, Attention, SinusoidalPosEmb, UpDownResolution
5
+ from .utils import exist, set_default_item, set_default_layer
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Block(nn.Module):
10
+
11
+ def __init__(self, in_channels, out_channels, time_embed_dim=None, groups=32, activation=None, up_resolution=None, dropout=None):
12
+ super().__init__()
13
+ self.group_norm = nn.GroupNorm(groups, in_channels)
14
+ self.activation = set_default_layer(
15
+ exist(activation),
16
+ nn.SiLU
17
+ )
18
+ self.change_resolution = set_default_layer(
19
+ exist(up_resolution),
20
+ UpDownResolution, (in_channels, up_resolution)
21
+ )
22
+ self.dropout = set_default_layer(
23
+ exist(dropout),
24
+ nn.Dropout, (), {'p': 0.1}
25
+ )
26
+ self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
27
+
28
+ def forward(self, x, scale_shift=None):
29
+ x = self.group_norm(x)
30
+ if exist(scale_shift):
31
+ scale, shift = scale_shift
32
+ x = x * (scale + 1) + shift
33
+ x = self.activation(x)
34
+ x = self.dropout(x)
35
+ x = self.change_resolution(x)
36
+ x = self.projection(x)
37
+ return x
38
+
39
+
40
+ class ResNetBlock(nn.Module):
41
+
42
+ def __init__(self, in_channels, out_channels, time_embed_dim=None, groups=32, up_resolution=None):
43
+ super().__init__()
44
+ self.time_mlp = set_default_item(
45
+ exist(time_embed_dim),
46
+ nn.Sequential(
47
+ nn.SiLU(),
48
+ nn.Linear(time_embed_dim, 2 * out_channels)
49
+ )
50
+ )
51
+ self.in_block = Block(in_channels, out_channels, time_embed_dim, groups, up_resolution=up_resolution)
52
+ self.out_block = Block(out_channels, out_channels, time_embed_dim, groups, activation=True, up_resolution=None, dropout=True)
53
+
54
+ self.change_resolution = set_default_layer(
55
+ exist(up_resolution),
56
+ UpDownResolution, (in_channels, up_resolution)
57
+ )
58
+ self.res_block = set_default_layer(
59
+ in_channels != out_channels or exist(up_resolution),
60
+ nn.Conv2d, (in_channels, out_channels), {'kernel_size': 1}
61
+ )
62
+
63
+ def forward(self, x, time_embed=None):
64
+ scale_shift = None
65
+ if exist(time_embed) and exist(self.time_mlp):
66
+ time_embed = self.time_mlp(time_embed)
67
+ time_embed = rearrange(time_embed, 'b c -> b c 1 1')
68
+ scale_shift = time_embed.chunk(2, dim=1)
69
+ out = self.in_block(x)
70
+ out = self.out_block(out, scale_shift=scale_shift)
71
+ x = self.change_resolution(x)
72
+ out = out + self.res_block(x)
73
+ return out
74
+
75
+
76
+ class AttentionBlock(nn.Module):
77
+
78
+ def __init__(
79
+ self, dim, context_dim=None, groups=32, num_heads=8, num_conditions=1, feed_forward_mult=2
80
+ ):
81
+ super().__init__()
82
+ self.in_norm = nn.GroupNorm(groups, dim)
83
+ self.attention = Attention(
84
+ dim, context_dim or dim, num_heads, num_conditions=num_conditions
85
+ )
86
+
87
+ hidden_dim = feed_forward_mult * dim
88
+ self.out_norm = nn.GroupNorm(groups, dim)
89
+ self.feed_forward = nn.Sequential(
90
+ nn.Conv2d(dim, hidden_dim, kernel_size=1, bias=False),
91
+ nn.SiLU(),
92
+ nn.Conv2d(hidden_dim, dim, kernel_size=1, bias=False),
93
+ )
94
+
95
+ def forward(self, x, context=None, context_mask=None, context_idx=None):
96
+ width = x.shape[-1]
97
+ out = self.in_norm(x)
98
+ out = rearrange(out, 'b c h w -> b (h w) c')
99
+ context = set_default_item(exist(context), context, out)
100
+ out = self.attention(out, context, context_mask, context_idx)
101
+ out = rearrange(out, 'b (h w) c -> b c h w', w=width)
102
+ x = x + out
103
+
104
+ out = self.out_norm(x)
105
+ out = self.feed_forward(out)
106
+ x = x + out
107
+ return x
108
+
109
+
110
+ class DownSampleBlock(nn.Module):
111
+
112
+ def __init__(
113
+ self, in_channels, out_channels, time_embed_dim,
114
+ num_resnet_blocks=3, groups=32, down_sample=True, context_dim=None, self_attention=True, num_conditions=1):
115
+ super().__init__()
116
+ up_resolutions = [set_default_item(down_sample, False)] + [None] * (num_resnet_blocks - 1)
117
+ hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_resnet_blocks - 1)
118
+ self.resnet_attn_blocks = nn.ModuleList([
119
+ nn.ModuleList([
120
+ ResNetBlock(in_channel, out_channel, time_embed_dim, groups, up_resolution),
121
+ set_default_layer(
122
+ exist(context_dim),
123
+ AttentionBlock, (out_channel, context_dim), {'num_conditions': num_conditions, 'groups': groups},
124
+ layer_2=Identity
125
+ )
126
+ ]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
127
+ ])
128
+
129
+ self.self_attention_block = set_default_layer(
130
+ self_attention,
131
+ AttentionBlock, (out_channels,), {'feed_forward_mult': 4, 'groups': groups},
132
+ layer_2=Identity
133
+ )
134
+
135
+ def forward(self, x, time_embed, context=None, context_mask=None, context_idx=None):
136
+ for resnet_block, attention in self.resnet_attn_blocks:
137
+ x = resnet_block(x, time_embed)
138
+ x = attention(x, context, context_mask, context_idx)
139
+ x = self.self_attention_block(x)
140
+ return x
141
+
142
+
143
+ class UpSampleBlock(nn.Module):
144
+
145
+ def __init__(
146
+ self, in_channels, cat_dim, out_channels, time_embed_dim,
147
+ num_resnet_blocks=3, groups=32, up_sample=True, context_dim=None, self_attention=True, num_conditions=1):
148
+ super().__init__()
149
+ up_resolutions = [None] * (num_resnet_blocks - 1) + [set_default_item(up_sample, True)]
150
+ hidden_channels = [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_resnet_blocks - 2) + [(in_channels, out_channels)]
151
+ self.resnet_attn_blocks = nn.ModuleList([
152
+ nn.ModuleList([
153
+ ResNetBlock(in_channel, out_channel, time_embed_dim, groups, up_resolution),
154
+ set_default_layer(
155
+ exist(context_dim),
156
+ AttentionBlock, (out_channel, context_dim), {'num_conditions': num_conditions, 'groups': groups, 'feed_forward_mult': 4},
157
+ layer_2=Identity
158
+ )
159
+ ]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
160
+ ])
161
+
162
+ self.self_attention_block = set_default_layer(
163
+ self_attention,
164
+ AttentionBlock, (out_channels,), {'feed_forward_mult': 4, 'groups': groups},
165
+ layer_2=Identity
166
+ )
167
+
168
+ def forward(self, x, time_embed, context=None, context_mask=None, context_idx=None):
169
+ for resnet_block, attention in self.resnet_attn_blocks:
170
+ x = resnet_block(x, time_embed)
171
+ x = attention(x, context, context_mask, context_idx)
172
+ x = self.self_attention_block(x)
173
+ return x
174
+
175
+
176
+ class UNet(nn.Module):
177
+
178
+ def __init__(self,
179
+ model_channels,
180
+ init_channels=128,
181
+ num_channels=3,
182
+ time_embed_dim=512,
183
+ context_dim=None,
184
+ groups=32,
185
+ feature_pooling_type='attention',
186
+ dim_mult=(1, 2, 4, 8),
187
+ num_resnet_blocks=(2, 4, 8, 8),
188
+ num_conditions=1,
189
+ skip_connect_scale=1.,
190
+ add_cross_attention=(False, False, False, False),
191
+ add_self_attention=(False, False, False, False),
192
+ lowres_cond=True,
193
+ ):
194
+ super().__init__()
195
+ out_channels = num_channels
196
+ num_channels = set_default_item(lowres_cond, num_channels * 2, num_channels)
197
+ init_channels = init_channels or model_channels
198
+ self.num_conditions = num_conditions
199
+ self.skip_connect_scale = skip_connect_scale
200
+ self.to_time_embed = nn.Sequential(
201
+ SinusoidalPosEmb(init_channels),
202
+ nn.Linear(init_channels, time_embed_dim),
203
+ nn.SiLU(),
204
+ nn.Linear(time_embed_dim, time_embed_dim)
205
+ )
206
+
207
+ self.init_conv = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
208
+
209
+ hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
210
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
211
+ text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
212
+ layer_params = [num_resnet_blocks, text_dims, add_self_attention]
213
+ rev_layer_params = map(reversed, layer_params)
214
+
215
+ cat_dims = []
216
+ self.num_levels = len(in_out_dims)
217
+ self.down_samples = nn.ModuleList([])
218
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
219
+ down_sample = level != (self.num_levels - 1)
220
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
221
+ self.down_samples.append(
222
+ DownSampleBlock(
223
+ in_dim, out_dim, time_embed_dim, res_block_num, groups, down_sample, text_dim, self_attention, num_conditions
224
+ )
225
+ )
226
+
227
+ self.up_samples = nn.ModuleList([])
228
+ for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
229
+ up_sample = level != 0
230
+ self.up_samples.append(
231
+ UpSampleBlock(
232
+ in_dim, cat_dims.pop(), out_dim, time_embed_dim, res_block_num, groups, up_sample, text_dim, self_attention, num_conditions
233
+ )
234
+ )
235
+
236
+ self.norm = nn.GroupNorm(groups, init_channels)
237
+ self.activation = nn.SiLU()
238
+ self.out_conv = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
239
+
240
+ def forward(self, x, time, context=None, context_mask=None, context_idx=None, lowres_img=None):
241
+ if exist(lowres_img):
242
+ _, _, new_height, new_width = x.shape
243
+ upsampled = F.interpolate(lowres_img, (new_height, new_width), mode="bilinear")
244
+ x = torch.cat([x, upsampled], dim=1)
245
+ time_embed = self.to_time_embed(time)
246
+
247
+ hidden_states = []
248
+ x = self.init_conv(x)
249
+ for level, down_sample in enumerate(self.down_samples):
250
+ x = down_sample(x, time_embed, context, context_mask, context_idx)
251
+ if level != self.num_levels - 1:
252
+ hidden_states.append(x)
253
+ for level, up_sample in enumerate(self.up_samples):
254
+ if level != 0:
255
+ x = torch.cat([x, hidden_states.pop() / self.skip_connect_scale], dim=1)
256
+ x = up_sample(x, time_embed, context, context_mask, context_idx)
257
+ x = self.norm(x)
258
+ x = self.activation(x)
259
+ x = self.out_conv(x)
260
+ return x
KandiSuperRes/model/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Identity
2
+ from einops import rearrange
3
+
4
+
5
+ def exist(item):
6
+ return item is not None
7
+
8
+
9
+ def set_default_item(condition, item_1, item_2=None):
10
+ if condition:
11
+ return item_1
12
+ else:
13
+ return item_2
14
+
15
+
16
+ def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=Identity, args_2=[], kwargs_2={}):
17
+ if condition:
18
+ return layer_1(*args_1, **kwargs_1)
19
+ else:
20
+ return layer_2(*args_2, **kwargs_2)
21
+
22
+
23
+ def get_tensor_items(x, pos, broadcast_shape):
24
+ device = pos.device
25
+ bs = pos.shape[0]
26
+ ndims = len(broadcast_shape[1:])
27
+ x = x.cpu()[pos.cpu()]
28
+ return x.reshape(bs, *((1,) * ndims)).to(device)
29
+
30
+
31
+ def local_patching(x, height, width, group_size):
32
+ if group_size > 0:
33
+ x = rearrange(
34
+ x, 'b c (h g1) (w g2) -> b (h w) (g1 g2) c',
35
+ h=height//group_size, w=width//group_size, g1=group_size, g2=group_size
36
+ )
37
+ else:
38
+ x = rearrange(x, 'b c h w -> b (h w) c', h=height, w=width)
39
+ return x
40
+
41
+
42
+ def local_merge(x, height, width, group_size):
43
+ if group_size > 0:
44
+ x = rearrange(
45
+ x, 'b (h w) (g1 g2) c -> b c (h g1) (w g2)',
46
+ h=height//group_size, w=width//group_size, g1=group_size, g2=group_size
47
+ )
48
+ else:
49
+ x = rearrange(x, 'b (h w) c -> b c h w', h=height, w=width)
50
+ return x
51
+
52
+
53
+ def global_patching(x, height, width, group_size):
54
+ x = local_patching(x, height, width, height//group_size)
55
+ x = x.transpose(-2, -3)
56
+ return x
57
+
58
+
59
+ def global_merge(x, height, width, group_size):
60
+ x = x.transpose(-2, -3)
61
+ x = local_merge(x, height, width, height//group_size)
62
+ return x
KandiSuperRes/movq.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+ from .utils import freeze
8
+ from tqdm import tqdm
9
+
10
+ import time
11
+
12
+ def nonlinearity(x):
13
+ return x*torch.sigmoid(x)
14
+
15
+
16
+ class SpatialNorm(nn.Module):
17
+ def __init__(
18
+ self, f_channels, zq_channels=None, norm_layer=nn.GroupNorm, freeze_norm_layer=False, add_conv=False, **norm_layer_params
19
+ ):
20
+ super().__init__()
21
+ self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
22
+ if zq_channels is not None:
23
+ if freeze_norm_layer:
24
+ for p in self.norm_layer.parameters:
25
+ p.requires_grad = False
26
+ self.add_conv = add_conv
27
+ if self.add_conv:
28
+ self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
29
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
30
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
31
+ def forward(self, f, zq=None):
32
+ norm_f = self.norm_layer(f)
33
+ if zq is not None:
34
+ f_size = f.shape[-2:]
35
+ zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
36
+ if self.add_conv:
37
+ zq = self.conv(zq)
38
+ norm_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
39
+ return norm_f
40
+
41
+
42
+ def Normalize(in_channels, zq_ch=None, add_conv=None):
43
+ return SpatialNorm(
44
+ in_channels, zq_ch, norm_layer=nn.GroupNorm,
45
+ freeze_norm_layer=False, add_conv=add_conv, num_groups=32, eps=1e-6, affine=True
46
+ )
47
+
48
+
49
+ class Upsample(nn.Module):
50
+ def __init__(self, in_channels, with_conv):
51
+ super().__init__()
52
+ self.with_conv = with_conv
53
+ if self.with_conv:
54
+ self.conv = torch.nn.Conv2d(in_channels,
55
+ in_channels,
56
+ kernel_size=3,
57
+ stride=1,
58
+ padding=1)
59
+
60
+ def forward(self, x):
61
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
62
+ if self.with_conv:
63
+ x = self.conv(x)
64
+ return x
65
+
66
+
67
+ class Downsample(nn.Module):
68
+ def __init__(self, in_channels, with_conv):
69
+ super().__init__()
70
+ self.with_conv = with_conv
71
+ if self.with_conv:
72
+ self.conv = torch.nn.Conv2d(in_channels,
73
+ in_channels,
74
+ kernel_size=3,
75
+ stride=2,
76
+ padding=0)
77
+
78
+ def forward(self, x):
79
+ if self.with_conv:
80
+ pad = (0,1,0,1)
81
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
82
+ x = self.conv(x)
83
+ else:
84
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
85
+ return x
86
+
87
+
88
+ class ResnetBlock(nn.Module):
89
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
90
+ dropout, temb_channels=512, zq_ch=None, add_conv=False):
91
+ super().__init__()
92
+ self.in_channels = in_channels
93
+ out_channels = in_channels if out_channels is None else out_channels
94
+ self.out_channels = out_channels
95
+ self.use_conv_shortcut = conv_shortcut
96
+
97
+ self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
98
+ self.conv1 = torch.nn.Conv2d(in_channels,
99
+ out_channels,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=1)
103
+ if temb_channels > 0:
104
+ self.temb_proj = torch.nn.Linear(temb_channels,
105
+ out_channels)
106
+ self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
107
+ self.dropout = torch.nn.Dropout(dropout)
108
+ self.conv2 = torch.nn.Conv2d(out_channels,
109
+ out_channels,
110
+ kernel_size=3,
111
+ stride=1,
112
+ padding=1)
113
+ if self.in_channels != self.out_channels:
114
+ if self.use_conv_shortcut:
115
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=3,
118
+ stride=1,
119
+ padding=1)
120
+ else:
121
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
122
+ out_channels,
123
+ kernel_size=1,
124
+ stride=1,
125
+ padding=0)
126
+
127
+ def forward(self, x, temb, zq=None):
128
+ h = x
129
+ h = self.norm1(h, zq)
130
+ h = nonlinearity(h)
131
+ h = self.conv1(h)
132
+
133
+ if temb is not None:
134
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
135
+
136
+ h = self.norm2(h, zq)
137
+ h = nonlinearity(h)
138
+ h = self.dropout(h)
139
+ h = self.conv2(h)
140
+
141
+ if self.in_channels != self.out_channels:
142
+ if self.use_conv_shortcut:
143
+ x = self.conv_shortcut(x)
144
+ else:
145
+ x = self.nin_shortcut(x)
146
+ return x+h
147
+
148
+
149
+ class AttnBlock(nn.Module):
150
+ def __init__(self, in_channels, zq_ch=None, add_conv=False):
151
+ super().__init__()
152
+ self.in_channels = in_channels
153
+
154
+ self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
155
+ self.q = torch.nn.Conv2d(in_channels,
156
+ in_channels,
157
+ kernel_size=1,
158
+ stride=1,
159
+ padding=0)
160
+ self.k = torch.nn.Conv2d(in_channels,
161
+ in_channels,
162
+ kernel_size=1,
163
+ stride=1,
164
+ padding=0)
165
+ self.v = torch.nn.Conv2d(in_channels,
166
+ in_channels,
167
+ kernel_size=1,
168
+ stride=1,
169
+ padding=0)
170
+ self.proj_out = torch.nn.Conv2d(in_channels,
171
+ in_channels,
172
+ kernel_size=1,
173
+ stride=1,
174
+ padding=0)
175
+
176
+
177
+ def forward(self, x, zq=None):
178
+ h_ = x
179
+ h_ = self.norm(h_, zq)
180
+ q = self.q(h_)
181
+ k = self.k(h_)
182
+ v = self.v(h_)
183
+
184
+ # compute attention
185
+ b,c,h,w = q.shape
186
+ q = q.reshape(b,c,h*w)
187
+ q = q.permute(0,2,1) # b,hw,c
188
+ k = k.reshape(b,c,h*w) # b,c,hw
189
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
190
+ w_ = w_ * (int(c)**(-0.5))
191
+ w_ = torch.nn.functional.softmax(w_, dim=2)
192
+
193
+ # attend to values
194
+ v = v.reshape(b,c,h*w)
195
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
196
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
197
+ h_ = h_.reshape(b,c,h,w)
198
+
199
+ h_ = self.proj_out(h_)
200
+ return x+h_
201
+
202
+
203
+ class Encoder(nn.Module):
204
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
205
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
206
+ resolution, z_channels, double_z=True, **ignore_kwargs):
207
+ super().__init__()
208
+ self.ch = ch
209
+ self.temb_ch = 0
210
+ self.num_resolutions = len(ch_mult)
211
+ self.num_res_blocks = num_res_blocks
212
+ self.resolution = resolution
213
+ self.in_channels = in_channels
214
+
215
+ # downsampling
216
+ self.conv_in = torch.nn.Conv2d(in_channels,
217
+ self.ch,
218
+ kernel_size=3,
219
+ stride=1,
220
+ padding=1)
221
+
222
+ curr_res = resolution
223
+ in_ch_mult = (1,)+tuple(ch_mult)
224
+ self.down = nn.ModuleList()
225
+ for i_level in range(self.num_resolutions):
226
+ block = nn.ModuleList()
227
+ attn = nn.ModuleList()
228
+ block_in = ch*in_ch_mult[i_level]
229
+ block_out = ch*ch_mult[i_level]
230
+ for i_block in range(self.num_res_blocks):
231
+ block.append(ResnetBlock(in_channels=block_in,
232
+ out_channels=block_out,
233
+ temb_channels=self.temb_ch,
234
+ dropout=dropout))
235
+ block_in = block_out
236
+ if curr_res in attn_resolutions:
237
+ attn.append(AttnBlock(block_in))
238
+ down = nn.Module()
239
+ down.block = block
240
+ down.attn = attn
241
+ if i_level != self.num_resolutions-1:
242
+ down.downsample = Downsample(block_in, resamp_with_conv)
243
+ curr_res = curr_res // 2
244
+ self.down.append(down)
245
+
246
+ # middle
247
+ self.mid = nn.Module()
248
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
249
+ out_channels=block_in,
250
+ temb_channels=self.temb_ch,
251
+ dropout=dropout)
252
+ self.mid.attn_1 = AttnBlock(block_in)
253
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
254
+ out_channels=block_in,
255
+ temb_channels=self.temb_ch,
256
+ dropout=dropout)
257
+
258
+ # end
259
+ self.norm_out = Normalize(block_in)
260
+ self.conv_out = torch.nn.Conv2d(block_in,
261
+ 2*z_channels if double_z else z_channels,
262
+ kernel_size=3,
263
+ stride=1,
264
+ padding=1)
265
+
266
+
267
+ def forward(self, x):
268
+ temb = None
269
+
270
+ # downsampling
271
+ hs = [self.conv_in(x)]
272
+ for i_level in range(self.num_resolutions):
273
+ for i_block in range(self.num_res_blocks):
274
+ h = self.down[i_level].block[i_block](hs[-1], temb)
275
+ if len(self.down[i_level].attn) > 0:
276
+ h = self.down[i_level].attn[i_block](h)
277
+ hs.append(h)
278
+ if i_level != self.num_resolutions-1:
279
+ hs.append(self.down[i_level].downsample(hs[-1]))
280
+
281
+ # middle
282
+ h = hs[-1]
283
+ h = self.mid.block_1(h, temb)
284
+ h = self.mid.attn_1(h)
285
+ h = self.mid.block_2(h, temb)
286
+
287
+ # end
288
+ h = self.norm_out(h)
289
+ h = nonlinearity(h)
290
+ h = self.conv_out(h)
291
+ return h
292
+
293
+
294
+ class Decoder(nn.Module):
295
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
296
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
297
+ resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, **ignorekwargs):
298
+ super().__init__()
299
+ self.ch = ch
300
+ self.temb_ch = 0
301
+ self.num_resolutions = len(ch_mult)
302
+ self.num_res_blocks = num_res_blocks
303
+ self.resolution = resolution
304
+ self.in_channels = in_channels
305
+ self.give_pre_end = give_pre_end
306
+
307
+ # compute in_ch_mult, block_in and curr_res at lowest res
308
+ in_ch_mult = (1,)+tuple(ch_mult)
309
+ block_in = ch*ch_mult[self.num_resolutions-1]
310
+ curr_res = resolution // 2**(self.num_resolutions-1)
311
+ self.z_shape = (1,z_channels,curr_res,curr_res)
312
+
313
+ # z to block_in
314
+ self.conv_in = torch.nn.Conv2d(z_channels,
315
+ block_in,
316
+ kernel_size=3,
317
+ stride=1,
318
+ padding=1)
319
+
320
+ # middle
321
+ self.mid = nn.Module()
322
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
323
+ out_channels=block_in,
324
+ temb_channels=self.temb_ch,
325
+ dropout=dropout,
326
+ zq_ch=zq_ch,
327
+ add_conv=add_conv)
328
+ self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
329
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
330
+ out_channels=block_in,
331
+ temb_channels=self.temb_ch,
332
+ dropout=dropout,
333
+ zq_ch=zq_ch,
334
+ add_conv=add_conv)
335
+
336
+ # upsampling
337
+ self.up = nn.ModuleList()
338
+ for i_level in reversed(range(self.num_resolutions)):
339
+ block = nn.ModuleList()
340
+ attn = nn.ModuleList()
341
+ block_out = ch*ch_mult[i_level]
342
+ for i_block in range(self.num_res_blocks+1):
343
+ block.append(ResnetBlock(in_channels=block_in,
344
+ out_channels=block_out,
345
+ temb_channels=self.temb_ch,
346
+ dropout=dropout,
347
+ zq_ch=zq_ch,
348
+ add_conv=add_conv))
349
+ block_in = block_out
350
+ if curr_res in attn_resolutions:
351
+ attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
352
+ up = nn.Module()
353
+ up.block = block
354
+ up.attn = attn
355
+ if i_level != 0:
356
+ up.upsample = Upsample(block_in, resamp_with_conv)
357
+ curr_res = curr_res * 2
358
+ self.up.insert(0, up) # prepend to get consistent order
359
+
360
+ # end
361
+ self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
362
+ self.conv_out = torch.nn.Conv2d(block_in,
363
+ out_ch,
364
+ kernel_size=3,
365
+ stride=1,
366
+ padding=1)
367
+
368
+ def forward(self, z, zq):
369
+ #assert z.shape[1:] == self.z_shape[1:]
370
+ self.last_z_shape = z.shape
371
+
372
+ # timestep embedding
373
+ temb = None
374
+
375
+ # z to block_in
376
+ h = self.conv_in(z)
377
+
378
+ # middle
379
+ h = self.mid.block_1(h, temb, zq)
380
+ h = self.mid.attn_1(h, zq)
381
+ h = self.mid.block_2(h, temb, zq)
382
+
383
+ # upsampling
384
+ for i_level in reversed(range(self.num_resolutions)):
385
+ for i_block in range(self.num_res_blocks+1):
386
+ h = self.up[i_level].block[i_block](h, temb, zq)
387
+ if len(self.up[i_level].attn) > 0:
388
+ h = self.up[i_level].attn[i_block](h, zq)
389
+ if i_level != 0:
390
+ h = self.up[i_level].upsample(h)
391
+
392
+ # end
393
+ if self.give_pre_end:
394
+ return h
395
+
396
+ h = self.norm_out(h, zq)
397
+ h = nonlinearity(h)
398
+ h = self.conv_out(h)
399
+ return h
400
+
401
+
402
+ class MoVQ(nn.Module):
403
+
404
+ def __init__(self, generator_params):
405
+ super().__init__()
406
+ z_channels = generator_params["z_channels"]
407
+ self.encoder = Encoder(**generator_params)
408
+ self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
409
+ self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
410
+ self.decoder = Decoder(zq_ch=z_channels, **generator_params)
411
+
412
+ self.tile_sample_min_size = generator_params["tile_sample_min_size"]
413
+ self.scale_factor = 8
414
+ self.tile_latent_min_size = int(self.tile_sample_min_size / self.scale_factor)
415
+ self.tile_overlap_factor_enc = generator_params["tile_overlap_factor_enc"]
416
+ self.tile_overlap_factor_dec = generator_params["tile_overlap_factor_dec"]
417
+ self.use_tiling = generator_params["use_tiling"]
418
+
419
+ @torch.no_grad()
420
+ def encode(self, x):
421
+ if self.use_tiling and (
422
+ x.shape[-1] > self.tile_sample_min_size
423
+ or x.shape[-2] > self.tile_sample_min_size
424
+ ):
425
+ print('tiled_encode')
426
+ return self.tiled_encode(x)
427
+ h = self.encoder(x)
428
+ h = self.quant_conv(h)
429
+ return h
430
+
431
+ @torch.no_grad()
432
+ def decode(self, quant):
433
+ if self.use_tiling and (
434
+ quant.shape[-1] > self.tile_latent_min_size
435
+ or quant.shape[-2] > self.tile_latent_min_size
436
+ ):
437
+ print('tiled_decode')
438
+ return self.tiled_decode(quant)
439
+ decoder_input = self.post_quant_conv(quant)
440
+ decoded = self.decoder(decoder_input, quant)
441
+ return decoded
442
+
443
+ def blend_v(
444
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
445
+ ) -> torch.Tensor:
446
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
447
+ for y in range(blend_extent):
448
+ b[ :, :, y, :] = a[ :, :, -blend_extent + y, :] * (
449
+ 1 - y / blend_extent
450
+ ) + b[ :, :, y, :] * (y / blend_extent)
451
+ return b
452
+
453
+ def blend_h(
454
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
455
+ ) -> torch.Tensor:
456
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
457
+ for x in range(blend_extent):
458
+ b[ :, :, :, x] = a[ :, :, :, -blend_extent + x] * (
459
+ 1 - x / blend_extent
460
+ ) + b[ :, :, :, x] * (x / blend_extent)
461
+ return b
462
+
463
+ def tiled_encode(self, x):
464
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor_enc))
465
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor_enc)
466
+ row_limit = self.tile_latent_min_size - blend_extent
467
+
468
+ # Split the image into tiles and encode them separately.
469
+ rows = []
470
+ for i in tqdm(range(0, x.shape[2], overlap_size)):
471
+ row = []
472
+ for j in range(0, x.shape[3], overlap_size):
473
+ tile = x[
474
+ :,
475
+ :,
476
+ i : i + self.tile_sample_min_size,
477
+ j : j + self.tile_sample_min_size,
478
+ ]
479
+ tile = self.encode(tile)
480
+ row.append(tile)
481
+ rows.append(row)
482
+ result_rows = []
483
+ for i, row in enumerate(rows):
484
+ result_row = []
485
+ for j, tile in enumerate(row):
486
+ # blend the above tile and the left tile
487
+ # to the current tile and add the current tile to the result row
488
+ if i > 0:
489
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
490
+ if j > 0:
491
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
492
+ result_row.append(tile[ :, :, :row_limit, :row_limit])
493
+ result_rows.append(torch.cat(result_row, dim=3))
494
+
495
+ h = torch.cat(result_rows, dim=2)
496
+ return h
497
+
498
+ def tiled_decode(self, z):
499
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor_dec))
500
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor_dec)
501
+ row_limit = self.tile_sample_min_size - blend_extent
502
+
503
+ # Split z into overlapping tiles and decode them separately.
504
+ # The tiles have an overlap to avoid seams between tiles.
505
+ rows = []
506
+ for i in tqdm(range(0, z.shape[2], overlap_size)):
507
+ row = []
508
+ for j in range(0, z.shape[3], overlap_size):
509
+ tile = z[
510
+ :,
511
+ :,
512
+ i : i + self.tile_latent_min_size,
513
+ j : j + self.tile_latent_min_size,
514
+ ]
515
+ decoded = self.decode(tile)
516
+ row.append(decoded)
517
+ rows.append(row)
518
+ result_rows = []
519
+ for i, row in enumerate(rows):
520
+ result_row = []
521
+ for j, tile in enumerate(row):
522
+ # blend the above tile and the left tile
523
+ # to the current tile and add the current tile to the result row
524
+ if i > 0:
525
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
526
+ if j > 0:
527
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
528
+ result_row.append(tile[ :, :, :row_limit, :row_limit])
529
+ result_rows.append(torch.cat(result_row, dim=3))
530
+
531
+ dec = torch.cat(result_rows, dim=2)
532
+ return dec
533
+
534
+
535
+ def get_vae(conf):
536
+ movq = MoVQ(conf.params)
537
+ if conf.checkpoint is not None:
538
+ movq_state_dict = torch.load(conf.checkpoint)
539
+ movq.load_state_dict(movq_state_dict)
540
+ movq = freeze(movq)
541
+ return movq
KandiSuperRes/sr_pipeline.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import PIL
4
+ import torchvision.transforms as T
5
+ import torch.nn.functional as F
6
+ from KandiSuperRes.model.unet import UNet
7
+ from KandiSuperRes.model.unet_sr import UNet as UNet_sr
8
+ from KandiSuperRes.movq import MoVQ
9
+ from KandiSuperRes.model.diffusion_sr import DPMSolver
10
+ from KandiSuperRes.model.diffusion_refine import BaseDiffusion, get_named_beta_schedule
11
+ from KandiSuperRes.model.diffusion_sr_turbo import BaseDiffusion as BaseDiffusion_turbo
12
+
13
+
14
+ class KandiSuperResPipeline:
15
+
16
+ def __init__(
17
+ self,
18
+ scale: int,
19
+ device: str,
20
+ dtype: str,
21
+ flash: bool,
22
+ sr_model: UNet_sr,
23
+ movq: MoVQ = None,
24
+ refiner: UNet = None,
25
+ ):
26
+ self.device = device
27
+ self.dtype = dtype
28
+ self.scale = scale
29
+ self.flash = flash
30
+ self.to_pil = T.ToPILImage()
31
+ self.image_transform = T.Compose([
32
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
33
+ T.ToTensor(),
34
+ T.Lambda(lambda img: 2. * img - 1.),
35
+ ])
36
+
37
+ self.sr_model = sr_model
38
+ self.movq = movq
39
+ self.refiner = refiner
40
+
41
+ def __call__(
42
+ self,
43
+ pil_image: PIL.Image.Image = None,
44
+ steps: int = 5,
45
+ view_batch_size: int = 15,
46
+ seed: int = 0,
47
+ refine=True
48
+ ) -> PIL.Image.Image:
49
+
50
+ if self.flash:
51
+ betas_turbo = get_named_beta_schedule('linear', 1000)
52
+ base_diffusion_sr = BaseDiffusion_turbo(betas_turbo)
53
+
54
+ old_height = pil_image.size[1]
55
+ old_width = pil_image.size[0]
56
+ height = int(old_height-np.mod(old_height,32))
57
+ width = int(old_width-np.mod(old_width,32))
58
+
59
+ pil_image = pil_image.resize((width,height))
60
+ lr_image = self.image_transform(pil_image).unsqueeze(0).to(self.device['sr_model'])
61
+
62
+ sr_image = base_diffusion_sr.p_sample_loop(
63
+ self.sr_model, (1, 3, height*self.scale, width*self.scale), self.device['sr_model'], self.dtype['sr_model'], lowres_img=lr_image
64
+ )
65
+
66
+ if refine:
67
+ betas = get_named_beta_schedule('cosine', 1000)
68
+ base_diffusion = BaseDiffusion(betas, 0.99)
69
+
70
+ with torch.cuda.amp.autocast(dtype=self.dtype['movq']):
71
+ lr_image_latent = self.movq.encode(sr_image)
72
+
73
+ pil_images = []
74
+ context = torch.load('weights/context.pt').to(self.dtype['refiner'])
75
+ context_mask = torch.load('weights/context_mask.pt').to(self.dtype['refiner'])
76
+
77
+ with torch.no_grad():
78
+ with torch.cuda.amp.autocast(dtype=self.dtype['refiner']):
79
+ refiner_image = base_diffusion.refine_tiled(self.refiner, lr_image_latent, context, context_mask)
80
+
81
+ with torch.cuda.amp.autocast(dtype=self.dtype['movq']):
82
+ refiner_image = self.movq.decode(refiner_image)
83
+ refiner_image = torch.clip((refiner_image + 1.) / 2., 0., 1.)
84
+
85
+ if old_height*self.scale != refiner_image.shape[2] or old_width*self.scale != refiner_image.shape[3]:
86
+ refiner_image = F.interpolate(refiner_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
87
+ refined_pil_image = self.to_pil(refiner_image[0])
88
+ return refined_pil_image
89
+
90
+ sr_image = torch.clip((sr_image + 1.) / 2., 0., 1.)
91
+ if old_height*self.scale != sr_image.shape[2] or old_width*self.scale != sr_image.shape[3]:
92
+ sr_image = F.interpolate(sr_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
93
+ pil_sr_image = self.to_pil(sr_image[0])
94
+ return pil_sr_image
95
+
96
+ else:
97
+ base_diffusion = DPMSolver(steps)
98
+
99
+ lr_image = self.image_transform(pil_image).unsqueeze(0).to(self.device)
100
+
101
+ old_height = pil_image.size[1]
102
+ old_width = pil_image.size[0]
103
+
104
+ height = int(old_height+np.mod(old_height,2))*self.scale
105
+ width = int(old_width+np.mod(old_width,2))*self.scale
106
+
107
+ sr_image = base_diffusion.generate_panorama(height, width, self.device, self.dtype, steps,
108
+ self.sr_model, lowres_img=lr_image,
109
+ view_batch_size=view_batch_size, eta=0.0, seed=seed)
110
+
111
+ sr_image = torch.clip((sr_image + 1.) / 2., 0., 1.)
112
+ if old_height*self.scale != height or old_width*self.scale != width:
113
+ sr_image = F.interpolate(sr_image, [old_height*self.scale, old_width*self.scale], mode='bilinear', align_corners=True)
114
+
115
+ pil_sr_image = self.to_pil(sr_image[0])
116
+ return pil_sr_image
KandiSuperRes/utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ def freeze(model):
2
+ for p in model.parameters():
3
+ p.requires_grad = False
4
+ return model
5
+
6
+ def unfreeze(model):
7
+ for p in model.parameters():
8
+ p.requires_grad = True
9
+ return model
weights/context.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92cc62ca3e341bd4ea03df187c06aceb43505e79107b6a2fef717a86051a6296
3
+ size 1049756
weights/context_mask.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ff4ea52d9deb41f4732dd422fef48b1382247d9dbe0493d0a274e0d2e13591f
3
+ size 2229