Delik commited on
Commit
f1ea451
·
verified ·
1 Parent(s): 33faaeb

Upload 32 files

Browse files
LIA_Model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from networks.encoder import Encoder
4
+ from networks.styledecoder import Synthesis
5
+
6
+ # This part is modified from: https://github.com/wyhsirius/LIA
7
+ class LIA_Model(torch.nn.Module):
8
+ def __init__(self, size = 256, style_dim = 512, motion_dim = 20, channel_multiplier=1, blur_kernel=[1, 3, 3, 1], fusion_type=''):
9
+ super().__init__()
10
+ self.enc = Encoder(size, style_dim, motion_dim, fusion_type)
11
+ self.dec = Synthesis(size, style_dim, motion_dim, blur_kernel, channel_multiplier)
12
+
13
+ def get_start_direction_code(self, x_start, x_target, x_face, x_aug):
14
+ enc_dic = self.enc(x_start, x_target, x_face, x_aug)
15
+
16
+ wa, alpha, feats = enc_dic['h_source'], enc_dic['h_motion'], enc_dic['feats']
17
+
18
+ return wa, alpha, feats
19
+
20
+ def render(self, start, direction, feats):
21
+ return self.dec(start, direction, feats)
22
+
23
+ def load_lightning_model(self, lia_pretrained_model_path):
24
+ selfState = self.state_dict()
25
+
26
+ state = torch.load(lia_pretrained_model_path, map_location='cpu')
27
+ for name, param in state.items():
28
+ origName = name;
29
+
30
+ if name not in selfState:
31
+ name = name.replace("lia.", "")
32
+ if name not in selfState:
33
+ print("%s is not in the model."%origName)
34
+ # You can ignore those errors as some parameters are only used for training
35
+ continue
36
+ if selfState[name].size() != state[origName].size():
37
+ print("Wrong parameter length: %s, model: %s, loaded: %s"%(origName, selfState[name].size(), state[origName].size()))
38
+ continue
39
+ selfState[name].copy_(param)
choices.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from torch import nn
3
+
4
+
5
+ class TrainMode(Enum):
6
+ # manipulate mode = training the classifier
7
+ manipulate = 'manipulate'
8
+ # default trainin mode!
9
+ diffusion = 'diffusion'
10
+ # default latent training mode!
11
+ # fitting the a DDPM to a given latent
12
+ latent_diffusion = 'latentdiffusion'
13
+
14
+ def is_manipulate(self):
15
+ return self in [
16
+ TrainMode.manipulate,
17
+ ]
18
+
19
+ def is_diffusion(self):
20
+ return self in [
21
+ TrainMode.diffusion,
22
+ TrainMode.latent_diffusion,
23
+ ]
24
+
25
+ def is_autoenc(self):
26
+ # the network possibly does autoencoding
27
+ return self in [
28
+ TrainMode.diffusion,
29
+ ]
30
+
31
+ def is_latent_diffusion(self):
32
+ return self in [
33
+ TrainMode.latent_diffusion,
34
+ ]
35
+
36
+ def use_latent_net(self):
37
+ return self.is_latent_diffusion()
38
+
39
+ def require_dataset_infer(self):
40
+ """
41
+ whether training in this mode requires the latent variables to be available?
42
+ """
43
+ # this will precalculate all the latents before hand
44
+ # and the dataset will be all the predicted latents
45
+ return self in [
46
+ TrainMode.latent_diffusion,
47
+ TrainMode.manipulate,
48
+ ]
49
+
50
+
51
+ class ManipulateMode(Enum):
52
+ """
53
+ how to train the classifier to manipulate
54
+ """
55
+ # train on whole celeba attr dataset
56
+ celebahq_all = 'celebahq_all'
57
+ # celeba with D2C's crop
58
+ d2c_fewshot = 'd2cfewshot'
59
+ d2c_fewshot_allneg = 'd2cfewshotallneg'
60
+
61
+ def is_celeba_attr(self):
62
+ return self in [
63
+ ManipulateMode.d2c_fewshot,
64
+ ManipulateMode.d2c_fewshot_allneg,
65
+ ManipulateMode.celebahq_all,
66
+ ]
67
+
68
+ def is_single_class(self):
69
+ return self in [
70
+ ManipulateMode.d2c_fewshot,
71
+ ManipulateMode.d2c_fewshot_allneg,
72
+ ]
73
+
74
+ def is_fewshot(self):
75
+ return self in [
76
+ ManipulateMode.d2c_fewshot,
77
+ ManipulateMode.d2c_fewshot_allneg,
78
+ ]
79
+
80
+ def is_fewshot_allneg(self):
81
+ return self in [
82
+ ManipulateMode.d2c_fewshot_allneg,
83
+ ]
84
+
85
+
86
+ class ModelType(Enum):
87
+ """
88
+ Kinds of the backbone models
89
+ """
90
+
91
+ # unconditional ddpm
92
+ ddpm = 'ddpm'
93
+ # autoencoding ddpm cannot do unconditional generation
94
+ autoencoder = 'autoencoder'
95
+
96
+ def has_autoenc(self):
97
+ return self in [
98
+ ModelType.autoencoder,
99
+ ]
100
+
101
+ def can_sample(self):
102
+ return self in [ModelType.ddpm]
103
+
104
+
105
+ class ModelName(Enum):
106
+ """
107
+ List of all supported model classes
108
+ """
109
+
110
+ beatgans_ddpm = 'beatgans_ddpm'
111
+ beatgans_autoenc = 'beatgans_autoenc'
112
+
113
+
114
+ class ModelMeanType(Enum):
115
+ """
116
+ Which type of output the model predicts.
117
+ """
118
+
119
+ eps = 'eps' # the model predicts epsilon
120
+
121
+
122
+ class ModelVarType(Enum):
123
+ """
124
+ What is used as the model's output variance.
125
+
126
+ The LEARNED_RANGE option has been added to allow the model to predict
127
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
128
+ """
129
+
130
+ # posterior beta_t
131
+ fixed_small = 'fixed_small'
132
+ # beta_t
133
+ fixed_large = 'fixed_large'
134
+
135
+
136
+ class LossType(Enum):
137
+ mse = 'mse' # use raw MSE loss (and KL when learning variances)
138
+ l1 = 'l1'
139
+
140
+
141
+ class GenerativeType(Enum):
142
+ """
143
+ How's a sample generated
144
+ """
145
+
146
+ ddpm = 'ddpm'
147
+ ddim = 'ddim'
148
+
149
+
150
+ class OptimizerType(Enum):
151
+ adam = 'adam'
152
+ adamw = 'adamw'
153
+
154
+
155
+ class Activation(Enum):
156
+ none = 'none'
157
+ relu = 'relu'
158
+ lrelu = 'lrelu'
159
+ silu = 'silu'
160
+ tanh = 'tanh'
161
+
162
+ def get_act(self):
163
+ if self == Activation.none:
164
+ return nn.Identity()
165
+ elif self == Activation.relu:
166
+ return nn.ReLU()
167
+ elif self == Activation.lrelu:
168
+ return nn.LeakyReLU(negative_slope=0.2)
169
+ elif self == Activation.silu:
170
+ return nn.SiLU()
171
+ elif self == Activation.tanh:
172
+ return nn.Tanh()
173
+ else:
174
+ raise NotImplementedError()
175
+
176
+
177
+ class ManipulateLossType(Enum):
178
+ bce = 'bce'
179
+ mse = 'mse'
config.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.unet import ScaleAt
2
+ from model.latentnet import *
3
+ from diffusion.resample import UniformSampler
4
+ from diffusion.diffusion import space_timesteps
5
+ from typing import Tuple
6
+
7
+ from torch.utils.data import DataLoader
8
+
9
+ from config_base import BaseConfig
10
+ from diffusion import *
11
+ from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
12
+ from model import *
13
+ from choices import *
14
+ from multiprocessing import get_context
15
+ import os
16
+ from dataset_util import *
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from dataset import LatentDataLoader
19
+
20
+ @dataclass
21
+ class PretrainConfig(BaseConfig):
22
+ name: str
23
+ path: str
24
+
25
+
26
+ @dataclass
27
+ class TrainConfig(BaseConfig):
28
+ # random seed
29
+ seed: int = 0
30
+ train_mode: TrainMode = TrainMode.diffusion
31
+ train_cond0_prob: float = 0
32
+ train_pred_xstart_detach: bool = True
33
+ train_interpolate_prob: float = 0
34
+ train_interpolate_img: bool = False
35
+ manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all
36
+ manipulate_cls: str = None
37
+ manipulate_shots: int = None
38
+ manipulate_loss: ManipulateLossType = ManipulateLossType.bce
39
+ manipulate_znormalize: bool = False
40
+ manipulate_seed: int = 0
41
+ accum_batches: int = 1
42
+ autoenc_mid_attn: bool = True
43
+ batch_size: int = 16
44
+ batch_size_eval: int = None
45
+ beatgans_gen_type: GenerativeType = GenerativeType.ddim
46
+ beatgans_loss_type: LossType = LossType.mse
47
+ beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
48
+ beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
49
+ beatgans_rescale_timesteps: bool = False
50
+ latent_infer_path: str = None
51
+ latent_znormalize: bool = False
52
+ latent_gen_type: GenerativeType = GenerativeType.ddim
53
+ latent_loss_type: LossType = LossType.mse
54
+ latent_model_mean_type: ModelMeanType = ModelMeanType.eps
55
+ latent_model_var_type: ModelVarType = ModelVarType.fixed_large
56
+ latent_rescale_timesteps: bool = False
57
+ latent_T_eval: int = 1_000
58
+ latent_clip_sample: bool = False
59
+ latent_beta_scheduler: str = 'linear'
60
+ beta_scheduler: str = 'linear'
61
+ data_name: str = ''
62
+ data_val_name: str = None
63
+ diffusion_type: str = None
64
+ dropout: float = 0.1
65
+ ema_decay: float = 0.9999
66
+ eval_num_images: int = 5_000
67
+ eval_every_samples: int = 200_000
68
+ eval_ema_every_samples: int = 200_000
69
+ fid_use_torch: bool = True
70
+ fp16: bool = False
71
+ grad_clip: float = 1
72
+ img_size: int = 64
73
+ lr: float = 0.0001
74
+ optimizer: OptimizerType = OptimizerType.adam
75
+ weight_decay: float = 0
76
+ model_conf: ModelConfig = None
77
+ model_name: ModelName = None
78
+ model_type: ModelType = None
79
+ net_attn: Tuple[int] = None
80
+ net_beatgans_attn_head: int = 1
81
+ # not necessarily the same as the the number of style channels
82
+ net_beatgans_embed_channels: int = 512
83
+ net_resblock_updown: bool = True
84
+ net_enc_use_time: bool = False
85
+ net_enc_pool: str = 'adaptivenonzero'
86
+ net_beatgans_gradient_checkpoint: bool = False
87
+ net_beatgans_resnet_two_cond: bool = False
88
+ net_beatgans_resnet_use_zero_module: bool = True
89
+ net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
90
+ net_beatgans_resnet_cond_channels: int = None
91
+ net_ch_mult: Tuple[int] = None
92
+ net_ch: int = 64
93
+ net_enc_attn: Tuple[int] = None
94
+ net_enc_k: int = None
95
+ # number of resblocks for the encoder (half-unet)
96
+ net_enc_num_res_blocks: int = 2
97
+ net_enc_channel_mult: Tuple[int] = None
98
+ net_enc_grad_checkpoint: bool = False
99
+ net_autoenc_stochastic: bool = False
100
+ net_latent_activation: Activation = Activation.silu
101
+ net_latent_channel_mult: Tuple[int] = (1, 2, 4)
102
+ net_latent_condition_bias: float = 0
103
+ net_latent_dropout: float = 0
104
+ net_latent_layers: int = None
105
+ net_latent_net_last_act: Activation = Activation.none
106
+ net_latent_net_type: LatentNetType = LatentNetType.none
107
+ net_latent_num_hid_channels: int = 1024
108
+ net_latent_num_time_layers: int = 2
109
+ net_latent_skip_layers: Tuple[int] = None
110
+ net_latent_time_emb_channels: int = 64
111
+ net_latent_use_norm: bool = False
112
+ net_latent_time_last_act: bool = False
113
+ net_num_res_blocks: int = 2
114
+ # number of resblocks for the UNET
115
+ net_num_input_res_blocks: int = None
116
+ net_enc_num_cls: int = None
117
+ num_workers: int = 4
118
+ parallel: bool = False
119
+ postfix: str = ''
120
+ sample_size: int = 64
121
+ sample_every_samples: int = 20_000
122
+ save_every_samples: int = 100_000
123
+ style_ch: int = 512
124
+ T_eval: int = 1_000
125
+ T_sampler: str = 'uniform'
126
+ T: int = 1_000
127
+ total_samples: int = 10_000_000
128
+ warmup: int = 0
129
+ pretrain: PretrainConfig = None
130
+ continue_from: PretrainConfig = None
131
+ eval_programs: Tuple[str] = None
132
+ # if present load the checkpoint from this path instead
133
+ eval_path: str = None
134
+ base_dir: str = 'checkpoints'
135
+ use_cache_dataset: bool = False
136
+ data_cache_dir: str = os.path.expanduser('~/cache')
137
+ work_cache_dir: str = os.path.expanduser('~/mycache')
138
+ # to be overridden
139
+ name: str = ''
140
+
141
+ def __post_init__(self):
142
+ self.batch_size_eval = self.batch_size_eval or self.batch_size
143
+ self.data_val_name = self.data_val_name or self.data_name
144
+
145
+ def scale_up_gpus(self, num_gpus, num_nodes=1):
146
+ self.eval_ema_every_samples *= num_gpus * num_nodes
147
+ self.eval_every_samples *= num_gpus * num_nodes
148
+ self.sample_every_samples *= num_gpus * num_nodes
149
+ self.batch_size *= num_gpus * num_nodes
150
+ self.batch_size_eval *= num_gpus * num_nodes
151
+ return self
152
+
153
+ @property
154
+ def batch_size_effective(self):
155
+ return self.batch_size * self.accum_batches
156
+
157
+ @property
158
+ def fid_cache(self):
159
+ # we try to use the local dirs to reduce the load over network drives
160
+ # hopefully, this would reduce the disconnection problems with sshfs
161
+ return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}'
162
+
163
+ @property
164
+ def data_path(self):
165
+ # may use the cache dir
166
+ path = data_paths[self.data_name]
167
+ if self.use_cache_dataset and path is not None:
168
+ path = use_cached_dataset_path(
169
+ path, f'{self.data_cache_dir}/{self.data_name}')
170
+ return path
171
+
172
+ @property
173
+ def logdir(self):
174
+ return f'{self.base_dir}/{self.name}'
175
+
176
+ @property
177
+ def generate_dir(self):
178
+ # we try to use the local dirs to reduce the load over network drives
179
+ # hopefully, this would reduce the disconnection problems with sshfs
180
+ return f'{self.work_cache_dir}/gen_images/{self.name}'
181
+
182
+ def _make_diffusion_conf(self, T=None):
183
+ if self.diffusion_type == 'beatgans':
184
+ # can use T < self.T for evaluation
185
+ # follows the guided-diffusion repo conventions
186
+ # t's are evenly spaced
187
+ if self.beatgans_gen_type == GenerativeType.ddpm:
188
+ section_counts = [T]
189
+ elif self.beatgans_gen_type == GenerativeType.ddim:
190
+ section_counts = f'ddim{T}'
191
+ else:
192
+ raise NotImplementedError()
193
+
194
+ return SpacedDiffusionBeatGansConfig(
195
+ gen_type=self.beatgans_gen_type,
196
+ model_type=self.model_type,
197
+ betas=get_named_beta_schedule(self.beta_scheduler, self.T),
198
+ model_mean_type=self.beatgans_model_mean_type,
199
+ model_var_type=self.beatgans_model_var_type,
200
+ loss_type=self.beatgans_loss_type,
201
+ rescale_timesteps=self.beatgans_rescale_timesteps,
202
+ use_timesteps=space_timesteps(num_timesteps=self.T,
203
+ section_counts=section_counts),
204
+ fp16=self.fp16,
205
+ )
206
+ else:
207
+ raise NotImplementedError()
208
+
209
+ def _make_latent_diffusion_conf(self, T=None):
210
+ # can use T < self.T for evaluation
211
+ # follows the guided-diffusion repo conventions
212
+ # t's are evenly spaced
213
+ if self.latent_gen_type == GenerativeType.ddpm:
214
+ section_counts = [T]
215
+ elif self.latent_gen_type == GenerativeType.ddim:
216
+ section_counts = f'ddim{T}'
217
+ else:
218
+ raise NotImplementedError()
219
+
220
+ return SpacedDiffusionBeatGansConfig(
221
+ train_pred_xstart_detach=self.train_pred_xstart_detach,
222
+ gen_type=self.latent_gen_type,
223
+ # latent's model is always ddpm
224
+ model_type=ModelType.ddpm,
225
+ # latent shares the beta scheduler and full T
226
+ betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
227
+ model_mean_type=self.latent_model_mean_type,
228
+ model_var_type=self.latent_model_var_type,
229
+ loss_type=self.latent_loss_type,
230
+ rescale_timesteps=self.latent_rescale_timesteps,
231
+ use_timesteps=space_timesteps(num_timesteps=self.T,
232
+ section_counts=section_counts),
233
+ fp16=self.fp16,
234
+ )
235
+
236
+ @property
237
+ def model_out_channels(self):
238
+ return 3
239
+
240
+ def make_T_sampler(self):
241
+ if self.T_sampler == 'uniform':
242
+ return UniformSampler(self.T)
243
+ else:
244
+ raise NotImplementedError()
245
+
246
+ def make_diffusion_conf(self):
247
+ return self._make_diffusion_conf(self.T)
248
+
249
+ def make_eval_diffusion_conf(self):
250
+ return self._make_diffusion_conf(T=self.T_eval)
251
+
252
+ def make_latent_diffusion_conf(self):
253
+ return self._make_latent_diffusion_conf(T=self.T)
254
+
255
+ def make_latent_eval_diffusion_conf(self):
256
+ # latent can have different eval T
257
+ return self._make_latent_diffusion_conf(T=self.latent_T_eval)
258
+
259
+ def make_dataset(self, path=None, **kwargs):
260
+ return LatentDataLoader(self.window_size,
261
+ self.frame_jpgs,
262
+ self.lmd_feats_prefix,
263
+ self.audio_prefix,
264
+ self.raw_audio_prefix,
265
+ self.motion_latents_prefix,
266
+ self.pose_prefix,
267
+ self.db_name,
268
+ audio_hz=self.audio_hz)
269
+
270
+ def make_loader(self,
271
+ dataset,
272
+ shuffle: bool,
273
+ num_worker: bool = None,
274
+ drop_last: bool = True,
275
+ batch_size: int = None,
276
+ parallel: bool = False):
277
+ if parallel and distributed.is_initialized():
278
+ # drop last to make sure that there is no added special indexes
279
+ sampler = DistributedSampler(dataset,
280
+ shuffle=shuffle,
281
+ drop_last=True)
282
+ else:
283
+ sampler = None
284
+ return DataLoader(
285
+ dataset,
286
+ batch_size=batch_size or self.batch_size,
287
+ sampler=sampler,
288
+ # with sampler, use the sample instead of this option
289
+ shuffle=False if sampler else shuffle,
290
+ num_workers=num_worker or self.num_workers,
291
+ pin_memory=True,
292
+ drop_last=drop_last,
293
+ multiprocessing_context=get_context('fork'),
294
+ )
295
+
296
+ def make_model_conf(self):
297
+ if self.model_name == ModelName.beatgans_ddpm:
298
+ self.model_type = ModelType.ddpm
299
+ self.model_conf = BeatGANsUNetConfig(
300
+ attention_resolutions=self.net_attn,
301
+ channel_mult=self.net_ch_mult,
302
+ conv_resample=True,
303
+ dims=2,
304
+ dropout=self.dropout,
305
+ embed_channels=self.net_beatgans_embed_channels,
306
+ image_size=self.img_size,
307
+ in_channels=3,
308
+ model_channels=self.net_ch,
309
+ num_classes=None,
310
+ num_head_channels=-1,
311
+ num_heads_upsample=-1,
312
+ num_heads=self.net_beatgans_attn_head,
313
+ num_res_blocks=self.net_num_res_blocks,
314
+ num_input_res_blocks=self.net_num_input_res_blocks,
315
+ out_channels=self.model_out_channels,
316
+ resblock_updown=self.net_resblock_updown,
317
+ use_checkpoint=self.net_beatgans_gradient_checkpoint,
318
+ use_new_attention_order=False,
319
+ resnet_two_cond=self.net_beatgans_resnet_two_cond,
320
+ resnet_use_zero_module=self.
321
+ net_beatgans_resnet_use_zero_module,
322
+ )
323
+ elif self.model_name in [
324
+ ModelName.beatgans_autoenc,
325
+ ]:
326
+ cls = BeatGANsAutoencConfig
327
+ # supports both autoenc and vaeddpm
328
+ if self.model_name == ModelName.beatgans_autoenc:
329
+ self.model_type = ModelType.autoencoder
330
+ else:
331
+ raise NotImplementedError()
332
+
333
+ if self.net_latent_net_type == LatentNetType.none:
334
+ latent_net_conf = None
335
+ elif self.net_latent_net_type == LatentNetType.skip:
336
+ latent_net_conf = MLPSkipNetConfig(
337
+ num_channels=self.style_ch,
338
+ skip_layers=self.net_latent_skip_layers,
339
+ num_hid_channels=self.net_latent_num_hid_channels,
340
+ num_layers=self.net_latent_layers,
341
+ num_time_emb_channels=self.net_latent_time_emb_channels,
342
+ activation=self.net_latent_activation,
343
+ use_norm=self.net_latent_use_norm,
344
+ condition_bias=self.net_latent_condition_bias,
345
+ dropout=self.net_latent_dropout,
346
+ last_act=self.net_latent_net_last_act,
347
+ num_time_layers=self.net_latent_num_time_layers,
348
+ time_last_act=self.net_latent_time_last_act,
349
+ )
350
+ else:
351
+ raise NotImplementedError()
352
+
353
+ self.model_conf = cls(
354
+ attention_resolutions=self.net_attn,
355
+ channel_mult=self.net_ch_mult,
356
+ conv_resample=True,
357
+ dims=2,
358
+ dropout=self.dropout,
359
+ embed_channels=self.net_beatgans_embed_channels,
360
+ enc_out_channels=self.style_ch,
361
+ enc_pool=self.net_enc_pool,
362
+ enc_num_res_block=self.net_enc_num_res_blocks,
363
+ enc_channel_mult=self.net_enc_channel_mult,
364
+ enc_grad_checkpoint=self.net_enc_grad_checkpoint,
365
+ enc_attn_resolutions=self.net_enc_attn,
366
+ image_size=self.img_size,
367
+ in_channels=3,
368
+ model_channels=self.net_ch,
369
+ num_classes=None,
370
+ num_head_channels=-1,
371
+ num_heads_upsample=-1,
372
+ num_heads=self.net_beatgans_attn_head,
373
+ num_res_blocks=self.net_num_res_blocks,
374
+ num_input_res_blocks=self.net_num_input_res_blocks,
375
+ out_channels=self.model_out_channels,
376
+ resblock_updown=self.net_resblock_updown,
377
+ use_checkpoint=self.net_beatgans_gradient_checkpoint,
378
+ use_new_attention_order=False,
379
+ resnet_two_cond=self.net_beatgans_resnet_two_cond,
380
+ resnet_use_zero_module=self.
381
+ net_beatgans_resnet_use_zero_module,
382
+ latent_net_conf=latent_net_conf,
383
+ resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
384
+ )
385
+ else:
386
+ raise NotImplementedError(self.model_name)
387
+
388
+ return self.model_conf
config_base.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class BaseConfig:
9
+ def clone(self):
10
+ return deepcopy(self)
11
+
12
+ def inherit(self, another):
13
+ """inherit common keys from a given config"""
14
+ common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
15
+ for k in common_keys:
16
+ setattr(self, k, getattr(another, k))
17
+
18
+ def propagate(self):
19
+ """push down the configuration to all members"""
20
+ for k, v in self.__dict__.items():
21
+ if isinstance(v, BaseConfig):
22
+ v.inherit(self)
23
+ v.propagate()
24
+
25
+ def save(self, save_path):
26
+ """save config to json file"""
27
+ dirname = os.path.dirname(save_path)
28
+ if not os.path.exists(dirname):
29
+ os.makedirs(dirname)
30
+ conf = self.as_dict_jsonable()
31
+ with open(save_path, 'w') as f:
32
+ json.dump(conf, f)
33
+
34
+ def load(self, load_path):
35
+ """load json config"""
36
+ with open(load_path) as f:
37
+ conf = json.load(f)
38
+ self.from_dict(conf)
39
+
40
+ def from_dict(self, dict, strict=False):
41
+ for k, v in dict.items():
42
+ if not hasattr(self, k):
43
+ if strict:
44
+ raise ValueError(f"loading extra '{k}'")
45
+ else:
46
+ print(f"loading extra '{k}'")
47
+ continue
48
+ if isinstance(self.__dict__[k], BaseConfig):
49
+ self.__dict__[k].from_dict(v)
50
+ else:
51
+ self.__dict__[k] = v
52
+
53
+ def as_dict_jsonable(self):
54
+ conf = {}
55
+ for k, v in self.__dict__.items():
56
+ if isinstance(v, BaseConfig):
57
+ conf[k] = v.as_dict_jsonable()
58
+ else:
59
+ if jsonable(v):
60
+ conf[k] = v
61
+ else:
62
+ # ignore not jsonable
63
+ pass
64
+ return conf
65
+
66
+
67
+ def jsonable(x):
68
+ try:
69
+ json.dumps(x)
70
+ return True
71
+ except TypeError:
72
+ return False
dataset.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import librosa
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import python_speech_features
6
+ import random
7
+ import os
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ import torchvision
11
+ import torchvision.transforms as transforms
12
+ from PIL import Image
13
+
14
+ class LatentDataLoader(object):
15
+
16
+ def __init__(
17
+ self,
18
+ window_size,
19
+ frame_jpgs,
20
+ lmd_feats_prefix,
21
+ audio_prefix,
22
+ raw_audio_prefix,
23
+ motion_latents_prefix,
24
+ pose_prefix,
25
+ db_name,
26
+ video_fps=25,
27
+ audio_hz=50,
28
+ size=256,
29
+ mfcc_mode=False,
30
+ ):
31
+ self.window_size = window_size
32
+ self.lmd_feats_prefix = lmd_feats_prefix
33
+ self.audio_prefix = audio_prefix
34
+ self.pose_prefix = pose_prefix
35
+ self.video_fps = video_fps
36
+ self.audio_hz = audio_hz
37
+ self.db_name = db_name
38
+ self.raw_audio_prefix = raw_audio_prefix
39
+ self.mfcc_mode = mfcc_mode
40
+
41
+
42
+ self.transform = torchvision.transforms.Compose([
43
+ transforms.Resize((size, size)),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
46
+ )
47
+
48
+ self.data = []
49
+ for db_name in [ 'VoxCeleb2', 'HDTF' ]:
50
+ db_png_path = os.path.join(frame_jpgs, db_name)
51
+ for clip_name in tqdm(os.listdir(db_png_path)):
52
+
53
+ item_dict = dict()
54
+ item_dict['clip_name'] = clip_name
55
+ item_dict['frame_count'] = len(list(os.listdir(os.path.join(frame_jpgs, db_name, clip_name))))
56
+ item_dict['hubert_path'] = os.path.join(audio_prefix, db_name, clip_name +".npy")
57
+ item_dict['wav_path'] = os.path.join(raw_audio_prefix, db_name, clip_name +".wav")
58
+
59
+ item_dict['yaw_pitch_roll_path'] = os.path.join(pose_prefix, db_name, 'raw_videos_pose_yaw_pitch_roll', clip_name +".npy")
60
+ if not os.path.exists(item_dict['yaw_pitch_roll_path']):
61
+ print(f"{db_name}'s {clip_name} miss yaw_pitch_roll_path")
62
+ continue
63
+
64
+ item_dict['yaw_pitch_roll'] = np.load(item_dict['yaw_pitch_roll_path'])
65
+ item_dict['yaw_pitch_roll'] = np.clip(item_dict['yaw_pitch_roll'], -90, 90) / 90.0
66
+
67
+ if not os.path.exists(item_dict['wav_path']):
68
+ print(f"{db_name}'s {clip_name} miss wav_path")
69
+ continue
70
+
71
+ if not os.path.exists(item_dict['hubert_path']):
72
+ print(f"{db_name}'s {clip_name} miss hubert_path")
73
+ continue
74
+
75
+
76
+ if self.mfcc_mode:
77
+ wav, sr = librosa.load(item_dict['wav_path'], sr=16000)
78
+ input_values = python_speech_features.mfcc(signal=wav,samplerate=sr,numcep=13,winlen=0.025,winstep=0.01)
79
+ d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
80
+ d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
81
+ input_values = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
82
+ item_dict['hubert_obj'] = input_values
83
+ else:
84
+ item_dict['hubert_obj'] = np.load(item_dict['hubert_path'], mmap_mode='r')
85
+ item_dict['lmd_path'] = os.path.join(lmd_feats_prefix, db_name, clip_name +".txt")
86
+ item_dict['lmd_obj_full'] = self.read_landmark_info(item_dict['lmd_path'], upper_face=False)
87
+
88
+ motion_start_path = os.path.join(motion_latents_prefix, db_name, 'motions', clip_name +".npy")
89
+ motion_direction_path = os.path.join(motion_latents_prefix, db_name, 'directions', clip_name +".npy")
90
+
91
+ if not os.path.exists(motion_start_path):
92
+ print(f"{db_name}'s {clip_name} miss motion_start_path")
93
+ continue
94
+ if not os.path.exists(motion_direction_path):
95
+ print(f"{db_name}'s {clip_name} miss motion_direction_path")
96
+ continue
97
+
98
+ item_dict['motion_start_obj'] = np.load(motion_start_path)
99
+ item_dict['motion_direction_obj'] = np.load(motion_direction_path)
100
+
101
+ if self.mfcc_mode:
102
+ min_len = min(
103
+ item_dict['lmd_obj_full'].shape[0],
104
+ item_dict['yaw_pitch_roll'].shape[0],
105
+ item_dict['motion_start_obj'].shape[0],
106
+ item_dict['motion_direction_obj'].shape[0],
107
+ int(item_dict['hubert_obj'].shape[0]/4),
108
+ item_dict['frame_count']
109
+ )
110
+ item_dict['frame_count'] = min_len
111
+ item_dict['hubert_obj'] = item_dict['hubert_obj'][:min_len*4,:]
112
+ else:
113
+ min_len = min(
114
+ item_dict['lmd_obj_full'].shape[0],
115
+ item_dict['yaw_pitch_roll'].shape[0],
116
+ item_dict['motion_start_obj'].shape[0],
117
+ item_dict['motion_direction_obj'].shape[0],
118
+ int(item_dict['hubert_obj'].shape[1]/2),
119
+ item_dict['frame_count']
120
+ )
121
+
122
+ item_dict['frame_count'] = min_len
123
+ item_dict['hubert_obj'] = item_dict['hubert_obj'][:, :min_len*2, :]
124
+
125
+ if min_len < self.window_size * self.video_fps + 5:
126
+ continue
127
+
128
+ print('Db count:', len(self.data))
129
+
130
+ def get_single_image(self, image_path):
131
+ img_source = Image.open(image_path).convert('RGB')
132
+ img_source = self.transform(img_source)
133
+ return img_source
134
+
135
+ def get_multiple_ranges(self, lists, multi_ranges):
136
+ # Ensure that multi_ranges is a list of tuples
137
+ if not all(isinstance(item, tuple) and len(item) == 2 for item in multi_ranges):
138
+ raise ValueError("multi_ranges must be a list of (start, end) tuples with exactly two elements each")
139
+ extracted_elements = [lists[start:end] for start, end in multi_ranges]
140
+ flat_list = [item for sublist in extracted_elements for item in sublist]
141
+ return flat_list
142
+
143
+
144
+ def read_landmark_info(self, lmd_path, upper_face=True):
145
+ with open(lmd_path, 'r') as file:
146
+ lmd_lines = file.readlines()
147
+ lmd_lines.sort()
148
+
149
+ total_lmd_obj = []
150
+ for i, line in enumerate(lmd_lines):
151
+ # Split the coordinates and filter out any empty strings
152
+ coords = [c for c in line.strip().split(' ') if c]
153
+ coords = coords[1:] # do not include the file name in the first row
154
+ lmd_obj = []
155
+ if upper_face:
156
+ # Ensure that the coordinates are parsed as integers
157
+ for coord_pair in self.get_multiple_ranges(coords, [(0, 3), (14, 27), (36, 48)]): # 28个
158
+ x, y = coord_pair.split('_')
159
+ lmd_obj.append((int(x)/512, int(y)/512))
160
+ else:
161
+ for coord_pair in coords:
162
+ x, y = coord_pair.split('_')
163
+ lmd_obj.append((int(x)/512, int(y)/512))
164
+ total_lmd_obj.append(lmd_obj)
165
+
166
+ return np.array(total_lmd_obj, dtype=np.float32)
167
+
168
+ def calculate_face_height(self, landmarks):
169
+ forehead_center = (landmarks[ :, 21, :] + landmarks[:, 22, :]) / 2
170
+ chin_bottom = landmarks[:, 8, :]
171
+ distances = np.linalg.norm(forehead_center - chin_bottom, axis=1, keepdims=True)
172
+ return distances
173
+
174
+ def __getitem__(self, index):
175
+
176
+ data_item = self.data[index]
177
+ hubert_obj = data_item['hubert_obj']
178
+ frame_count = data_item['frame_count']
179
+ lmd_obj_full = data_item['lmd_obj_full']
180
+ yaw_pitch_roll = data_item['yaw_pitch_roll']
181
+ motion_start_obj = data_item['motion_start_obj']
182
+ motion_direction_obj = data_item['motion_direction_obj']
183
+
184
+ frame_end_index = random.randint(self.window_size * self.video_fps + 1, frame_count - 1)
185
+ frame_start_index = frame_end_index - self.window_size * self.video_fps
186
+ frame_hint_index = frame_start_index - 1
187
+
188
+ audio_start_index = int(frame_start_index * (self.audio_hz / self.video_fps))
189
+ audio_end_index = int(frame_end_index * (self.audio_hz / self.video_fps))
190
+
191
+ if self.mfcc_mode:
192
+ audio_feats = hubert_obj[audio_start_index:audio_end_index, :]
193
+ else:
194
+ audio_feats = hubert_obj[:, audio_start_index:audio_end_index, :]
195
+
196
+ lmd_obj_full = lmd_obj_full[frame_hint_index:frame_end_index, :]
197
+
198
+ yaw_pitch_roll = yaw_pitch_roll[frame_start_index:frame_end_index, :]
199
+
200
+ motion_start = motion_start_obj[frame_hint_index]
201
+ motion_direction_start = motion_direction_obj[frame_hint_index]
202
+ motion_direction = motion_direction_obj[frame_start_index:frame_end_index, :]
203
+
204
+
205
+
206
+ return {
207
+ 'motion_start': motion_start,
208
+ 'motion_direction': motion_direction,
209
+ 'audio_feats': audio_feats,
210
+ 'face_location': lmd_obj_full[1:, 30, 0], # '1:' means taking the first frame as the driven frame. '30' is the noise location, '0' means x coordinate
211
+ 'face_scale': self.calculate_face_height(lmd_obj_full[1:,:,:]),
212
+ 'yaw_pitch_roll': yaw_pitch_roll,
213
+ 'motion_direction_start': motion_direction_start,
214
+ }
215
+
216
+ def __len__(self):
217
+ return len(self.data)
218
+
dataset_util.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import os
3
+ from dist_utils import *
4
+
5
+
6
+ def use_cached_dataset_path(source_path, cache_path):
7
+ if get_rank() == 0:
8
+ if not os.path.exists(cache_path):
9
+ # shutil.rmtree(cache_path)
10
+ print(f'copying the data: {source_path} to {cache_path}')
11
+ shutil.copytree(source_path, cache_path)
12
+ barrier()
13
+ return cache_path
demo.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from LIA_Model import LIA_Model
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ import argparse
8
+ import numpy as np
9
+ from torchvision import transforms
10
+ from templates import *
11
+ import argparse
12
+ import shutil
13
+ from moviepy.editor import *
14
+ import librosa
15
+ import python_speech_features
16
+ import importlib.util
17
+ import time
18
+
19
+ def check_package_installed(package_name):
20
+ package_spec = importlib.util.find_spec(package_name)
21
+ if package_spec is None:
22
+ print(f"{package_name} is not installed.")
23
+ return False
24
+ else:
25
+ print(f"{package_name} is installed.")
26
+ return True
27
+
28
+ def frames_to_video(input_path, audio_path, output_path, fps=25):
29
+ image_files = [os.path.join(input_path, img) for img in sorted(os.listdir(input_path))]
30
+ clips = [ImageClip(m).set_duration(1/fps) for m in image_files]
31
+ video = concatenate_videoclips(clips, method="compose")
32
+
33
+ audio = AudioFileClip(audio_path)
34
+ final_video = video.set_audio(audio)
35
+ final_video.write_videofile(output_path, fps=fps, codec='libx264', audio_codec='aac')
36
+
37
+ def load_image(filename, size):
38
+ img = Image.open(filename).convert('RGB')
39
+ img = img.resize((size, size))
40
+ img = np.asarray(img)
41
+ img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
42
+ return img / 255.0
43
+
44
+ def img_preprocessing(img_path, size):
45
+ img = load_image(img_path, size) # [0, 1]
46
+ img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
47
+ imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
48
+ return imgs_norm
49
+
50
+ def saved_image(img_tensor, img_path):
51
+ toPIL = transforms.ToPILImage()
52
+ img = toPIL(img_tensor.detach().cpu().squeeze(0)) # 使用squeeze(0)来移除批次维度
53
+ img.save(img_path)
54
+
55
+ def main(args):
56
+ frames_result_saved_path = os.path.join(args.result_path, 'frames')
57
+ os.makedirs(frames_result_saved_path, exist_ok=True)
58
+ test_image_name = os.path.splitext(os.path.basename(args.test_image_path))[0]
59
+ audio_name = os.path.splitext(os.path.basename(args.test_audio_path))[0]
60
+ predicted_video_256_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}.mp4')
61
+ predicted_video_512_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}_SR.mp4')
62
+
63
+ #======Loading Stage 1 model=========
64
+ lia = LIA_Model(motion_dim=args.motion_dim, fusion_type='weighted_sum')
65
+ lia.load_lightning_model(args.stage1_checkpoint_path)
66
+ lia.to(args.device)
67
+ #============================
68
+
69
+ conf = ffhq256_autoenc()
70
+ conf.seed = args.seed
71
+ conf.decoder_layers = args.decoder_layers
72
+ conf.infer_type = args.infer_type
73
+ conf.motion_dim = args.motion_dim
74
+
75
+ if args.infer_type == 'mfcc_full_control':
76
+ conf.face_location=True
77
+ conf.face_scale=True
78
+ conf.mfcc = True
79
+
80
+ elif args.infer_type == 'mfcc_pose_only':
81
+ conf.face_location=False
82
+ conf.face_scale=False
83
+ conf.mfcc = True
84
+
85
+ elif args.infer_type == 'hubert_pose_only':
86
+ conf.face_location=False
87
+ conf.face_scale=False
88
+ conf.mfcc = False
89
+
90
+ elif args.infer_type == 'hubert_audio_only':
91
+ conf.face_location=False
92
+ conf.face_scale=False
93
+ conf.mfcc = False
94
+
95
+ elif args.infer_type == 'hubert_full_control':
96
+ conf.face_location=True
97
+ conf.face_scale=True
98
+ conf.mfcc = False
99
+
100
+ else:
101
+ print('Type NOT Found!')
102
+ exit(0)
103
+
104
+ if not os.path.exists(args.test_image_path):
105
+ print(f'{args.test_image_path} does not exist!')
106
+ exit(0)
107
+
108
+ if not os.path.exists(args.test_audio_path):
109
+ print(f'{args.test_audio_path} does not exist!')
110
+ exit(0)
111
+
112
+ img_source = img_preprocessing(args.test_image_path, args.image_size).to(args.device)
113
+ one_shot_lia_start, one_shot_lia_direction, feats = lia.get_start_direction_code(img_source, img_source, img_source, img_source)
114
+
115
+
116
+ #======Loading Stage 2 model=========
117
+ model = LitModel(conf)
118
+ state = torch.load(args.stage2_checkpoint_path, map_location='cpu')
119
+ model.load_state_dict(state, strict=True)
120
+ model.ema_model.eval()
121
+ model.ema_model.to(args.device);
122
+ #=================================
123
+
124
+
125
+ #======Audio Input=========
126
+ if conf.infer_type.startswith('mfcc'):
127
+ # MFCC features
128
+ wav, sr = librosa.load(args.test_audio_path, sr=16000)
129
+ input_values = python_speech_features.mfcc(signal=wav, samplerate=sr, numcep=13, winlen=0.025, winstep=0.01)
130
+ d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
131
+ d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
132
+ audio_driven_obj = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
133
+ frame_start, frame_end = 0, int(audio_driven_obj.shape[0]/4)
134
+ audio_start, audio_end = int(frame_start * 4), int(frame_end * 4) # The video frame is fixed to 25 hz and the audio is fixed to 100 hz
135
+
136
+ audio_driven = torch.Tensor(audio_driven_obj[audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
137
+
138
+ elif conf.infer_type.startswith('hubert'):
139
+ # Hubert features
140
+ if not os.path.exists(args.test_hubert_path):
141
+
142
+ if not check_package_installed('transformers'):
143
+ print('Please install transformers module first.')
144
+ exit(0)
145
+ hubert_model_path = 'ckpts/chinese-hubert-large'
146
+ if not os.path.exists(hubert_model_path):
147
+ print('Please download the hubert weight into the ckpts path first.')
148
+ exit(0)
149
+ print('You did not extract the audio features in advance, extracting online now, which will increase processing delay')
150
+
151
+ start_time = time.time()
152
+
153
+ # load hubert model
154
+ from transformers import Wav2Vec2FeatureExtractor, HubertModel
155
+ audio_model = HubertModel.from_pretrained(hubert_model_path).to(args.device)
156
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_path)
157
+ audio_model.feature_extractor._freeze_parameters()
158
+ audio_model.eval()
159
+
160
+ # hubert model forward pass
161
+ audio, sr = librosa.load(args.test_audio_path, sr=16000)
162
+ input_values = feature_extractor(audio, sampling_rate=16000, padding=True, do_normalize=True, return_tensors="pt").input_values
163
+ input_values = input_values.to(args.device)
164
+ ws_feats = []
165
+ with torch.no_grad():
166
+ outputs = audio_model(input_values, output_hidden_states=True)
167
+ for i in range(len(outputs.hidden_states)):
168
+ ws_feats.append(outputs.hidden_states[i].detach().cpu().numpy())
169
+ ws_feat_obj = np.array(ws_feats)
170
+ ws_feat_obj = np.squeeze(ws_feat_obj, 1)
171
+ ws_feat_obj = np.pad(ws_feat_obj, ((0, 0), (0, 1), (0, 0)), 'edge') # align the audio length with video frame
172
+
173
+ execution_time = time.time() - start_time
174
+ print(f"Extraction Audio Feature: {execution_time:.2f} Seconds")
175
+
176
+ audio_driven_obj = ws_feat_obj
177
+ else:
178
+ print(f'Using audio feature from path: {args.test_hubert_path}')
179
+ audio_driven_obj = np.load(args.test_hubert_path)
180
+
181
+ frame_start, frame_end = 0, int(audio_driven_obj.shape[1]/2)
182
+ audio_start, audio_end = int(frame_start * 2), int(frame_end * 2) # The video frame is fixed to 25 hz and the audio is fixed to 50 hz
183
+
184
+ audio_driven = torch.Tensor(audio_driven_obj[:,audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
185
+ #============================
186
+
187
+ # Diffusion Noise
188
+ noisyT = th.randn((1,frame_end, args.motion_dim)).to(args.device)
189
+
190
+ #======Inputs for Attribute Control=========
191
+ if os.path.exists(args.pose_driven_path):
192
+ pose_obj = np.load(args.pose_driven_path)
193
+
194
+
195
+ if len(pose_obj.shape) != 2:
196
+ print('please check your pose information. The shape must be like (T, 3).')
197
+ exit(0)
198
+ if pose_obj.shape[1] != 3:
199
+ print('please check your pose information. The shape must be like (T, 3).')
200
+ exit(0)
201
+
202
+ if pose_obj.shape[0] >= frame_end:
203
+ pose_obj = pose_obj[:frame_end,:]
204
+ else:
205
+ padding = np.tile(pose_obj[-1, :], (frame_end - pose_obj.shape[0], 1))
206
+ pose_obj = np.vstack((pose_obj, padding))
207
+
208
+ pose_signal = torch.Tensor(pose_obj).unsqueeze(0).to(args.device) / 90 # 90 is for normalization here
209
+ else:
210
+ yaw_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_yaw
211
+ pitch_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_pitch
212
+ roll_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_roll
213
+ pose_signal = torch.cat((yaw_signal, pitch_signal, roll_signal), dim=-1)
214
+
215
+ pose_signal = torch.clamp(pose_signal, -1, 1)
216
+
217
+ face_location_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_location
218
+ face_scae_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_scale
219
+ #===========================================
220
+
221
+ start_time = time.time()
222
+
223
+ #======Diffusion Denosing Process=========
224
+ generated_directions = model.render(one_shot_lia_start, one_shot_lia_direction, audio_driven, face_location_signal, face_scae_signal, pose_signal, noisyT, args.step_T, control_flag=args.control_flag)
225
+ #=========================================
226
+
227
+ execution_time = time.time() - start_time
228
+ print(f"Motion Diffusion Model: {execution_time:.2f} Seconds")
229
+
230
+ generated_directions = generated_directions.detach().cpu().numpy()
231
+
232
+ start_time = time.time()
233
+ #======Rendering images frame-by-frame=========
234
+ for pred_index in tqdm(range(generated_directions.shape[1])):
235
+ ori_img_recon = lia.render(one_shot_lia_start, torch.Tensor(generated_directions[:,pred_index,:]).to(args.device), feats)
236
+ ori_img_recon = ori_img_recon.clamp(-1, 1)
237
+ wav_pred = (ori_img_recon.detach() + 1) / 2
238
+ saved_image(wav_pred, os.path.join(frames_result_saved_path, "%06d.png"%(pred_index)))
239
+ #==============================================
240
+
241
+ execution_time = time.time() - start_time
242
+ print(f"Renderer Model: {execution_time:.2f} Seconds")
243
+
244
+ frames_to_video(frames_result_saved_path, args.test_audio_path, predicted_video_256_path)
245
+
246
+ shutil.rmtree(frames_result_saved_path)
247
+
248
+
249
+ # Enhancer
250
+ # Code is modified from https://github.com/OpenTalker/SadTalker/blob/cd4c0465ae0b54a6f85af57f5c65fec9fe23e7f8/src/utils/face_enhancer.py#L26
251
+
252
+ if args.face_sr and check_package_installed('gfpgan'):
253
+ from face_sr.face_enhancer import enhancer_list
254
+ import imageio
255
+
256
+ # Super-resolution
257
+ imageio.mimsave(predicted_video_512_path+'.tmp.mp4', enhancer_list(predicted_video_256_path, method='gfpgan', bg_upsampler=None), fps=float(25))
258
+
259
+ # Merge audio and video
260
+ video_clip = VideoFileClip(predicted_video_512_path+'.tmp.mp4')
261
+ audio_clip = AudioFileClip(predicted_video_256_path)
262
+ final_clip = video_clip.set_audio(audio_clip)
263
+ final_clip.write_videofile(predicted_video_512_path, codec='libx264', audio_codec='aac')
264
+
265
+ os.remove(predicted_video_512_path+'.tmp.mp4')
266
+
267
+ if __name__ == '__main__':
268
+ parser = argparse.ArgumentParser()
269
+ parser.add_argument('--infer_type', type=str, default='mfcc_pose_only', help='mfcc_pose_only or mfcc_full_control')
270
+ parser.add_argument('--test_image_path', type=str, default='./test_demos/portraits/monalisa.jpg', help='Path to the portrait')
271
+ parser.add_argument('--test_audio_path', type=str, default='./test_demos/audios/english_female.wav', help='Path to the driven audio')
272
+ parser.add_argument('--test_hubert_path', type=str, default='./test_demos/audios_hubert/english_female.npy', help='Path to the driven audio(hubert type). Not needed for MFCC')
273
+ parser.add_argument('--result_path', type=str, default='./results/', help='Type of inference')
274
+ parser.add_argument('--stage1_checkpoint_path', type=str, default='./ckpts/stage1.ckpt', help='Path to the checkpoint of Stage1')
275
+ parser.add_argument('--stage2_checkpoint_path', type=str, default='./ckpts/pose_only.ckpt', help='Path to the checkpoint of Stage2')
276
+ parser.add_argument('--seed', type=int, default=0, help='seed for generations')
277
+ parser.add_argument('--control_flag', action='store_true', help='Whether to use control signal or not')
278
+ parser.add_argument('--pose_yaw', type=float, default=0.25, help='range from -1 to 1 (-90 ~ 90 angles)')
279
+ parser.add_argument('--pose_pitch', type=float, default=0, help='range from -1 to 1 (-90 ~ 90 angles)')
280
+ parser.add_argument('--pose_roll', type=float, default=0, help='range from -1 to 1 (-90 ~ 90 angles)')
281
+ parser.add_argument('--face_location', type=float, default=0.5, help='range from 0 to 1 (from left to right)')
282
+ parser.add_argument('--pose_driven_path', type=str, default='xxx', help='path to pose numpy, shape is (T, 3). You can check the following code https://github.com/liutaocode/talking_face_preprocessing to extract the yaw, pitch and roll.')
283
+ parser.add_argument('--face_scale', type=float, default=0.5, help='range from 0 to 1 (from small to large)')
284
+ parser.add_argument('--step_T', type=int, default=50, help='Step T for diffusion denoising process')
285
+ parser.add_argument('--image_size', type=int, default=256, help='Size of the image. Do not change.')
286
+ parser.add_argument('--device', type=str, default='cuda:0', help='Device for computation')
287
+ parser.add_argument('--motion_dim', type=int, default=20, help='Dimension of motion. Do not change.')
288
+ parser.add_argument('--decoder_layers', type=int, default=2, help='Layer number for the conformer.')
289
+ parser.add_argument('--face_sr', action='store_true', help='Face super-resolution (Optional). Please install GFPGAN first')
290
+
291
+
292
+
293
+ args = parser.parse_args()
294
+
295
+ main(args)
diffusion/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
4
+
5
+ Sampler = Union[SpacedDiffusionBeatGans]
6
+ SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
diffusion/base.py ADDED
@@ -0,0 +1,1128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code started out as a PyTorch port of Ho et al's diffusion models:
3
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4
+
5
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6
+ """
7
+
8
+ from model.unet_autoenc import AutoencReturn
9
+ from config_base import BaseConfig
10
+ import enum
11
+ import math
12
+
13
+ import numpy as np
14
+ import torch as th
15
+ from model import *
16
+ from model.nn import mean_flat
17
+ from typing import NamedTuple, Tuple
18
+ from choices import *
19
+ from torch.cuda.amp import autocast
20
+ import torch.nn.functional as F
21
+
22
+ from dataclasses import dataclass
23
+
24
+
25
+ @dataclass
26
+ class GaussianDiffusionBeatGansConfig(BaseConfig):
27
+ gen_type: GenerativeType
28
+ betas: Tuple[float]
29
+ model_type: ModelType
30
+ model_mean_type: ModelMeanType
31
+ model_var_type: ModelVarType
32
+ loss_type: LossType
33
+ rescale_timesteps: bool
34
+ fp16: bool
35
+ train_pred_xstart_detach: bool = True
36
+
37
+ def make_sampler(self):
38
+ return GaussianDiffusionBeatGans(self)
39
+
40
+
41
+ class GaussianDiffusionBeatGans:
42
+ """
43
+ Utilities for training and sampling diffusion models.
44
+
45
+ Ported directly from here, and then adapted over time to further experimentation.
46
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
47
+
48
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
49
+ starting at T and going to 1.
50
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
51
+ :param model_var_type: a ModelVarType determining how variance is output.
52
+ :param loss_type: a LossType determining the loss function to use.
53
+ :param rescale_timesteps: if True, pass floating point timesteps into the
54
+ model so that they are always scaled like in the
55
+ original paper (0 to 1000).
56
+ """
57
+ def __init__(self, conf: GaussianDiffusionBeatGansConfig):
58
+ self.conf = conf
59
+ self.model_mean_type = conf.model_mean_type
60
+ self.model_var_type = conf.model_var_type
61
+ self.loss_type = conf.loss_type
62
+ self.rescale_timesteps = conf.rescale_timesteps
63
+
64
+ # Use float64 for accuracy.
65
+ betas = np.array(conf.betas, dtype=np.float64)
66
+ self.betas = betas
67
+ assert len(betas.shape) == 1, "betas must be 1-D"
68
+ assert (betas > 0).all() and (betas <= 1).all()
69
+
70
+ self.num_timesteps = int(betas.shape[0])
71
+
72
+ alphas = 1.0 - betas
73
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
74
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
75
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
76
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps, )
77
+
78
+ # calculations for diffusion q(x_t | x_{t-1}) and others
79
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
80
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
81
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
82
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
83
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod -
84
+ 1)
85
+
86
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
87
+ self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) /
88
+ (1.0 - self.alphas_cumprod))
89
+ # log calculation clipped because the posterior variance is 0 at the
90
+ # beginning of the diffusion chain.
91
+ self.posterior_log_variance_clipped = np.log(
92
+ np.append(self.posterior_variance[1], self.posterior_variance[1:]))
93
+ self.posterior_mean_coef1 = (betas *
94
+ np.sqrt(self.alphas_cumprod_prev) /
95
+ (1.0 - self.alphas_cumprod))
96
+ self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) *
97
+ np.sqrt(alphas) /
98
+ (1.0 - self.alphas_cumprod))
99
+
100
+ def training_losses(self,
101
+ model,
102
+ motion_direction_start: th.Tensor,
103
+ motion_target: th.Tensor,
104
+ motion_start: th.Tensor,
105
+ audio_feats: th.Tensor,
106
+ face_location: th.Tensor,
107
+ face_scale: th.Tensor,
108
+ yaw_pitch_roll: th.Tensor,
109
+ t: th.Tensor,
110
+ model_kwargs=None,
111
+ noise: th.Tensor = None):
112
+ """
113
+ Compute training losses for a single timestep.
114
+
115
+ :param model: the model to evaluate loss on.
116
+ :param x_start: the [N x C x ...] tensor of inputs.
117
+ :param t: a batch of timestep indices.
118
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
119
+ pass to the model. This can be used for conditioning.
120
+ :param noise: if specified, the specific Gaussian noise to try to remove.
121
+ :return: a dict with the key "loss" containing a tensor of shape [N].
122
+ Some mean or variance settings may also have other keys.
123
+ """
124
+ if model_kwargs is None:
125
+ model_kwargs = {}
126
+ if noise is None:
127
+ noise = th.randn_like(motion_target)
128
+
129
+ x_t = self.q_sample(motion_target, t, noise=noise)
130
+
131
+ terms = {'x_t': x_t}
132
+
133
+ if self.loss_type in [
134
+ LossType.mse,
135
+ LossType.l1,
136
+ ]:
137
+ with autocast(self.conf.fp16):
138
+ # x_t is static wrt. to the diffusion process
139
+ predicted_direction, predicted_location, predicted_scale, predicted_pose = model.forward(motion_start,
140
+ motion_direction_start,
141
+ audio_feats,
142
+ face_location,
143
+ face_scale,
144
+ yaw_pitch_roll,
145
+ x_t.detach(),
146
+ self._scale_timesteps(t),
147
+ control_flag=False)
148
+
149
+
150
+ target_types = {
151
+ ModelMeanType.eps: noise,
152
+ }
153
+ target = target_types[self.model_mean_type]
154
+ assert predicted_direction.shape == target.shape == motion_target.shape
155
+
156
+ if self.loss_type == LossType.mse:
157
+ if self.model_mean_type == ModelMeanType.eps:
158
+
159
+ direction_loss = mean_flat((target - predicted_direction)**2)
160
+ # import pdb;pdb.set_trace()
161
+ location_loss = mean_flat((face_location.unsqueeze(-1) - predicted_location)**2)
162
+ scale_loss = mean_flat((face_scale - predicted_scale)**2)
163
+ pose_loss = mean_flat((yaw_pitch_roll - predicted_pose)**2)
164
+
165
+ terms["mse"] = direction_loss + location_loss + scale_loss + pose_loss
166
+
167
+ else:
168
+ raise NotImplementedError()
169
+ elif self.loss_type == LossType.l1:
170
+ # (n, c, h, w) => (n, )
171
+ terms["mse"] = mean_flat((target - predicted_direction).abs())
172
+ else:
173
+ raise NotImplementedError()
174
+
175
+ if "vb" in terms:
176
+ # if learning the variance also use the vlb loss
177
+ terms["loss"] = terms["mse"] + terms["vb"]
178
+ else:
179
+ terms["loss"] = terms["mse"]
180
+ else:
181
+ raise NotImplementedError(self.loss_type)
182
+
183
+
184
+ return terms
185
+
186
+ def sample(self,
187
+ model: Model,
188
+ shape=None,
189
+ noise=None,
190
+ cond=None,
191
+ x_start=None,
192
+ clip_denoised=True,
193
+ model_kwargs=None,
194
+ progress=False):
195
+ """
196
+ Args:
197
+ x_start: given for the autoencoder
198
+ """
199
+ if model_kwargs is None:
200
+ model_kwargs = {}
201
+ if self.conf.model_type.has_autoenc():
202
+ model_kwargs['x_start'] = x_start
203
+ model_kwargs['cond'] = cond
204
+
205
+ if self.conf.gen_type == GenerativeType.ddpm:
206
+ return self.p_sample_loop(model,
207
+ shape=shape,
208
+ noise=noise,
209
+ clip_denoised=clip_denoised,
210
+ model_kwargs=model_kwargs,
211
+ progress=progress)
212
+ elif self.conf.gen_type == GenerativeType.ddim:
213
+ return self.ddim_sample_loop(model,
214
+ shape=shape,
215
+ noise=noise,
216
+ clip_denoised=clip_denoised,
217
+ model_kwargs=model_kwargs,
218
+ progress=progress)
219
+ else:
220
+ raise NotImplementedError()
221
+
222
+ def q_mean_variance(self, x_start, t):
223
+ """
224
+ Get the distribution q(x_t | x_0).
225
+
226
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
227
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
228
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
229
+ """
230
+ mean = (
231
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
232
+ x_start)
233
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t,
234
+ x_start.shape)
235
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod,
236
+ t, x_start.shape)
237
+ return mean, variance, log_variance
238
+
239
+ def q_sample(self, x_start, t, noise=None):
240
+ """
241
+ Diffuse the data for a given number of diffusion steps.
242
+
243
+ In other words, sample from q(x_t | x_0).
244
+
245
+ :param x_start: the initial data batch.
246
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
247
+ :param noise: if specified, the split-out normal noise.
248
+ :return: A noisy version of x_start.
249
+ """
250
+ if noise is None:
251
+ noise = th.randn_like(x_start)
252
+ assert noise.shape == x_start.shape
253
+ return (
254
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
255
+ x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
256
+ t, x_start.shape) * noise)
257
+
258
+ def q_posterior_mean_variance(self, x_start, x_t, t):
259
+ """
260
+ Compute the mean and variance of the diffusion posterior:
261
+
262
+ q(x_{t-1} | x_t, x_0)
263
+
264
+ """
265
+ assert x_start.shape == x_t.shape
266
+ posterior_mean = (
267
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) *
268
+ x_start +
269
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) *
270
+ x_t)
271
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t,
272
+ x_t.shape)
273
+ posterior_log_variance_clipped = _extract_into_tensor(
274
+ self.posterior_log_variance_clipped, t, x_t.shape)
275
+ assert (posterior_mean.shape[0] == posterior_variance.shape[0] ==
276
+ posterior_log_variance_clipped.shape[0] == x_start.shape[0])
277
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
278
+
279
+ def p_mean_variance(self,
280
+ model,
281
+ x,
282
+ t,
283
+ clip_denoised=True,
284
+ denoised_fn=None,
285
+ model_kwargs=None):
286
+ """
287
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
288
+ the initial x, x_0.
289
+
290
+ :param model: the model, which takes a signal and a batch of timesteps
291
+ as input.
292
+ :param x: the [N x C x ...] tensor at time t.
293
+ :param t: a 1-D Tensor of timesteps.
294
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
295
+ :param denoised_fn: if not None, a function which applies to the
296
+ x_start prediction before it is used to sample. Applies before
297
+ clip_denoised.
298
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
299
+ pass to the model. This can be used for conditioning.
300
+ :return: a dict with the following keys:
301
+ - 'mean': the model mean output.
302
+ - 'variance': the model variance output.
303
+ - 'log_variance': the log of 'variance'.
304
+ - 'pred_xstart': the prediction for x_0.
305
+ """
306
+ if model_kwargs is None:
307
+ model_kwargs = {}
308
+
309
+ motion_start = model_kwargs['start']
310
+ audio_feats = model_kwargs['audio_driven']
311
+ face_location = model_kwargs['face_location']
312
+ face_scale = model_kwargs['face_scale']
313
+ yaw_pitch_roll = model_kwargs['yaw_pitch_roll']
314
+ motion_direction_start = model_kwargs['motion_direction_start']
315
+ control_flag = model_kwargs['control_flag']
316
+
317
+ B, C = x.shape[:2]
318
+ assert t.shape == (B, )
319
+ with autocast(self.conf.fp16):
320
+ model_forward, _, _, _ = model.forward(motion_start,
321
+ motion_direction_start,
322
+ audio_feats,
323
+ face_location,
324
+ face_scale,
325
+ yaw_pitch_roll,
326
+ x,
327
+ self._scale_timesteps(t),
328
+ control_flag)
329
+ model_output = model_forward
330
+
331
+ if self.model_var_type in [
332
+ ModelVarType.fixed_large, ModelVarType.fixed_small
333
+ ]:
334
+ model_variance, model_log_variance = {
335
+ # for fixedlarge, we set the initial (log-)variance like so
336
+ # to get a better decoder log likelihood.
337
+ ModelVarType.fixed_large: (
338
+ np.append(self.posterior_variance[1], self.betas[1:]),
339
+ np.log(
340
+ np.append(self.posterior_variance[1], self.betas[1:])),
341
+ ),
342
+ ModelVarType.fixed_small: (
343
+ self.posterior_variance,
344
+ self.posterior_log_variance_clipped,
345
+ ),
346
+ }[self.model_var_type]
347
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
348
+ model_log_variance = _extract_into_tensor(model_log_variance, t,
349
+ x.shape)
350
+
351
+ def process_xstart(x):
352
+ if denoised_fn is not None:
353
+ x = denoised_fn(x)
354
+ if clip_denoised:
355
+ return x.clamp(-1, 1)
356
+ return x
357
+
358
+ if self.model_mean_type in [
359
+ ModelMeanType.eps,
360
+ ]:
361
+ if self.model_mean_type == ModelMeanType.eps:
362
+ pred_xstart = process_xstart(
363
+ self._predict_xstart_from_eps(x_t=x, t=t,
364
+ eps=model_output))
365
+ else:
366
+ raise NotImplementedError()
367
+ model_mean, _, _ = self.q_posterior_mean_variance(
368
+ x_start=pred_xstart, x_t=x, t=t)
369
+ else:
370
+ raise NotImplementedError(self.model_mean_type)
371
+
372
+ assert (model_mean.shape == model_log_variance.shape ==
373
+ pred_xstart.shape == x.shape)
374
+ return {
375
+ "mean": model_mean,
376
+ "variance": model_variance,
377
+ "log_variance": model_log_variance,
378
+ "pred_xstart": pred_xstart,
379
+ 'model_forward': model_forward,
380
+ }
381
+
382
+ def _predict_xstart_from_eps(self, x_t, t, eps):
383
+ assert x_t.shape == eps.shape
384
+ return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
385
+ x_t.shape) * x_t -
386
+ _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
387
+ x_t.shape) * eps)
388
+
389
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
390
+ assert x_t.shape == xprev.shape
391
+ return ( # (xprev - coef2*x_t) / coef1
392
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape)
393
+ * xprev - _extract_into_tensor(
394
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
395
+ x_t.shape) * x_t)
396
+
397
+ def _predict_xstart_from_scaled_xstart(self, t, scaled_xstart):
398
+ return scaled_xstart * _extract_into_tensor(
399
+ self.sqrt_recip_alphas_cumprod, t, scaled_xstart.shape)
400
+
401
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
402
+ return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
403
+ x_t.shape) * x_t -
404
+ pred_xstart) / _extract_into_tensor(
405
+ self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
406
+
407
+ def _predict_eps_from_scaled_xstart(self, x_t, t, scaled_xstart):
408
+ """
409
+ Args:
410
+ scaled_xstart: is supposed to be sqrt(alphacum) * x_0
411
+ """
412
+ # 1 / sqrt(1-alphabar) * (x_t - scaled xstart)
413
+ return (x_t - scaled_xstart) / _extract_into_tensor(
414
+ self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
415
+
416
+ def _scale_timesteps(self, t):
417
+ if self.rescale_timesteps:
418
+ # scale t to be maxed out at 1000 steps
419
+ return t.float() * (1000.0 / self.num_timesteps)
420
+ return t
421
+
422
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
423
+ """
424
+ Compute the mean for the previous step, given a function cond_fn that
425
+ computes the gradient of a conditional log probability with respect to
426
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
427
+ condition on y.
428
+
429
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
430
+ """
431
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
432
+ new_mean = (p_mean_var["mean"].float() +
433
+ p_mean_var["variance"] * gradient.float())
434
+ return new_mean
435
+
436
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
437
+ """
438
+ Compute what the p_mean_variance output would have been, should the
439
+ model's score function be conditioned by cond_fn.
440
+
441
+ See condition_mean() for details on cond_fn.
442
+
443
+ Unlike condition_mean(), this instead uses the conditioning strategy
444
+ from Song et al (2020).
445
+ """
446
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
447
+
448
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
449
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
450
+ x, self._scale_timesteps(t), **model_kwargs)
451
+
452
+ out = p_mean_var.copy()
453
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
454
+ out["mean"], _, _ = self.q_posterior_mean_variance(
455
+ x_start=out["pred_xstart"], x_t=x, t=t)
456
+ return out
457
+
458
+ def p_sample(
459
+ self,
460
+ model: Model,
461
+ x,
462
+ t,
463
+ clip_denoised=True,
464
+ denoised_fn=None,
465
+ cond_fn=None,
466
+ model_kwargs=None,
467
+ ):
468
+ """
469
+ Sample x_{t-1} from the model at the given timestep.
470
+
471
+ :param model: the model to sample from.
472
+ :param x: the current tensor at x_{t-1}.
473
+ :param t: the value of t, starting at 0 for the first diffusion step.
474
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
475
+ :param denoised_fn: if not None, a function which applies to the
476
+ x_start prediction before it is used to sample.
477
+ :param cond_fn: if not None, this is a gradient function that acts
478
+ similarly to the model.
479
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
480
+ pass to the model. This can be used for conditioning.
481
+ :return: a dict containing the following keys:
482
+ - 'sample': a random sample from the model.
483
+ - 'pred_xstart': a prediction of x_0.
484
+ """
485
+ out = self.p_mean_variance(
486
+ model,
487
+ x,
488
+ t,
489
+ clip_denoised=clip_denoised,
490
+ denoised_fn=denoised_fn,
491
+ model_kwargs=model_kwargs,
492
+ )
493
+ noise = th.randn_like(x)
494
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
495
+ ) # no noise when t == 0
496
+ if cond_fn is not None:
497
+ out["mean"] = self.condition_mean(cond_fn,
498
+ out,
499
+ x,
500
+ t,
501
+ model_kwargs=model_kwargs)
502
+ sample = out["mean"] + nonzero_mask * th.exp(
503
+ 0.5 * out["log_variance"]) * noise
504
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
505
+
506
+ def p_sample_loop(
507
+ self,
508
+ model: Model,
509
+ shape=None,
510
+ noise=None,
511
+ clip_denoised=True,
512
+ denoised_fn=None,
513
+ cond_fn=None,
514
+ model_kwargs=None,
515
+ device=None,
516
+ progress=False,
517
+ ):
518
+ """
519
+ Generate samples from the model.
520
+
521
+ :param model: the model module.
522
+ :param shape: the shape of the samples, (N, C, H, W).
523
+ :param noise: if specified, the noise from the encoder to sample.
524
+ Should be of the same shape as `shape`.
525
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
526
+ :param denoised_fn: if not None, a function which applies to the
527
+ x_start prediction before it is used to sample.
528
+ :param cond_fn: if not None, this is a gradient function that acts
529
+ similarly to the model.
530
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
531
+ pass to the model. This can be used for conditioning.
532
+ :param device: if specified, the device to create the samples on.
533
+ If not specified, use a model parameter's device.
534
+ :param progress: if True, show a tqdm progress bar.
535
+ :return: a non-differentiable batch of samples.
536
+ """
537
+ final = None
538
+ for sample in self.p_sample_loop_progressive(
539
+ model,
540
+ shape,
541
+ noise=noise,
542
+ clip_denoised=clip_denoised,
543
+ denoised_fn=denoised_fn,
544
+ cond_fn=cond_fn,
545
+ model_kwargs=model_kwargs,
546
+ device=device,
547
+ progress=progress,
548
+ ):
549
+ final = sample
550
+ return final["sample"]
551
+
552
+ def p_sample_loop_progressive(
553
+ self,
554
+ model: Model,
555
+ shape=None,
556
+ noise=None,
557
+ clip_denoised=True,
558
+ denoised_fn=None,
559
+ cond_fn=None,
560
+ model_kwargs=None,
561
+ device=None,
562
+ progress=False,
563
+ ):
564
+ """
565
+ Generate samples from the model and yield intermediate samples from
566
+ each timestep of diffusion.
567
+
568
+ Arguments are the same as p_sample_loop().
569
+ Returns a generator over dicts, where each dict is the return value of
570
+ p_sample().
571
+ """
572
+ if device is None:
573
+ device = next(model.parameters()).device
574
+ if noise is not None:
575
+ img = noise
576
+ else:
577
+ assert isinstance(shape, (tuple, list))
578
+ img = th.randn(*shape, device=device)
579
+ indices = list(range(self.num_timesteps))[::-1]
580
+
581
+ if progress:
582
+ # Lazy import so that we don't depend on tqdm.
583
+ from tqdm.auto import tqdm
584
+
585
+ indices = tqdm(indices)
586
+
587
+ for i in indices:
588
+ # t = th.tensor([i] * shape[0], device=device)
589
+ t = th.tensor([i] * len(img), device=device)
590
+ with th.no_grad():
591
+ out = self.p_sample(
592
+ model,
593
+ img,
594
+ t,
595
+ clip_denoised=clip_denoised,
596
+ denoised_fn=denoised_fn,
597
+ cond_fn=cond_fn,
598
+ model_kwargs=model_kwargs,
599
+ )
600
+ yield out
601
+ img = out["sample"]
602
+
603
+ def ddim_sample(
604
+ self,
605
+ model: Model,
606
+ x,
607
+ t,
608
+ clip_denoised=True,
609
+ denoised_fn=None,
610
+ cond_fn=None,
611
+ model_kwargs=None,
612
+ eta=0.0,
613
+ ):
614
+ """
615
+ Sample x_{t-1} from the model using DDIM.
616
+
617
+ Same usage as p_sample().
618
+ """
619
+ out = self.p_mean_variance(
620
+ model,
621
+ x,
622
+ t,
623
+ clip_denoised=clip_denoised,
624
+ denoised_fn=denoised_fn,
625
+ model_kwargs=model_kwargs,
626
+ )
627
+ if cond_fn is not None:
628
+ out = self.condition_score(cond_fn,
629
+ out,
630
+ x,
631
+ t,
632
+ model_kwargs=model_kwargs)
633
+
634
+ # Usually our model outputs epsilon, but we re-derive it
635
+ # in case we used x_start or x_prev prediction.
636
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
637
+
638
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
639
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t,
640
+ x.shape)
641
+ sigma = (eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) *
642
+ th.sqrt(1 - alpha_bar / alpha_bar_prev))
643
+ # Equation 12.
644
+ noise = th.randn_like(x)
645
+ mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_prev) +
646
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
647
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
648
+ ) # no noise when t == 0
649
+ sample = mean_pred + nonzero_mask * sigma * noise
650
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
651
+
652
+ def ddim_reverse_sample(
653
+ self,
654
+ model: Model,
655
+ x,
656
+ t,
657
+ clip_denoised=True,
658
+ denoised_fn=None,
659
+ model_kwargs=None,
660
+ eta=0.0,
661
+ ):
662
+ """
663
+ Sample x_{t+1} from the model using DDIM reverse ODE.
664
+ NOTE: never used ?
665
+ """
666
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
667
+ out = self.p_mean_variance(
668
+ model,
669
+ x,
670
+ t,
671
+ clip_denoised=clip_denoised,
672
+ denoised_fn=denoised_fn,
673
+ model_kwargs=model_kwargs,
674
+ )
675
+ # Usually our model outputs epsilon, but we re-derive it
676
+ # in case we used x_start or x_prev prediction.
677
+ eps = (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape)
678
+ * x - out["pred_xstart"]) / _extract_into_tensor(
679
+ self.sqrt_recipm1_alphas_cumprod, t, x.shape)
680
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t,
681
+ x.shape)
682
+
683
+ # Equation 12. reversed (DDIM paper) (th.sqrt == torch.sqrt)
684
+ mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_next) +
685
+ th.sqrt(1 - alpha_bar_next) * eps)
686
+
687
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
688
+
689
+ def ddim_reverse_sample_loop(
690
+ self,
691
+ model: Model,
692
+ x,
693
+ clip_denoised=True,
694
+ denoised_fn=None,
695
+ model_kwargs=None,
696
+ eta=0.0,
697
+ device=None,
698
+ ):
699
+ if device is None:
700
+ device = next(model.parameters()).device
701
+ sample_t = []
702
+ xstart_t = []
703
+ T = []
704
+ indices = list(range(self.num_timesteps))
705
+ sample = x
706
+ for i in indices:
707
+ t = th.tensor([i] * len(sample), device=device)
708
+ with th.no_grad():
709
+ out = self.ddim_reverse_sample(model,
710
+ sample,
711
+ t=t,
712
+ clip_denoised=clip_denoised,
713
+ denoised_fn=denoised_fn,
714
+ model_kwargs=model_kwargs,
715
+ eta=eta)
716
+ sample = out['sample']
717
+ # [1, ..., T]
718
+ sample_t.append(sample)
719
+ # [0, ...., T-1]
720
+ xstart_t.append(out['pred_xstart'])
721
+ # [0, ..., T-1] ready to use
722
+ T.append(t)
723
+
724
+ return {
725
+ # xT "
726
+ 'sample': sample,
727
+ # (1, ..., T)
728
+ 'sample_t': sample_t,
729
+ # xstart here is a bit different from sampling from T = T-1 to T = 0
730
+ # may not be exact
731
+ 'xstart_t': xstart_t,
732
+ 'T': T,
733
+ }
734
+
735
+ def ddim_sample_loop(
736
+ self,
737
+ model: Model,
738
+ shape=None,
739
+ noise=None,
740
+ clip_denoised=True,
741
+ denoised_fn=None,
742
+ cond_fn=None,
743
+ model_kwargs=None,
744
+ device=None,
745
+ progress=False,
746
+ eta=0.0,
747
+ ):
748
+ """
749
+ Generate samples from the model using DDIM.
750
+
751
+ Same usage as p_sample_loop().
752
+ """
753
+ final = None
754
+ for sample in self.ddim_sample_loop_progressive(
755
+ model,
756
+ shape,
757
+ noise=noise,
758
+ clip_denoised=clip_denoised,
759
+ denoised_fn=denoised_fn,
760
+ cond_fn=cond_fn,
761
+ model_kwargs=model_kwargs,
762
+ device=device,
763
+ progress=progress,
764
+ eta=eta,
765
+ ):
766
+ final = sample
767
+ return final["sample"]
768
+
769
+ def ddim_sample_loop_progressive(
770
+ self,
771
+ model: Model,
772
+ shape=None,
773
+ noise=None,
774
+ clip_denoised=True,
775
+ denoised_fn=None,
776
+ cond_fn=None,
777
+ model_kwargs=None,
778
+ device=None,
779
+ progress=False,
780
+ eta=0.0,
781
+ ):
782
+ """
783
+ Use DDIM to sample from the model and yield intermediate samples from
784
+ each timestep of DDIM.
785
+
786
+ Same usage as p_sample_loop_progressive().
787
+ """
788
+ if device is None:
789
+ device = next(model.parameters()).device
790
+ if noise is not None:
791
+ img = noise
792
+ else:
793
+ assert isinstance(shape, (tuple, list))
794
+ img = th.randn(*shape, device=device)
795
+ indices = list(range(self.num_timesteps))[::-1]
796
+
797
+ if progress:
798
+ # Lazy import so that we don't depend on tqdm.
799
+ from tqdm.auto import tqdm
800
+
801
+ indices = tqdm(indices)
802
+
803
+ for i in indices:
804
+
805
+ if isinstance(model_kwargs, list):
806
+ # index dependent model kwargs
807
+ # (T-1, ..., 0)
808
+ _kwargs = model_kwargs[i]
809
+ else:
810
+ _kwargs = model_kwargs
811
+
812
+ t = th.tensor([i] * len(img), device=device)
813
+ with th.no_grad():
814
+ out = self.ddim_sample(
815
+ model,
816
+ img,
817
+ t,
818
+ clip_denoised=clip_denoised,
819
+ denoised_fn=denoised_fn,
820
+ cond_fn=cond_fn,
821
+ model_kwargs=_kwargs,
822
+ eta=eta,
823
+ )
824
+ out['t'] = t
825
+ yield out
826
+ img = out["sample"]
827
+
828
+ def _vb_terms_bpd(self,
829
+ model: Model,
830
+ x_start,
831
+ x_t,
832
+ t,
833
+ clip_denoised=True,
834
+ model_kwargs=None):
835
+ """
836
+ Get a term for the variational lower-bound.
837
+
838
+ The resulting units are bits (rather than nats, as one might expect).
839
+ This allows for comparison to other papers.
840
+
841
+ :return: a dict with the following keys:
842
+ - 'output': a shape [N] tensor of NLLs or KLs.
843
+ - 'pred_xstart': the x_0 predictions.
844
+ """
845
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
846
+ x_start=x_start, x_t=x_t, t=t)
847
+ out = self.p_mean_variance(model,
848
+ x_t,
849
+ t,
850
+ clip_denoised=clip_denoised,
851
+ model_kwargs=model_kwargs)
852
+ kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"],
853
+ out["log_variance"])
854
+ kl = mean_flat(kl) / np.log(2.0)
855
+
856
+ decoder_nll = -discretized_gaussian_log_likelihood(
857
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"])
858
+ assert decoder_nll.shape == x_start.shape
859
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
860
+
861
+ # At the first timestep return the decoder NLL,
862
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
863
+ output = th.where((t == 0), decoder_nll, kl)
864
+ return {
865
+ "output": output,
866
+ "pred_xstart": out["pred_xstart"],
867
+ 'model_forward': out['model_forward'],
868
+ }
869
+
870
+ def _prior_bpd(self, x_start):
871
+ """
872
+ Get the prior KL term for the variational lower-bound, measured in
873
+ bits-per-dim.
874
+
875
+ This term can't be optimized, as it only depends on the encoder.
876
+
877
+ :param x_start: the [N x C x ...] tensor of inputs.
878
+ :return: a batch of [N] KL values (in bits), one per batch element.
879
+ """
880
+ batch_size = x_start.shape[0]
881
+ t = th.tensor([self.num_timesteps - 1] * batch_size,
882
+ device=x_start.device)
883
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
884
+ kl_prior = normal_kl(mean1=qt_mean,
885
+ logvar1=qt_log_variance,
886
+ mean2=0.0,
887
+ logvar2=0.0)
888
+ return mean_flat(kl_prior) / np.log(2.0)
889
+
890
+ def calc_bpd_loop(self,
891
+ model: Model,
892
+ x_start,
893
+ clip_denoised=True,
894
+ model_kwargs=None):
895
+ """
896
+ Compute the entire variational lower-bound, measured in bits-per-dim,
897
+ as well as other related quantities.
898
+
899
+ :param model: the model to evaluate loss on.
900
+ :param x_start: the [N x C x ...] tensor of inputs.
901
+ :param clip_denoised: if True, clip denoised samples.
902
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
903
+ pass to the model. This can be used for conditioning.
904
+
905
+ :return: a dict containing the following keys:
906
+ - total_bpd: the total variational lower-bound, per batch element.
907
+ - prior_bpd: the prior term in the lower-bound.
908
+ - vb: an [N x T] tensor of terms in the lower-bound.
909
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
910
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
911
+ """
912
+ device = x_start.device
913
+ batch_size = x_start.shape[0]
914
+
915
+ vb = []
916
+ xstart_mse = []
917
+ mse = []
918
+ for t in list(range(self.num_timesteps))[::-1]:
919
+ t_batch = th.tensor([t] * batch_size, device=device)
920
+ noise = th.randn_like(x_start)
921
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
922
+ # Calculate VLB term at the current timestep
923
+ with th.no_grad():
924
+ out = self._vb_terms_bpd(
925
+ model,
926
+ x_start=x_start,
927
+ x_t=x_t,
928
+ t=t_batch,
929
+ clip_denoised=clip_denoised,
930
+ model_kwargs=model_kwargs,
931
+ )
932
+ vb.append(out["output"])
933
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start)**2))
934
+ eps = self._predict_eps_from_xstart(x_t, t_batch,
935
+ out["pred_xstart"])
936
+ mse.append(mean_flat((eps - noise)**2))
937
+
938
+ vb = th.stack(vb, dim=1)
939
+ xstart_mse = th.stack(xstart_mse, dim=1)
940
+ mse = th.stack(mse, dim=1)
941
+
942
+ prior_bpd = self._prior_bpd(x_start)
943
+ total_bpd = vb.sum(dim=1) + prior_bpd
944
+ return {
945
+ "total_bpd": total_bpd,
946
+ "prior_bpd": prior_bpd,
947
+ "vb": vb,
948
+ "xstart_mse": xstart_mse,
949
+ "mse": mse,
950
+ }
951
+
952
+
953
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
954
+ """
955
+ Extract values from a 1-D numpy array for a batch of indices.
956
+
957
+ :param arr: the 1-D numpy array.
958
+ :param timesteps: a tensor of indices into the array to extract.
959
+ :param broadcast_shape: a larger shape of K dimensions with the batch
960
+ dimension equal to the length of timesteps.
961
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
962
+ """
963
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
964
+ while len(res.shape) < len(broadcast_shape):
965
+ res = res[..., None]
966
+ return res.expand(broadcast_shape)
967
+
968
+
969
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
970
+ """
971
+ Get a pre-defined beta schedule for the given name.
972
+
973
+ The beta schedule library consists of beta schedules which remain similar
974
+ in the limit of num_diffusion_timesteps.
975
+ Beta schedules may be added, but should not be removed or changed once
976
+ they are committed to maintain backwards compatibility.
977
+ """
978
+ if schedule_name == "linear":
979
+ # Linear schedule from Ho et al, extended to work for any number of
980
+ # diffusion steps.
981
+ scale = 1000 / num_diffusion_timesteps
982
+ beta_start = scale * 0.0001
983
+ beta_end = scale * 0.02
984
+ return np.linspace(beta_start,
985
+ beta_end,
986
+ num_diffusion_timesteps,
987
+ dtype=np.float64)
988
+ elif schedule_name == "cosine":
989
+ return betas_for_alpha_bar(
990
+ num_diffusion_timesteps,
991
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2,
992
+ )
993
+ elif schedule_name == "const0.01":
994
+ scale = 1000 / num_diffusion_timesteps
995
+ return np.array([scale * 0.01] * num_diffusion_timesteps,
996
+ dtype=np.float64)
997
+ elif schedule_name == "const0.015":
998
+ scale = 1000 / num_diffusion_timesteps
999
+ return np.array([scale * 0.015] * num_diffusion_timesteps,
1000
+ dtype=np.float64)
1001
+ elif schedule_name == "const0.008":
1002
+ scale = 1000 / num_diffusion_timesteps
1003
+ return np.array([scale * 0.008] * num_diffusion_timesteps,
1004
+ dtype=np.float64)
1005
+ elif schedule_name == "const0.0065":
1006
+ scale = 1000 / num_diffusion_timesteps
1007
+ return np.array([scale * 0.0065] * num_diffusion_timesteps,
1008
+ dtype=np.float64)
1009
+ elif schedule_name == "const0.0055":
1010
+ scale = 1000 / num_diffusion_timesteps
1011
+ return np.array([scale * 0.0055] * num_diffusion_timesteps,
1012
+ dtype=np.float64)
1013
+ elif schedule_name == "const0.0045":
1014
+ scale = 1000 / num_diffusion_timesteps
1015
+ return np.array([scale * 0.0045] * num_diffusion_timesteps,
1016
+ dtype=np.float64)
1017
+ elif schedule_name == "const0.0035":
1018
+ scale = 1000 / num_diffusion_timesteps
1019
+ return np.array([scale * 0.0035] * num_diffusion_timesteps,
1020
+ dtype=np.float64)
1021
+ elif schedule_name == "const0.0025":
1022
+ scale = 1000 / num_diffusion_timesteps
1023
+ return np.array([scale * 0.0025] * num_diffusion_timesteps,
1024
+ dtype=np.float64)
1025
+ elif schedule_name == "const0.0015":
1026
+ scale = 1000 / num_diffusion_timesteps
1027
+ return np.array([scale * 0.0015] * num_diffusion_timesteps,
1028
+ dtype=np.float64)
1029
+ else:
1030
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
1031
+
1032
+
1033
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
1034
+ """
1035
+ Create a beta schedule that discretizes the given alpha_t_bar function,
1036
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
1037
+
1038
+ :param num_diffusion_timesteps: the number of betas to produce.
1039
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
1040
+ produces the cumulative product of (1-beta) up to that
1041
+ part of the diffusion process.
1042
+ :param max_beta: the maximum beta to use; use values lower than 1 to
1043
+ prevent singularities.
1044
+ """
1045
+ betas = []
1046
+ for i in range(num_diffusion_timesteps):
1047
+ t1 = i / num_diffusion_timesteps
1048
+ t2 = (i + 1) / num_diffusion_timesteps
1049
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
1050
+ return np.array(betas)
1051
+
1052
+
1053
+ def normal_kl(mean1, logvar1, mean2, logvar2):
1054
+ """
1055
+ Compute the KL divergence between two gaussians.
1056
+
1057
+ Shapes are automatically broadcasted, so batches can be compared to
1058
+ scalars, among other use cases.
1059
+ """
1060
+ tensor = None
1061
+ for obj in (mean1, logvar1, mean2, logvar2):
1062
+ if isinstance(obj, th.Tensor):
1063
+ tensor = obj
1064
+ break
1065
+ assert tensor is not None, "at least one argument must be a Tensor"
1066
+
1067
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
1068
+ # Tensors, but it does not work for th.exp().
1069
+ logvar1, logvar2 = [
1070
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
1071
+ for x in (logvar1, logvar2)
1072
+ ]
1073
+
1074
+ return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) +
1075
+ ((mean1 - mean2)**2) * th.exp(-logvar2))
1076
+
1077
+
1078
+ def approx_standard_normal_cdf(x):
1079
+ """
1080
+ A fast approximation of the cumulative distribution function of the
1081
+ standard normal.
1082
+ """
1083
+ return 0.5 * (
1084
+ 1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
1085
+
1086
+
1087
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
1088
+ """
1089
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
1090
+ given image.
1091
+
1092
+ :param x: the target images. It is assumed that this was uint8 values,
1093
+ rescaled to the range [-1, 1].
1094
+ :param means: the Gaussian mean Tensor.
1095
+ :param log_scales: the Gaussian log stddev Tensor.
1096
+ :return: a tensor like x of log probabilities (in nats).
1097
+ """
1098
+ assert x.shape == means.shape == log_scales.shape
1099
+ centered_x = x - means
1100
+ inv_stdv = th.exp(-log_scales)
1101
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
1102
+ cdf_plus = approx_standard_normal_cdf(plus_in)
1103
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
1104
+ cdf_min = approx_standard_normal_cdf(min_in)
1105
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
1106
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
1107
+ cdf_delta = cdf_plus - cdf_min
1108
+ log_probs = th.where(
1109
+ x < -0.999,
1110
+ log_cdf_plus,
1111
+ th.where(x > 0.999, log_one_minus_cdf_min,
1112
+ th.log(cdf_delta.clamp(min=1e-12))),
1113
+ )
1114
+ assert log_probs.shape == x.shape
1115
+ return log_probs
1116
+
1117
+
1118
+ class DummyModel(th.nn.Module):
1119
+ def __init__(self, pred):
1120
+ super().__init__()
1121
+ self.pred = pred
1122
+
1123
+ def forward(self, *args, **kwargs):
1124
+ return DummyReturn(pred=self.pred)
1125
+
1126
+
1127
+ class DummyReturn(NamedTuple):
1128
+ pred: th.Tensor
diffusion/diffusion.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import *
2
+ from dataclasses import dataclass
3
+
4
+
5
+ def space_timesteps(num_timesteps, section_counts):
6
+ """
7
+ Create a list of timesteps to use from an original diffusion process,
8
+ given the number of timesteps we want to take from equally-sized portions
9
+ of the original process.
10
+
11
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
12
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
13
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
14
+
15
+ If the stride is a string starting with "ddim", then the fixed striding
16
+ from the DDIM paper is used, and only one section is allowed.
17
+
18
+ :param num_timesteps: the number of diffusion steps in the original
19
+ process to divide up.
20
+ :param section_counts: either a list of numbers, or a string containing
21
+ comma-separated numbers, indicating the step count
22
+ per section. As a special case, use "ddimN" where N
23
+ is a number of steps to use the striding from the
24
+ DDIM paper.
25
+ :return: a set of diffusion steps from the original process to use.
26
+ """
27
+ if isinstance(section_counts, str):
28
+ if section_counts.startswith("ddim"):
29
+ desired_count = int(section_counts[len("ddim"):])
30
+ for i in range(1, num_timesteps):
31
+ if len(range(0, num_timesteps, i)) == desired_count:
32
+ return set(range(0, num_timesteps, i))
33
+ raise ValueError(
34
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
35
+ )
36
+ section_counts = [int(x) for x in section_counts.split(",")]
37
+ size_per = num_timesteps // len(section_counts)
38
+ extra = num_timesteps % len(section_counts)
39
+ start_idx = 0
40
+ all_steps = []
41
+ for i, section_count in enumerate(section_counts):
42
+ size = size_per + (1 if i < extra else 0)
43
+ if size < section_count:
44
+ raise ValueError(
45
+ f"cannot divide section of {size} steps into {section_count}")
46
+ if section_count <= 1:
47
+ frac_stride = 1
48
+ else:
49
+ frac_stride = (size - 1) / (section_count - 1)
50
+ cur_idx = 0.0
51
+ taken_steps = []
52
+ for _ in range(section_count):
53
+ taken_steps.append(start_idx + round(cur_idx))
54
+ cur_idx += frac_stride
55
+ all_steps += taken_steps
56
+ start_idx += size
57
+ return set(all_steps)
58
+
59
+
60
+ @dataclass
61
+ class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
62
+ use_timesteps: Tuple[int] = None
63
+
64
+ def make_sampler(self):
65
+ return SpacedDiffusionBeatGans(self)
66
+
67
+
68
+ class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
69
+ """
70
+ A diffusion process which can skip steps in a base diffusion process.
71
+
72
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
73
+ original diffusion process to retain.
74
+ :param kwargs: the kwargs to create the base diffusion process.
75
+ """
76
+ def __init__(self, conf: SpacedDiffusionBeatGansConfig):
77
+ self.conf = conf
78
+ self.use_timesteps = set(conf.use_timesteps)
79
+ # how the new t's mapped to the old t's
80
+ self.timestep_map = []
81
+ self.original_num_steps = len(conf.betas)
82
+
83
+ base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
84
+ last_alpha_cumprod = 1.0
85
+ new_betas = []
86
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
87
+ if i in self.use_timesteps:
88
+ # getting the new betas of the new timesteps
89
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
90
+ last_alpha_cumprod = alpha_cumprod
91
+ self.timestep_map.append(i)
92
+ conf.betas = np.array(new_betas)
93
+ super().__init__(conf)
94
+
95
+ def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
96
+ return super().p_mean_variance(self._wrap_model(model), *args,
97
+ **kwargs)
98
+
99
+ def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
100
+ return super().training_losses(self._wrap_model(model), *args,
101
+ **kwargs)
102
+
103
+ def condition_mean(self, cond_fn, *args, **kwargs):
104
+ return super().condition_mean(self._wrap_model(cond_fn), *args,
105
+ **kwargs)
106
+
107
+ def condition_score(self, cond_fn, *args, **kwargs):
108
+ return super().condition_score(self._wrap_model(cond_fn), *args,
109
+ **kwargs)
110
+
111
+ def _wrap_model(self, model: Model):
112
+ if isinstance(model, _WrappedModel):
113
+ return model
114
+ return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
115
+ self.original_num_steps)
116
+
117
+ def _scale_timesteps(self, t):
118
+ # Scaling is done by the wrapped model.
119
+ return t
120
+
121
+
122
+ class _WrappedModel:
123
+ """
124
+ converting the supplied t's to the old t's scales.
125
+ """
126
+ def __init__(self, model, timestep_map, rescale_timesteps,
127
+ original_num_steps):
128
+ self.model = model
129
+ self.timestep_map = timestep_map
130
+ self.rescale_timesteps = rescale_timesteps
131
+ self.original_num_steps = original_num_steps
132
+
133
+ def forward(self,motion_start, motion_direction_start, audio_feats,face_location, face_scale,yaw_pitch_roll, x_t, t, control_flag=False):
134
+ """
135
+ Args:
136
+ t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
137
+ t_cond: the same as t but can be of different values
138
+ """
139
+ map_tensor = th.tensor(self.timestep_map,
140
+ device=t.device,
141
+ dtype=t.dtype)
142
+
143
+ def do(t):
144
+ new_ts = map_tensor[t]
145
+ if self.rescale_timesteps:
146
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
147
+ return new_ts
148
+
149
+ return self.model(motion_start, motion_direction_start, audio_feats,face_location, face_scale,yaw_pitch_roll, x_t,do(t), control_flag=control_flag)
150
+
151
+ def __getattr__(self, name):
152
+ # allow for calling the model's methods
153
+ if hasattr(self.model, name):
154
+ func = getattr(self.model, name)
155
+ return func
156
+ raise AttributeError(name)
diffusion/resample.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import torch.distributed as dist
6
+
7
+
8
+ def create_named_schedule_sampler(name, diffusion):
9
+ """
10
+ Create a ScheduleSampler from a library of pre-defined samplers.
11
+
12
+ :param name: the name of the sampler.
13
+ :param diffusion: the diffusion object to sample for.
14
+ """
15
+ if name == "uniform":
16
+ return UniformSampler(diffusion)
17
+ else:
18
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
19
+
20
+
21
+ class ScheduleSampler(ABC):
22
+ """
23
+ A distribution over timesteps in the diffusion process, intended to reduce
24
+ variance of the objective.
25
+
26
+ By default, samplers perform unbiased importance sampling, in which the
27
+ objective's mean is unchanged.
28
+ However, subclasses may override sample() to change how the resampled
29
+ terms are reweighted, allowing for actual changes in the objective.
30
+ """
31
+ @abstractmethod
32
+ def weights(self):
33
+ """
34
+ Get a numpy array of weights, one per diffusion step.
35
+
36
+ The weights needn't be normalized, but must be positive.
37
+ """
38
+
39
+ def sample(self, batch_size, device):
40
+ """
41
+ Importance-sample timesteps for a batch.
42
+
43
+ :param batch_size: the number of timesteps.
44
+ :param device: the torch device to save to.
45
+ :return: a tuple (timesteps, weights):
46
+ - timesteps: a tensor of timestep indices.
47
+ - weights: a tensor of weights to scale the resulting losses.
48
+ """
49
+ w = self.weights()
50
+ p = w / np.sum(w)
51
+ indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
52
+ indices = th.from_numpy(indices_np).long().to(device)
53
+ weights_np = 1 / (len(p) * p[indices_np])
54
+ weights = th.from_numpy(weights_np).float().to(device)
55
+ return indices, weights
56
+
57
+
58
+ class UniformSampler(ScheduleSampler):
59
+ def __init__(self, num_timesteps):
60
+ self._weights = np.ones([num_timesteps])
61
+
62
+ def weights(self):
63
+ return self._weights
dist_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from torch import distributed
3
+
4
+
5
+ def barrier():
6
+ if distributed.is_initialized():
7
+ distributed.barrier()
8
+ else:
9
+ pass
10
+
11
+
12
+ def broadcast(data, src):
13
+ if distributed.is_initialized():
14
+ distributed.broadcast(data, src)
15
+ else:
16
+ pass
17
+
18
+
19
+ def all_gather(data: List, src):
20
+ if distributed.is_initialized():
21
+ distributed.all_gather(data, src)
22
+ else:
23
+ data[0] = src
24
+
25
+
26
+ def get_rank():
27
+ if distributed.is_initialized():
28
+ return distributed.get_rank()
29
+ else:
30
+ return 0
31
+
32
+
33
+ def get_world_size():
34
+ if distributed.is_initialized():
35
+ return distributed.get_world_size()
36
+ else:
37
+ return 1
38
+
39
+
40
+ def chunk_size(size, rank, world_size):
41
+ extra = rank < size % world_size
42
+ return size // world_size + extra
experiment.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from pytorch_lightning import loggers as pl_loggers
8
+ from pytorch_lightning.callbacks import *
9
+ from torch.cuda import amp
10
+ from torch.optim.optimizer import Optimizer
11
+ from torch.utils.data.dataset import TensorDataset
12
+ from model.seq2seq import DiffusionPredictor
13
+
14
+ from config import *
15
+ from dist_utils import *
16
+ from renderer import *
17
+
18
+ # This part is modified from: https://github.com/phizaz/diffae/blob/master/experiment.py
19
+ class LitModel(pl.LightningModule):
20
+ def __init__(self, conf: TrainConfig):
21
+ super().__init__()
22
+ assert conf.train_mode != TrainMode.manipulate
23
+ if conf.seed is not None:
24
+ pl.seed_everything(conf.seed)
25
+
26
+ self.save_hyperparameters(conf.as_dict_jsonable())
27
+
28
+ self.conf = conf
29
+
30
+ self.model = DiffusionPredictor(conf)
31
+
32
+ self.ema_model = copy.deepcopy(self.model)
33
+ self.ema_model.requires_grad_(False)
34
+ self.ema_model.eval()
35
+
36
+ self.sampler = conf.make_diffusion_conf().make_sampler()
37
+ self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
38
+
39
+ # this is shared for both model and latent
40
+ self.T_sampler = conf.make_T_sampler()
41
+
42
+ if conf.train_mode.use_latent_net():
43
+ self.latent_sampler = conf.make_latent_diffusion_conf(
44
+ ).make_sampler()
45
+ self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf(
46
+ ).make_sampler()
47
+ else:
48
+ self.latent_sampler = None
49
+ self.eval_latent_sampler = None
50
+
51
+ # initial variables for consistent sampling
52
+ self.register_buffer(
53
+ 'x_T',
54
+ torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size))
55
+
56
+
57
+ def render(self, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, step_T, control_flag):
58
+ if step_T is None:
59
+ sampler = self.eval_sampler
60
+ else:
61
+ sampler = self.conf._make_diffusion_conf(step_T).make_sampler()
62
+
63
+ pred_img = render_condition(self.conf,
64
+ self.ema_model,
65
+ sampler, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, control_flag)
66
+ return pred_img
67
+
68
+ def forward(self, noise=None, x_start=None, ema_model: bool = False):
69
+ with amp.autocast(False):
70
+ if not self.disable_ema:
71
+ model = self.ema_model
72
+ else:
73
+ model = self.model
74
+ gen = self.eval_sampler.sample(model=model,
75
+ noise=noise,
76
+ x_start=x_start)
77
+ return gen
78
+
79
+ def setup(self, stage=None) -> None:
80
+ """
81
+ make datasets & seeding each worker separately
82
+ """
83
+ ##############################################
84
+ # NEED TO SET THE SEED SEPARATELY HERE
85
+ if self.conf.seed is not None:
86
+ seed = self.conf.seed * get_world_size() + self.global_rank
87
+ np.random.seed(seed)
88
+ torch.manual_seed(seed)
89
+ torch.cuda.manual_seed(seed)
90
+ print('local seed:', seed)
91
+ ##############################################
92
+
93
+ self.train_data = self.conf.make_dataset()
94
+ print('train data:', len(self.train_data))
95
+ self.val_data = self.train_data
96
+ print('val data:', len(self.val_data))
97
+
98
+ def _train_dataloader(self, drop_last=True):
99
+ """
100
+ really make the dataloader
101
+ """
102
+ # make sure to use the fraction of batch size
103
+ # the batch size is global!
104
+ conf = self.conf.clone()
105
+ conf.batch_size = self.batch_size
106
+
107
+ dataloader = conf.make_loader(self.train_data,
108
+ shuffle=True,
109
+ drop_last=drop_last)
110
+ return dataloader
111
+
112
+ def train_dataloader(self):
113
+ """
114
+ return the dataloader, if diffusion mode => return image dataset
115
+ if latent mode => return the inferred latent dataset
116
+ """
117
+ print('on train dataloader start ...')
118
+ if self.conf.train_mode.require_dataset_infer():
119
+ if self.conds is None:
120
+ # usually we load self.conds from a file
121
+ # so we do not need to do this again!
122
+ self.conds = self.infer_whole_dataset()
123
+ # need to use float32! unless the mean & std will be off!
124
+ # (1, c)
125
+ self.conds_mean.data = self.conds.float().mean(dim=0,
126
+ keepdim=True)
127
+ self.conds_std.data = self.conds.float().std(dim=0,
128
+ keepdim=True)
129
+ print('mean:', self.conds_mean.mean(), 'std:',
130
+ self.conds_std.mean())
131
+
132
+ # return the dataset with pre-calculated conds
133
+ conf = self.conf.clone()
134
+ conf.batch_size = self.batch_size
135
+ data = TensorDataset(self.conds)
136
+ return conf.make_loader(data, shuffle=True)
137
+ else:
138
+ return self._train_dataloader()
139
+
140
+ @property
141
+ def batch_size(self):
142
+ """
143
+ local batch size for each worker
144
+ """
145
+ ws = get_world_size()
146
+ assert self.conf.batch_size % ws == 0
147
+ return self.conf.batch_size // ws
148
+
149
+ @property
150
+ def num_samples(self):
151
+ """
152
+ (global) batch size * iterations
153
+ """
154
+ # batch size here is global!
155
+ # global_step already takes into account the accum batches
156
+ return self.global_step * self.conf.batch_size_effective
157
+
158
+ def is_last_accum(self, batch_idx):
159
+ """
160
+ is it the last gradient accumulation loop?
161
+ used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
162
+ """
163
+ return (batch_idx + 1) % self.conf.accum_batches == 0
164
+
165
+ def training_step(self, batch, batch_idx):
166
+ """
167
+ given an input, calculate the loss function
168
+ no optimization at this stage.
169
+ """
170
+ with amp.autocast(False):
171
+ motion_start = batch['motion_start'] # torch.Size([B, 512])
172
+ motion_direction = batch['motion_direction'] # torch.Size([B, 125, 20])
173
+ audio_feats = batch['audio_feats'].float() # torch.Size([B, 25, 250, 1024])
174
+ face_location = batch['face_location'].float() # torch.Size([B, 125])
175
+ face_scale = batch['face_scale'].float() # torch.Size([B, 125, 1])
176
+ yaw_pitch_roll = batch['yaw_pitch_roll'].float() # torch.Size([B, 125, 3])
177
+ motion_direction_start = batch['motion_direction_start'].float() # torch.Size([B, 20])
178
+
179
+ # import pdb; pdb.set_trace()
180
+ if self.conf.train_mode == TrainMode.diffusion:
181
+ """
182
+ main training mode!!!
183
+ """
184
+ # with numpy seed we have the problem that the sample t's are related!
185
+ t, weight = self.T_sampler.sample(len(motion_start), motion_start.device)
186
+ losses = self.sampler.training_losses(model=self.model,
187
+ motion_direction_start=motion_direction_start,
188
+ motion_target=motion_direction,
189
+ motion_start=motion_start,
190
+ audio_feats=audio_feats,
191
+ face_location=face_location,
192
+ face_scale=face_scale,
193
+ yaw_pitch_roll=yaw_pitch_roll,
194
+ t=t)
195
+ else:
196
+ raise NotImplementedError()
197
+
198
+ loss = losses['loss'].mean()
199
+ # divide by accum batches to make the accumulated gradient exact!
200
+ for key in losses.keys():
201
+ losses[key] = self.all_gather(losses[key]).mean()
202
+
203
+ if self.global_rank == 0:
204
+ self.logger.experiment.add_scalar('loss', losses['loss'],
205
+ self.num_samples)
206
+ for key in losses:
207
+ self.logger.experiment.add_scalar(
208
+ f'loss/{key}', losses[key], self.num_samples)
209
+
210
+ return {'loss': loss}
211
+
212
+ def on_train_batch_end(self, outputs, batch, batch_idx: int,
213
+ dataloader_idx: int) -> None:
214
+ """
215
+ after each training step ...
216
+ """
217
+ if self.is_last_accum(batch_idx):
218
+
219
+ if self.conf.train_mode == TrainMode.latent_diffusion:
220
+ # it trains only the latent hence change only the latent
221
+ ema(self.model.latent_net, self.ema_model.latent_net,
222
+ self.conf.ema_decay)
223
+ else:
224
+ ema(self.model, self.ema_model, self.conf.ema_decay)
225
+
226
+ def on_before_optimizer_step(self, optimizer: Optimizer,
227
+ optimizer_idx: int) -> None:
228
+ # fix the fp16 + clip grad norm problem with pytorch lightinng
229
+ # this is the currently correct way to do it
230
+ if self.conf.grad_clip > 0:
231
+ # from trainer.params_grads import grads_norm, iter_opt_params
232
+ params = [
233
+ p for group in optimizer.param_groups for p in group['params']
234
+ ]
235
+ torch.nn.utils.clip_grad_norm_(params,
236
+ max_norm=self.conf.grad_clip)
237
+ def configure_optimizers(self):
238
+ out = {}
239
+ if self.conf.optimizer == OptimizerType.adam:
240
+ optim = torch.optim.Adam(self.model.parameters(),
241
+ lr=self.conf.lr,
242
+ weight_decay=self.conf.weight_decay)
243
+ elif self.conf.optimizer == OptimizerType.adamw:
244
+ optim = torch.optim.AdamW(self.model.parameters(),
245
+ lr=self.conf.lr,
246
+ weight_decay=self.conf.weight_decay)
247
+ else:
248
+ raise NotImplementedError()
249
+ out['optimizer'] = optim
250
+ if self.conf.warmup > 0:
251
+ sched = torch.optim.lr_scheduler.LambdaLR(optim,
252
+ lr_lambda=WarmupLR(
253
+ self.conf.warmup))
254
+ out['lr_scheduler'] = {
255
+ 'scheduler': sched,
256
+ 'interval': 'step',
257
+ }
258
+ return out
259
+
260
+ def split_tensor(self, x):
261
+ """
262
+ extract the tensor for a corresponding "worker" in the batch dimension
263
+
264
+ Args:
265
+ x: (n, c)
266
+
267
+ Returns: x: (n_local, c)
268
+ """
269
+ n = len(x)
270
+ rank = self.global_rank
271
+ world_size = get_world_size()
272
+ # print(f'rank: {rank}/{world_size}')
273
+ per_rank = n // world_size
274
+ return x[rank * per_rank:(rank + 1) * per_rank]
275
+
276
+ def ema(source, target, decay):
277
+ source_dict = source.state_dict()
278
+ target_dict = target.state_dict()
279
+ for key in source_dict.keys():
280
+ target_dict[key].data.copy_(target_dict[key].data * decay +
281
+ source_dict[key].data * (1 - decay))
282
+
283
+
284
+ class WarmupLR:
285
+ def __init__(self, warmup) -> None:
286
+ self.warmup = warmup
287
+
288
+ def __call__(self, step):
289
+ return min(step, self.warmup) / self.warmup
290
+
291
+
292
+ def is_time(num_samples, every, step_size):
293
+ closest = (num_samples // every) * every
294
+ return num_samples - closest < step_size
295
+
296
+
297
+ def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'):
298
+ print('conf:', conf.name)
299
+ # assert not (conf.fp16 and conf.grad_clip > 0
300
+ # ), 'pytorch lightning has bug with amp + gradient clipping'
301
+ model = LitModel(conf)
302
+
303
+ if not os.path.exists(conf.logdir):
304
+ os.makedirs(conf.logdir)
305
+ checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}',
306
+ save_last=True,
307
+ save_top_k=-1,
308
+ every_n_epochs=10)
309
+ checkpoint_path = f'{conf.logdir}/last.ckpt'
310
+ print('ckpt path:', checkpoint_path)
311
+ if os.path.exists(checkpoint_path):
312
+ resume = checkpoint_path
313
+ print('resume!')
314
+ else:
315
+ if conf.continue_from is not None:
316
+ # continue from a checkpoint
317
+ resume = conf.continue_from.pathcd
318
+ else:
319
+ resume = None
320
+
321
+ tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir,
322
+ name=None,
323
+ version='')
324
+
325
+ # from pytorch_lightning.
326
+
327
+ plugins = []
328
+ if len(gpus) == 1 and nodes == 1:
329
+ accelerator = None
330
+ else:
331
+ accelerator = 'ddp'
332
+ from pytorch_lightning.plugins import DDPPlugin
333
+
334
+ # important for working with gradient checkpoint
335
+ plugins.append(DDPPlugin(find_unused_parameters=True))
336
+
337
+ trainer = pl.Trainer(
338
+ max_steps=conf.total_samples // conf.batch_size_effective,
339
+ resume_from_checkpoint=resume,
340
+ gpus=gpus,
341
+ num_nodes=nodes,
342
+ accelerator=accelerator,
343
+ precision=16 if conf.fp16 else 32,
344
+ callbacks=[
345
+ checkpoint,
346
+ LearningRateMonitor(),
347
+ ],
348
+ # clip in the model instead
349
+ # gradient_clip_val=conf.grad_clip,
350
+ replace_sampler_ddp=True,
351
+ logger=tb_logger,
352
+ accumulate_grad_batches=conf.accum_batches,
353
+ plugins=plugins,
354
+ )
355
+
356
+ trainer.fit(model)
face_sr/face_enhancer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from gfpgan import GFPGANer
5
+
6
+ from tqdm import tqdm
7
+
8
+ from .videoio import load_video_to_cv2
9
+
10
+ import cv2
11
+
12
+
13
+ class GeneratorWithLen(object):
14
+ """ From https://stackoverflow.com/a/7460929 """
15
+
16
+ def __init__(self, gen, length):
17
+ self.gen = gen
18
+ self.length = length
19
+
20
+ def __len__(self):
21
+ return self.length
22
+
23
+ def __iter__(self):
24
+ return self.gen
25
+
26
+ def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
27
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
28
+ return list(gen)
29
+
30
+ def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):
31
+ """ Provide a generator with a __len__ method so that it can passed to functions that
32
+ call len()"""
33
+
34
+ if os.path.isfile(images): # handle video to images
35
+ # TODO: Create a generator version of load_video_to_cv2
36
+ images = load_video_to_cv2(images)
37
+
38
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
39
+ gen_with_len = GeneratorWithLen(gen, len(images))
40
+ return gen_with_len
41
+
42
+ def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):
43
+ """ Provide a generator function so that all of the enhanced images don't need
44
+ to be stored in memory at the same time. This can save tons of RAM compared to
45
+ the enhancer function. """
46
+
47
+ print('face enhancer....')
48
+ if not isinstance(images, list) and os.path.isfile(images): # handle video to images
49
+ images = load_video_to_cv2(images)
50
+
51
+ # ------------------------ set up GFPGAN restorer ------------------------
52
+ if method == 'gfpgan':
53
+ arch = 'clean'
54
+ channel_multiplier = 2
55
+ model_name = 'GFPGANv1.4'
56
+ url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
57
+ elif method == 'RestoreFormer':
58
+ arch = 'RestoreFormer'
59
+ channel_multiplier = 2
60
+ model_name = 'RestoreFormer'
61
+ url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
62
+ elif method == 'codeformer': # TODO:
63
+ arch = 'CodeFormer'
64
+ channel_multiplier = 2
65
+ model_name = 'CodeFormer'
66
+ url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
67
+ else:
68
+ raise ValueError(f'Wrong model version {method}.')
69
+
70
+
71
+ # ------------------------ set up background upsampler ------------------------
72
+ if bg_upsampler == 'realesrgan':
73
+ if not torch.cuda.is_available(): # CPU
74
+ import warnings
75
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
76
+ 'If you really want to use it, please modify the corresponding codes.')
77
+ bg_upsampler = None
78
+ else:
79
+ from basicsr.archs.rrdbnet_arch import RRDBNet
80
+ from realesrgan import RealESRGANer
81
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
82
+ bg_upsampler = RealESRGANer(
83
+ scale=2,
84
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
85
+ model=model,
86
+ tile=400,
87
+ tile_pad=10,
88
+ pre_pad=0,
89
+ half=True) # need to set False in CPU mode
90
+ else:
91
+ bg_upsampler = None
92
+
93
+ # determine model paths
94
+ model_path = os.path.join('gfpgan/weights', model_name + '.pth')
95
+
96
+ if not os.path.isfile(model_path):
97
+ model_path = os.path.join('checkpoints', model_name + '.pth')
98
+
99
+ if not os.path.isfile(model_path):
100
+ # download pre-trained models from url
101
+ model_path = url
102
+
103
+ restorer = GFPGANer(
104
+ model_path=model_path,
105
+ upscale=2,
106
+ arch=arch,
107
+ channel_multiplier=channel_multiplier,
108
+ bg_upsampler=bg_upsampler)
109
+
110
+ # ------------------------ restore ------------------------
111
+ for idx in tqdm(range(len(images)), 'Face Enhancer:'):
112
+
113
+ img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
114
+
115
+ # restore faces and background if necessary
116
+ cropped_faces, restored_faces, r_img = restorer.enhance(
117
+ img,
118
+ has_aligned=False,
119
+ only_center_face=False,
120
+ paste_back=True)
121
+
122
+ r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
123
+ yield r_img
face_sr/videoio.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import uuid
3
+
4
+ import os
5
+
6
+ import cv2
7
+
8
+ def load_video_to_cv2(input_path):
9
+ video_stream = cv2.VideoCapture(input_path)
10
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
11
+ full_frames = []
12
+ while 1:
13
+ still_reading, frame = video_stream.read()
14
+ if not still_reading:
15
+ video_stream.release()
16
+ break
17
+ full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
18
+ return full_frames
19
+
20
+ def save_video_with_watermark(video, audio, save_path, watermark=False):
21
+ temp_file = str(uuid.uuid4())+'.mp4'
22
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file)
23
+ os.system(cmd)
24
+
25
+ if watermark is False:
26
+ shutil.move(temp_file, save_path)
27
+ else:
28
+ # watermark
29
+ try:
30
+ ##### check if stable-diffusion-webui
31
+ import webui
32
+ from modules import paths
33
+ watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png"
34
+ except:
35
+ # get the root path of sadtalker.
36
+ dir_path = os.path.dirname(os.path.realpath(__file__))
37
+ watarmark_path = dir_path+"/../../docs/sadtalker_logo.png"
38
+
39
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path)
40
+ os.system(cmd)
41
+ os.remove(temp_file)
model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
3
+ from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
4
+
5
+ Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
6
+ ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
model/base.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2
+ # This program is free software; you can redistribute it and/or modify
3
+ # it under the terms of the MIT License.
4
+ # This program is distributed in the hope that it will be useful,
5
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
6
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7
+ # MIT License for more details.
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ class BaseModule(torch.nn.Module):
14
+ def __init__(self):
15
+ super(BaseModule, self).__init__()
16
+
17
+ @property
18
+ def nparams(self):
19
+ """
20
+ Returns number of trainable parameters of the module.
21
+ """
22
+ num_params = 0
23
+ for name, param in self.named_parameters():
24
+ if param.requires_grad:
25
+ num_params += np.prod(param.detach().cpu().numpy().shape)
26
+ return num_params
27
+
28
+
29
+ def relocate_input(self, x: list):
30
+ """
31
+ Relocates provided tensors to the same device set for the module.
32
+ """
33
+ device = next(self.parameters()).device
34
+ for i in range(len(x)):
35
+ if isinstance(x[i], torch.Tensor) and x[i].device != device:
36
+ x[i] = x[i].to(device)
37
+ return x
model/blocks.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import abstractmethod
3
+ from dataclasses import dataclass
4
+ from numbers import Number
5
+
6
+ import torch as th
7
+ import torch.nn.functional as F
8
+ from choices import *
9
+ from config_base import BaseConfig
10
+ from torch import nn
11
+
12
+ from .nn import (avg_pool_nd, conv_nd, linear, normalization,
13
+ timestep_embedding, torch_checkpoint, zero_module)
14
+
15
+
16
+ class ScaleAt(Enum):
17
+ after_norm = 'afternorm'
18
+
19
+
20
+ class TimestepBlock(nn.Module):
21
+ """
22
+ Any module where forward() takes timestep embeddings as a second argument.
23
+ """
24
+ @abstractmethod
25
+ def forward(self, x, emb=None, cond=None, lateral=None):
26
+ """
27
+ Apply the module to `x` given `emb` timestep embeddings.
28
+ """
29
+
30
+
31
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
32
+ """
33
+ A sequential module that passes timestep embeddings to the children that
34
+ support it as an extra input.
35
+ """
36
+ def forward(self, x, emb=None, cond=None, lateral=None):
37
+ for layer in self:
38
+ if isinstance(layer, TimestepBlock):
39
+ x = layer(x, emb=emb, cond=cond, lateral=lateral)
40
+ else:
41
+ x = layer(x)
42
+ return x
43
+
44
+
45
+ @dataclass
46
+ class ResBlockConfig(BaseConfig):
47
+ channels: int
48
+ emb_channels: int
49
+ dropout: float
50
+ out_channels: int = None
51
+ # condition the resblock with time (and encoder's output)
52
+ use_condition: bool = True
53
+ # whether to use 3x3 conv for skip path when the channels aren't matched
54
+ use_conv: bool = False
55
+ # dimension of conv (always 2 = 2d)
56
+ dims: int = 2
57
+ # gradient checkpoint
58
+ use_checkpoint: bool = False
59
+ up: bool = False
60
+ down: bool = False
61
+ # whether to condition with both time & encoder's output
62
+ two_cond: bool = False
63
+ # number of encoders' output channels
64
+ cond_emb_channels: int = None
65
+ # suggest: False
66
+ has_lateral: bool = False
67
+ lateral_channels: int = None
68
+ # whether to init the convolution with zero weights
69
+ # this is default from BeatGANs and seems to help learning
70
+ use_zero_module: bool = True
71
+
72
+ def __post_init__(self):
73
+ self.out_channels = self.out_channels or self.channels
74
+ self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
75
+
76
+ def make_model(self):
77
+ return ResBlock(self)
78
+
79
+
80
+ class ResBlock(TimestepBlock):
81
+ """
82
+ A residual block that can optionally change the number of channels.
83
+
84
+ total layers:
85
+ in_layers
86
+ - norm
87
+ - act
88
+ - conv
89
+ out_layers
90
+ - norm
91
+ - (modulation)
92
+ - act
93
+ - conv
94
+ """
95
+ def __init__(self, conf: ResBlockConfig):
96
+ super().__init__()
97
+ self.conf = conf
98
+
99
+ #############################
100
+ # IN LAYERS
101
+ #############################
102
+ assert conf.lateral_channels is None
103
+ layers = [
104
+ normalization(conf.channels),
105
+ nn.SiLU(),
106
+ conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
107
+ ]
108
+ self.in_layers = nn.Sequential(*layers)
109
+
110
+ self.updown = conf.up or conf.down
111
+
112
+ if conf.up:
113
+ self.h_upd = Upsample(conf.channels, False, conf.dims)
114
+ self.x_upd = Upsample(conf.channels, False, conf.dims)
115
+ elif conf.down:
116
+ self.h_upd = Downsample(conf.channels, False, conf.dims)
117
+ self.x_upd = Downsample(conf.channels, False, conf.dims)
118
+ else:
119
+ self.h_upd = self.x_upd = nn.Identity()
120
+
121
+ #############################
122
+ # OUT LAYERS CONDITIONS
123
+ #############################
124
+ if conf.use_condition:
125
+ # condition layers for the out_layers
126
+ self.emb_layers = nn.Sequential(
127
+ nn.SiLU(),
128
+ linear(conf.emb_channels, 2 * conf.out_channels),
129
+ )
130
+
131
+ if conf.two_cond:
132
+ self.cond_emb_layers = nn.Sequential(
133
+ nn.SiLU(),
134
+ linear(conf.cond_emb_channels, conf.out_channels),
135
+ )
136
+ #############################
137
+ # OUT LAYERS (ignored when there is no condition)
138
+ #############################
139
+ # original version
140
+ conv = conv_nd(conf.dims,
141
+ conf.out_channels,
142
+ conf.out_channels,
143
+ 3,
144
+ padding=1)
145
+ if conf.use_zero_module:
146
+ # zere out the weights
147
+ # it seems to help training
148
+ conv = zero_module(conv)
149
+
150
+ # construct the layers
151
+ # - norm
152
+ # - (modulation)
153
+ # - act
154
+ # - dropout
155
+ # - conv
156
+ layers = []
157
+ layers += [
158
+ normalization(conf.out_channels),
159
+ nn.SiLU(),
160
+ nn.Dropout(p=conf.dropout),
161
+ conv,
162
+ ]
163
+ self.out_layers = nn.Sequential(*layers)
164
+
165
+ #############################
166
+ # SKIP LAYERS
167
+ #############################
168
+ if conf.out_channels == conf.channels:
169
+ # cannot be used with gatedconv, also gatedconv is alsways used as the first block
170
+ self.skip_connection = nn.Identity()
171
+ else:
172
+ if conf.use_conv:
173
+ kernel_size = 3
174
+ padding = 1
175
+ else:
176
+ kernel_size = 1
177
+ padding = 0
178
+
179
+ self.skip_connection = conv_nd(conf.dims,
180
+ conf.channels,
181
+ conf.out_channels,
182
+ kernel_size,
183
+ padding=padding)
184
+
185
+ def forward(self, x, emb=None, cond=None, lateral=None):
186
+ """
187
+ Apply the block to a Tensor, conditioned on a timestep embedding.
188
+
189
+ Args:
190
+ x: input
191
+ lateral: lateral connection from the encoder
192
+ """
193
+ return torch_checkpoint(self._forward, (x, emb, cond, lateral),
194
+ self.conf.use_checkpoint)
195
+
196
+ def _forward(
197
+ self,
198
+ x,
199
+ emb=None,
200
+ cond=None,
201
+ lateral=None,
202
+ ):
203
+ """
204
+ Args:
205
+ lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
206
+ """
207
+ if self.conf.has_lateral:
208
+ # lateral may be supplied even if it doesn't require
209
+ # the model will take the lateral only if "has_lateral"
210
+ assert lateral is not None
211
+ x = th.cat([x, lateral], dim=1)
212
+
213
+ if self.updown:
214
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
215
+ h = in_rest(x)
216
+ h = self.h_upd(h)
217
+ x = self.x_upd(x)
218
+ h = in_conv(h)
219
+ else:
220
+ h = self.in_layers(x)
221
+
222
+ if self.conf.use_condition:
223
+ # it's possible that the network may not receieve the time emb
224
+ # this happens with autoenc and setting the time_at
225
+ if emb is not None:
226
+ emb_out = self.emb_layers(emb).type(h.dtype)
227
+ else:
228
+ emb_out = None
229
+
230
+ if self.conf.two_cond:
231
+ # it's possible that the network is two_cond
232
+ # but it doesn't get the second condition
233
+ # in which case, we ignore the second condition
234
+ # and treat as if the network has one condition
235
+ if cond is None:
236
+ cond_out = None
237
+ else:
238
+ cond_out = self.cond_emb_layers(cond).type(h.dtype)
239
+
240
+ if cond_out is not None:
241
+ while len(cond_out.shape) < len(h.shape):
242
+ cond_out = cond_out[..., None]
243
+ else:
244
+ cond_out = None
245
+
246
+ # this is the new refactored code
247
+ h = apply_conditions(
248
+ h=h,
249
+ emb=emb_out,
250
+ cond=cond_out,
251
+ layers=self.out_layers,
252
+ scale_bias=1,
253
+ in_channels=self.conf.out_channels,
254
+ up_down_layer=None,
255
+ )
256
+
257
+ return self.skip_connection(x) + h
258
+
259
+
260
+ def apply_conditions(
261
+ h,
262
+ emb=None,
263
+ cond=None,
264
+ layers: nn.Sequential = None,
265
+ scale_bias: float = 1,
266
+ in_channels: int = 512,
267
+ up_down_layer: nn.Module = None,
268
+ ):
269
+ """
270
+ apply conditions on the feature maps
271
+
272
+ Args:
273
+ emb: time conditional (ready to scale + shift)
274
+ cond: encoder's conditional (read to scale + shift)
275
+ """
276
+ two_cond = emb is not None and cond is not None
277
+
278
+ if emb is not None:
279
+ # adjusting shapes
280
+ while len(emb.shape) < len(h.shape):
281
+ emb = emb[..., None]
282
+
283
+ if two_cond:
284
+ # adjusting shapes
285
+ while len(cond.shape) < len(h.shape):
286
+ cond = cond[..., None]
287
+ # time first
288
+ scale_shifts = [emb, cond]
289
+ else:
290
+ # "cond" is not used with single cond mode
291
+ scale_shifts = [emb]
292
+
293
+ # support scale, shift or shift only
294
+ for i, each in enumerate(scale_shifts):
295
+ if each is None:
296
+ # special case: the condition is not provided
297
+ a = None
298
+ b = None
299
+ else:
300
+ if each.shape[1] == in_channels * 2:
301
+ a, b = th.chunk(each, 2, dim=1)
302
+ else:
303
+ a = each
304
+ b = None
305
+ scale_shifts[i] = (a, b)
306
+
307
+ # condition scale bias could be a list
308
+ if isinstance(scale_bias, Number):
309
+ biases = [scale_bias] * len(scale_shifts)
310
+ else:
311
+ # a list
312
+ biases = scale_bias
313
+
314
+ # default, the scale & shift are applied after the group norm but BEFORE SiLU
315
+ pre_layers, post_layers = layers[0], layers[1:]
316
+
317
+ # spilt the post layer to be able to scale up or down before conv
318
+ # post layers will contain only the conv
319
+ mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
320
+
321
+ h = pre_layers(h)
322
+ # scale and shift for each condition
323
+ for i, (scale, shift) in enumerate(scale_shifts):
324
+ # if scale is None, it indicates that the condition is not provided
325
+ if scale is not None:
326
+ h = h * (biases[i] + scale)
327
+ if shift is not None:
328
+ h = h + shift
329
+ h = mid_layers(h)
330
+
331
+ # upscale or downscale if any just before the last conv
332
+ if up_down_layer is not None:
333
+ h = up_down_layer(h)
334
+ h = post_layers(h)
335
+ return h
336
+
337
+
338
+ class Upsample(nn.Module):
339
+ """
340
+ An upsampling layer with an optional convolution.
341
+
342
+ :param channels: channels in the inputs and outputs.
343
+ :param use_conv: a bool determining if a convolution is applied.
344
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
345
+ upsampling occurs in the inner-two dimensions.
346
+ """
347
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
348
+ super().__init__()
349
+ self.channels = channels
350
+ self.out_channels = out_channels or channels
351
+ self.use_conv = use_conv
352
+ self.dims = dims
353
+ if use_conv:
354
+ self.conv = conv_nd(dims,
355
+ self.channels,
356
+ self.out_channels,
357
+ 3,
358
+ padding=1)
359
+
360
+ def forward(self, x):
361
+ assert x.shape[1] == self.channels
362
+ if self.dims == 3:
363
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
364
+ mode="nearest")
365
+ else:
366
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
367
+ if self.use_conv:
368
+ x = self.conv(x)
369
+ return x
370
+
371
+
372
+ class Downsample(nn.Module):
373
+ """
374
+ A downsampling layer with an optional convolution.
375
+
376
+ :param channels: channels in the inputs and outputs.
377
+ :param use_conv: a bool determining if a convolution is applied.
378
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
379
+ downsampling occurs in the inner-two dimensions.
380
+ """
381
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
382
+ super().__init__()
383
+ self.channels = channels
384
+ self.out_channels = out_channels or channels
385
+ self.use_conv = use_conv
386
+ self.dims = dims
387
+ stride = 2 if dims != 3 else (1, 2, 2)
388
+ if use_conv:
389
+ self.op = conv_nd(dims,
390
+ self.channels,
391
+ self.out_channels,
392
+ 3,
393
+ stride=stride,
394
+ padding=1)
395
+ else:
396
+ assert self.channels == self.out_channels
397
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
398
+
399
+ def forward(self, x):
400
+ assert x.shape[1] == self.channels
401
+ return self.op(x)
402
+
403
+
404
+ class AttentionBlock(nn.Module):
405
+ """
406
+ An attention block that allows spatial positions to attend to each other.
407
+
408
+ Originally ported from here, but adapted to the N-d case.
409
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
410
+ """
411
+ def __init__(
412
+ self,
413
+ channels,
414
+ num_heads=1,
415
+ num_head_channels=-1,
416
+ use_checkpoint=False,
417
+ use_new_attention_order=False,
418
+ ):
419
+ super().__init__()
420
+ self.channels = channels
421
+ if num_head_channels == -1:
422
+ self.num_heads = num_heads
423
+ else:
424
+ assert (
425
+ channels % num_head_channels == 0
426
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
427
+ self.num_heads = channels // num_head_channels
428
+ self.use_checkpoint = use_checkpoint
429
+ self.norm = normalization(channels)
430
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
431
+ if use_new_attention_order:
432
+ # split qkv before split heads
433
+ self.attention = QKVAttention(self.num_heads)
434
+ else:
435
+ # split heads before split qkv
436
+ self.attention = QKVAttentionLegacy(self.num_heads)
437
+
438
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
439
+
440
+ def forward(self, x):
441
+ return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
442
+
443
+ def _forward(self, x):
444
+ b, c, *spatial = x.shape
445
+ x = x.reshape(b, c, -1)
446
+ qkv = self.qkv(self.norm(x))
447
+ h = self.attention(qkv)
448
+ h = self.proj_out(h)
449
+ return (x + h).reshape(b, c, *spatial)
450
+
451
+
452
+ def count_flops_attn(model, _x, y):
453
+ """
454
+ A counter for the `thop` package to count the operations in an
455
+ attention operation.
456
+ Meant to be used like:
457
+ macs, params = thop.profile(
458
+ model,
459
+ inputs=(inputs, timestamps),
460
+ custom_ops={QKVAttention: QKVAttention.count_flops},
461
+ )
462
+ """
463
+ b, c, *spatial = y[0].shape
464
+ num_spatial = int(np.prod(spatial))
465
+ # We perform two matmuls with the same number of ops.
466
+ # The first computes the weight matrix, the second computes
467
+ # the combination of the value vectors.
468
+ matmul_ops = 2 * b * (num_spatial**2) * c
469
+ model.total_ops += th.DoubleTensor([matmul_ops])
470
+
471
+
472
+ class QKVAttentionLegacy(nn.Module):
473
+ """
474
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
475
+ """
476
+ def __init__(self, n_heads):
477
+ super().__init__()
478
+ self.n_heads = n_heads
479
+
480
+ def forward(self, qkv):
481
+ """
482
+ Apply QKV attention.
483
+
484
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
485
+ :return: an [N x (H * C) x T] tensor after attention.
486
+ """
487
+ bs, width, length = qkv.shape
488
+ assert width % (3 * self.n_heads) == 0
489
+ ch = width // (3 * self.n_heads)
490
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
491
+ dim=1)
492
+ scale = 1 / math.sqrt(math.sqrt(ch))
493
+ weight = th.einsum(
494
+ "bct,bcs->bts", q * scale,
495
+ k * scale) # More stable with f16 than dividing afterwards
496
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
497
+ a = th.einsum("bts,bcs->bct", weight, v)
498
+ return a.reshape(bs, -1, length)
499
+
500
+ @staticmethod
501
+ def count_flops(model, _x, y):
502
+ return count_flops_attn(model, _x, y)
503
+
504
+
505
+ class QKVAttention(nn.Module):
506
+ """
507
+ A module which performs QKV attention and splits in a different order.
508
+ """
509
+ def __init__(self, n_heads):
510
+ super().__init__()
511
+ self.n_heads = n_heads
512
+
513
+ def forward(self, qkv):
514
+ """
515
+ Apply QKV attention.
516
+
517
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
518
+ :return: an [N x (H * C) x T] tensor after attention.
519
+ """
520
+ bs, width, length = qkv.shape
521
+ assert width % (3 * self.n_heads) == 0
522
+ ch = width // (3 * self.n_heads)
523
+ q, k, v = qkv.chunk(3, dim=1)
524
+ scale = 1 / math.sqrt(math.sqrt(ch))
525
+ weight = th.einsum(
526
+ "bct,bcs->bts",
527
+ (q * scale).view(bs * self.n_heads, ch, length),
528
+ (k * scale).view(bs * self.n_heads, ch, length),
529
+ ) # More stable with f16 than dividing afterwards
530
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
531
+ a = th.einsum("bts,bcs->bct", weight,
532
+ v.reshape(bs * self.n_heads, ch, length))
533
+ return a.reshape(bs, -1, length)
534
+
535
+ @staticmethod
536
+ def count_flops(model, _x, y):
537
+ return count_flops_attn(model, _x, y)
538
+
539
+
540
+ class AttentionPool2d(nn.Module):
541
+ """
542
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
543
+ """
544
+ def __init__(
545
+ self,
546
+ spacial_dim: int,
547
+ embed_dim: int,
548
+ num_heads_channels: int,
549
+ output_dim: int = None,
550
+ ):
551
+ super().__init__()
552
+ self.positional_embedding = nn.Parameter(
553
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
554
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
555
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
556
+ self.num_heads = embed_dim // num_heads_channels
557
+ self.attention = QKVAttention(self.num_heads)
558
+
559
+ def forward(self, x):
560
+ b, c, *_spatial = x.shape
561
+ x = x.reshape(b, c, -1) # NC(HW)
562
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
563
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
564
+ x = self.qkv_proj(x)
565
+ x = self.attention(x)
566
+ x = self.c_proj(x)
567
+ return x[:, :, 0]
model/diffusion.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2
+ # This program is free software; you can redistribute it and/or modify
3
+ # it under the terms of the MIT License.
4
+ # This program is distributed in the hope that it will be useful,
5
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
6
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7
+ # MIT License for more details.
8
+
9
+ import math
10
+ import torch
11
+ from einops import rearrange
12
+
13
+ from model.base import BaseModule
14
+
15
+
16
+ class Mish(BaseModule):
17
+ def forward(self, x):
18
+ return x * torch.tanh(torch.nn.functional.softplus(x))
19
+
20
+
21
+ class Upsample(BaseModule):
22
+ def __init__(self, dim):
23
+ super(Upsample, self).__init__()
24
+ self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
25
+
26
+ def forward(self, x):
27
+ return self.conv(x)
28
+
29
+
30
+ class Downsample(BaseModule):
31
+ def __init__(self, dim):
32
+ super(Downsample, self).__init__()
33
+ self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
34
+
35
+ def forward(self, x):
36
+ return self.conv(x)
37
+
38
+
39
+ class Rezero(BaseModule):
40
+ def __init__(self, fn):
41
+ super(Rezero, self).__init__()
42
+ self.fn = fn
43
+ self.g = torch.nn.Parameter(torch.zeros(1))
44
+
45
+ def forward(self, x):
46
+ return self.fn(x) * self.g
47
+
48
+
49
+ class Block(BaseModule):
50
+ def __init__(self, dim, dim_out, groups=8):
51
+ super(Block, self).__init__()
52
+ self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
53
+ padding=1), torch.nn.GroupNorm(
54
+ groups, dim_out), Mish())
55
+
56
+ def forward(self, x, mask):
57
+ output = self.block(x * mask)
58
+ return output * mask
59
+
60
+
61
+ class ResnetBlock(BaseModule):
62
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
63
+ super(ResnetBlock, self).__init__()
64
+ self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
65
+ dim_out))
66
+
67
+ self.block1 = Block(dim, dim_out, groups=groups)
68
+ self.block2 = Block(dim_out, dim_out, groups=groups)
69
+ if dim != dim_out:
70
+ self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
71
+ else:
72
+ self.res_conv = torch.nn.Identity()
73
+
74
+ def forward(self, x, mask, time_emb):
75
+ h = self.block1(x, mask)
76
+ h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
77
+ h = self.block2(h, mask)
78
+ output = h + self.res_conv(x * mask)
79
+ return output
80
+
81
+
82
+ class LinearAttention(BaseModule):
83
+ def __init__(self, dim, heads=4, dim_head=32):
84
+ super(LinearAttention, self).__init__()
85
+ self.heads = heads
86
+ hidden_dim = dim_head * heads
87
+ self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
88
+ self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
89
+
90
+ def forward(self, x):
91
+ b, c, h, w = x.shape
92
+ qkv = self.to_qkv(x)
93
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
94
+ heads = self.heads, qkv=3)
95
+ k = k.softmax(dim=-1)
96
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
97
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
98
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
99
+ heads=self.heads, h=h, w=w)
100
+ return self.to_out(out)
101
+
102
+
103
+ class Residual(BaseModule):
104
+ def __init__(self, fn):
105
+ super(Residual, self).__init__()
106
+ self.fn = fn
107
+
108
+ def forward(self, x, *args, **kwargs):
109
+ output = self.fn(x, *args, **kwargs) + x
110
+ return output
111
+
112
+
113
+ class SinusoidalPosEmb(BaseModule):
114
+ def __init__(self, dim):
115
+ super(SinusoidalPosEmb, self).__init__()
116
+ self.dim = dim
117
+
118
+ def forward(self, x, scale=1000):
119
+ device = x.device
120
+ half_dim = self.dim // 2
121
+ emb = math.log(10000) / (half_dim - 1)
122
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
123
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
124
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
125
+ return emb
126
+
127
+
128
+ class GradLogPEstimator2d(BaseModule):
129
+ def __init__(self, dim, dim_mults=(1, 2, 4), groups=8,
130
+ n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
131
+ super(GradLogPEstimator2d, self).__init__()
132
+ self.dim = dim
133
+ self.dim_mults = dim_mults
134
+ self.groups = groups
135
+ self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
136
+ self.spk_emb_dim = spk_emb_dim
137
+ self.pe_scale = pe_scale
138
+
139
+ if n_spks > 1:
140
+ self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
141
+ torch.nn.Linear(spk_emb_dim * 4, n_feats))
142
+ self.time_pos_emb = SinusoidalPosEmb(dim)
143
+ self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
144
+ torch.nn.Linear(dim * 4, dim))
145
+
146
+ dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
147
+ in_out = list(zip(dims[:-1], dims[1:]))
148
+ self.downs = torch.nn.ModuleList([])
149
+ self.ups = torch.nn.ModuleList([])
150
+ num_resolutions = len(in_out)
151
+
152
+ for ind, (dim_in, dim_out) in enumerate(in_out):
153
+ is_last = ind >= (num_resolutions - 1)
154
+ self.downs.append(torch.nn.ModuleList([
155
+ ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
156
+ ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
157
+ Residual(Rezero(LinearAttention(dim_out))),
158
+ Downsample(dim_out) if not is_last else torch.nn.Identity()]))
159
+
160
+ mid_dim = dims[-1]
161
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
162
+ self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
163
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
164
+
165
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
166
+ self.ups.append(torch.nn.ModuleList([
167
+ ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
168
+ ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
169
+ Residual(Rezero(LinearAttention(dim_in))),
170
+ Upsample(dim_in)]))
171
+ self.final_block = Block(dim, dim)
172
+ self.final_conv = torch.nn.Conv2d(dim, 1, 1)
173
+
174
+ def forward(self, x, mask, mu, t, spk=None):
175
+ if not isinstance(spk, type(None)):
176
+ s = self.spk_mlp(spk)
177
+
178
+ t = self.time_pos_emb(t, scale=self.pe_scale)
179
+ t = self.mlp(t)
180
+
181
+ if self.n_spks < 2:
182
+ x = torch.stack([mu, x], 1)
183
+ else:
184
+ s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
185
+ x = torch.stack([mu, x, s], 1)
186
+ mask = mask.unsqueeze(1)
187
+
188
+ hiddens = []
189
+ masks = [mask]
190
+ for resnet1, resnet2, attn, downsample in self.downs:
191
+ mask_down = masks[-1]
192
+ x = resnet1(x, mask_down, t)
193
+ x = resnet2(x, mask_down, t)
194
+ x = attn(x)
195
+ hiddens.append(x)
196
+ x = downsample(x * mask_down)
197
+ masks.append(mask_down[:, :, :, ::2])
198
+
199
+ masks = masks[:-1]
200
+ mask_mid = masks[-1]
201
+ x = self.mid_block1(x, mask_mid, t)
202
+ x = self.mid_attn(x)
203
+ x = self.mid_block2(x, mask_mid, t)
204
+
205
+ for resnet1, resnet2, attn, upsample in self.ups:
206
+ mask_up = masks.pop()
207
+ x = torch.cat((x, hiddens.pop()), dim=1)
208
+ x = resnet1(x, mask_up, t)
209
+ x = resnet2(x, mask_up, t)
210
+ x = attn(x)
211
+ x = upsample(x * mask_up)
212
+
213
+ x = self.final_block(x, mask)
214
+ output = self.final_conv(x * mask)
215
+
216
+ return (output * mask).squeeze(1)
217
+
218
+
219
+ def get_noise(t, beta_init, beta_term, cumulative=False):
220
+ if cumulative:
221
+ noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
222
+ else:
223
+ noise = beta_init + (beta_term - beta_init)*t
224
+ return noise
225
+
226
+
227
+ class Diffusion(BaseModule):
228
+ def __init__(self, n_feats, dim,
229
+ n_spks=1, spk_emb_dim=64,
230
+ beta_min=0.05, beta_max=20, pe_scale=1000):
231
+ super(Diffusion, self).__init__()
232
+ self.n_feats = n_feats
233
+ self.dim = dim
234
+ self.n_spks = n_spks
235
+ self.spk_emb_dim = spk_emb_dim
236
+ self.beta_min = beta_min
237
+ self.beta_max = beta_max
238
+ self.pe_scale = pe_scale
239
+
240
+ self.estimator = GradLogPEstimator2d(dim, n_spks=n_spks,
241
+ spk_emb_dim=spk_emb_dim,
242
+ pe_scale=pe_scale)
243
+
244
+ def forward_diffusion(self, x0, mask, mu, t):
245
+ time = t.unsqueeze(-1).unsqueeze(-1)
246
+ cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
247
+ mean = x0*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
248
+ variance = 1.0 - torch.exp(-cum_noise)
249
+ z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device,
250
+ requires_grad=False)
251
+ xt = mean + z * torch.sqrt(variance)
252
+ return xt * mask, z * mask
253
+
254
+ @torch.no_grad()
255
+ def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
256
+ h = 1.0 / n_timesteps
257
+ xt = z * mask
258
+ for i in range(n_timesteps):
259
+ t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype,
260
+ device=z.device)
261
+ time = t.unsqueeze(-1).unsqueeze(-1)
262
+ noise_t = get_noise(time, self.beta_min, self.beta_max,
263
+ cumulative=False)
264
+ if stoc: # adds stochastic term
265
+ dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
266
+ dxt_det = dxt_det * noise_t * h
267
+ dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
268
+ requires_grad=False)
269
+ dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
270
+ dxt = dxt_det + dxt_stoc
271
+ else:
272
+ dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
273
+ dxt = dxt * noise_t * h
274
+ xt = (xt - dxt) * mask
275
+ return xt
276
+
277
+ @torch.no_grad()
278
+ def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
279
+ return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk)
280
+
281
+ def loss_t(self, x0, mask, mu, t, spk=None):
282
+ xt, z = self.forward_diffusion(x0, mask, mu, t)
283
+ time = t.unsqueeze(-1).unsqueeze(-1)
284
+ cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
285
+ noise_estimation = self.estimator(xt, mask, mu, t, spk)
286
+ noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
287
+ loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feats)
288
+ return loss, xt
289
+
290
+ def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5):
291
+ t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device,
292
+ requires_grad=False)
293
+ t = torch.clamp(t, offset, 1.0 - offset)
294
+ return self.loss_t(x0, mask, mu, t, spk)
model/latentnet.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import NamedTuple, Tuple
5
+
6
+ import torch
7
+ from choices import *
8
+ from config_base import BaseConfig
9
+ from torch import nn
10
+ from torch.nn import init
11
+
12
+ from .blocks import *
13
+ from .nn import timestep_embedding
14
+ from .unet import *
15
+
16
+
17
+ class LatentNetType(Enum):
18
+ none = 'none'
19
+ # injecting inputs into the hidden layers
20
+ skip = 'skip'
21
+
22
+
23
+ class LatentNetReturn(NamedTuple):
24
+ pred: torch.Tensor = None
25
+
26
+
27
+ @dataclass
28
+ class MLPSkipNetConfig(BaseConfig):
29
+ """
30
+ default MLP for the latent DPM in the paper!
31
+ """
32
+ num_channels: int
33
+ skip_layers: Tuple[int]
34
+ num_hid_channels: int
35
+ num_layers: int
36
+ num_time_emb_channels: int = 64
37
+ activation: Activation = Activation.silu
38
+ use_norm: bool = True
39
+ condition_bias: float = 1
40
+ dropout: float = 0
41
+ last_act: Activation = Activation.none
42
+ num_time_layers: int = 2
43
+ time_last_act: bool = False
44
+
45
+ def make_model(self):
46
+ return MLPSkipNet(self)
47
+
48
+
49
+ class MLPSkipNet(nn.Module):
50
+ """
51
+ concat x to hidden layers
52
+
53
+ default MLP for the latent DPM in the paper!
54
+ """
55
+ def __init__(self, conf: MLPSkipNetConfig):
56
+ super().__init__()
57
+ self.conf = conf
58
+
59
+ layers = []
60
+ for i in range(conf.num_time_layers):
61
+ if i == 0:
62
+ a = conf.num_time_emb_channels
63
+ b = conf.num_channels
64
+ else:
65
+ a = conf.num_channels
66
+ b = conf.num_channels
67
+ layers.append(nn.Linear(a, b))
68
+ if i < conf.num_time_layers - 1 or conf.time_last_act:
69
+ layers.append(conf.activation.get_act())
70
+ self.time_embed = nn.Sequential(*layers)
71
+
72
+ self.layers = nn.ModuleList([])
73
+ for i in range(conf.num_layers):
74
+ if i == 0:
75
+ act = conf.activation
76
+ norm = conf.use_norm
77
+ cond = True
78
+ a, b = conf.num_channels, conf.num_hid_channels
79
+ dropout = conf.dropout
80
+ elif i == conf.num_layers - 1:
81
+ act = Activation.none
82
+ norm = False
83
+ cond = False
84
+ a, b = conf.num_hid_channels, conf.num_channels
85
+ dropout = 0
86
+ else:
87
+ act = conf.activation
88
+ norm = conf.use_norm
89
+ cond = True
90
+ a, b = conf.num_hid_channels, conf.num_hid_channels
91
+ dropout = conf.dropout
92
+
93
+ if i in conf.skip_layers:
94
+ a += conf.num_channels
95
+
96
+ self.layers.append(
97
+ MLPLNAct(
98
+ a,
99
+ b,
100
+ norm=norm,
101
+ activation=act,
102
+ cond_channels=conf.num_channels,
103
+ use_cond=cond,
104
+ condition_bias=conf.condition_bias,
105
+ dropout=dropout,
106
+ ))
107
+ self.last_act = conf.last_act.get_act()
108
+
109
+ def forward(self, x, t, **kwargs):
110
+ t = timestep_embedding(t, self.conf.num_time_emb_channels)
111
+ cond = self.time_embed(t)
112
+ h = x
113
+ for i in range(len(self.layers)):
114
+ if i in self.conf.skip_layers:
115
+ # injecting input into the hidden layers
116
+ h = torch.cat([h, x], dim=1)
117
+ h = self.layers[i].forward(x=h, cond=cond)
118
+ h = self.last_act(h)
119
+ return LatentNetReturn(h)
120
+
121
+
122
+ class MLPLNAct(nn.Module):
123
+ def __init__(
124
+ self,
125
+ in_channels: int,
126
+ out_channels: int,
127
+ norm: bool,
128
+ use_cond: bool,
129
+ activation: Activation,
130
+ cond_channels: int,
131
+ condition_bias: float = 0,
132
+ dropout: float = 0,
133
+ ):
134
+ super().__init__()
135
+ self.activation = activation
136
+ self.condition_bias = condition_bias
137
+ self.use_cond = use_cond
138
+
139
+ self.linear = nn.Linear(in_channels, out_channels)
140
+ self.act = activation.get_act()
141
+ if self.use_cond:
142
+ self.linear_emb = nn.Linear(cond_channels, out_channels)
143
+ self.cond_layers = nn.Sequential(self.act, self.linear_emb)
144
+ if norm:
145
+ self.norm = nn.LayerNorm(out_channels)
146
+ else:
147
+ self.norm = nn.Identity()
148
+
149
+ if dropout > 0:
150
+ self.dropout = nn.Dropout(p=dropout)
151
+ else:
152
+ self.dropout = nn.Identity()
153
+
154
+ self.init_weights()
155
+
156
+ def init_weights(self):
157
+ for module in self.modules():
158
+ if isinstance(module, nn.Linear):
159
+ if self.activation == Activation.relu:
160
+ init.kaiming_normal_(module.weight,
161
+ a=0,
162
+ nonlinearity='relu')
163
+ elif self.activation == Activation.lrelu:
164
+ init.kaiming_normal_(module.weight,
165
+ a=0.2,
166
+ nonlinearity='leaky_relu')
167
+ elif self.activation == Activation.silu:
168
+ init.kaiming_normal_(module.weight,
169
+ a=0,
170
+ nonlinearity='relu')
171
+ else:
172
+ # leave it as default
173
+ pass
174
+
175
+ def forward(self, x, cond=None):
176
+ x = self.linear(x)
177
+ if self.use_cond:
178
+ # (n, c) or (n, c * 2)
179
+ cond = self.cond_layers(cond)
180
+ cond = (cond, None)
181
+
182
+ # scale shift first
183
+ x = x * (self.condition_bias + cond[0])
184
+ if cond[1] is not None:
185
+ x = x + cond[1]
186
+ # then norm
187
+ x = self.norm(x)
188
+ else:
189
+ # no condition
190
+ x = self.norm(x)
191
+ x = self.act(x)
192
+ x = self.dropout(x)
193
+ return x
model/nn.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ from enum import Enum
6
+ import math
7
+ from typing import Optional
8
+
9
+ import torch as th
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+
13
+ import torch.nn.functional as F
14
+
15
+
16
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
17
+ class SiLU(nn.Module):
18
+ # @th.jit.script
19
+ def forward(self, x):
20
+ return x * th.sigmoid(x)
21
+
22
+
23
+ class GroupNorm32(nn.GroupNorm):
24
+ def forward(self, x):
25
+ return super().forward(x.float()).type(x.dtype)
26
+
27
+
28
+ def conv_nd(dims, *args, **kwargs):
29
+ """
30
+ Create a 1D, 2D, or 3D convolution module.
31
+ """
32
+ if dims == 1:
33
+ return nn.Conv1d(*args, **kwargs)
34
+ elif dims == 2:
35
+ return nn.Conv2d(*args, **kwargs)
36
+ elif dims == 3:
37
+ return nn.Conv3d(*args, **kwargs)
38
+ raise ValueError(f"unsupported dimensions: {dims}")
39
+
40
+
41
+ def linear(*args, **kwargs):
42
+ """
43
+ Create a linear module.
44
+ """
45
+ return nn.Linear(*args, **kwargs)
46
+
47
+
48
+ def avg_pool_nd(dims, *args, **kwargs):
49
+ """
50
+ Create a 1D, 2D, or 3D average pooling module.
51
+ """
52
+ if dims == 1:
53
+ return nn.AvgPool1d(*args, **kwargs)
54
+ elif dims == 2:
55
+ return nn.AvgPool2d(*args, **kwargs)
56
+ elif dims == 3:
57
+ return nn.AvgPool3d(*args, **kwargs)
58
+ raise ValueError(f"unsupported dimensions: {dims}")
59
+
60
+
61
+ def update_ema(target_params, source_params, rate=0.99):
62
+ """
63
+ Update target parameters to be closer to those of source parameters using
64
+ an exponential moving average.
65
+
66
+ :param target_params: the target parameter sequence.
67
+ :param source_params: the source parameter sequence.
68
+ :param rate: the EMA rate (closer to 1 means slower).
69
+ """
70
+ for targ, src in zip(target_params, source_params):
71
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
72
+
73
+
74
+ def zero_module(module):
75
+ """
76
+ Zero out the parameters of a module and return it.
77
+ """
78
+ for p in module.parameters():
79
+ p.detach().zero_()
80
+ return module
81
+
82
+
83
+ def scale_module(module, scale):
84
+ """
85
+ Scale the parameters of a module and return it.
86
+ """
87
+ for p in module.parameters():
88
+ p.detach().mul_(scale)
89
+ return module
90
+
91
+
92
+ def mean_flat(tensor):
93
+ """
94
+ Take the mean over all non-batch dimensions.
95
+ """
96
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
97
+
98
+
99
+ def normalization(channels):
100
+ """
101
+ Make a standard normalization layer.
102
+
103
+ :param channels: number of input channels.
104
+ :return: an nn.Module for normalization.
105
+ """
106
+ return GroupNorm32(min(32, channels), channels)
107
+
108
+
109
+ def timestep_embedding(timesteps, dim, max_period=10000):
110
+ """
111
+ Create sinusoidal timestep embeddings.
112
+
113
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
114
+ These may be fractional.
115
+ :param dim: the dimension of the output.
116
+ :param max_period: controls the minimum frequency of the embeddings.
117
+ :return: an [N x dim] Tensor of positional embeddings.
118
+ """
119
+ half = dim // 2
120
+ freqs = th.exp(-math.log(max_period) *
121
+ th.arange(start=0, end=half, dtype=th.float32) /
122
+ half).to(device=timesteps.device)
123
+ args = timesteps[:, None].float() * freqs[None]
124
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
125
+ if dim % 2:
126
+ embedding = th.cat(
127
+ [embedding, th.zeros_like(embedding[:, :1])], dim=-1)
128
+ return embedding
129
+
130
+
131
+ def torch_checkpoint(func, args, flag, preserve_rng_state=False):
132
+ # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
133
+ if flag:
134
+ return torch.utils.checkpoint.checkpoint(
135
+ func, *args, preserve_rng_state=preserve_rng_state)
136
+ else:
137
+ return func(*args)
model/seq2seq.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from model.base import BaseModule
4
+ from espnet.nets.pytorch_backend.conformer.encoder import Encoder as ConformerEncoder
5
+ import torch.nn.functional as F
6
+
7
+ class LSTM(nn.Module):
8
+ def __init__(self, motion_dim, output_dim, num_layers=2, hidden_dim=128):
9
+ super().__init__()
10
+ self.lstm = nn.LSTM(input_size=motion_dim, hidden_size=hidden_dim,
11
+ num_layers=num_layers, batch_first=True)
12
+ self.fc = nn.Linear(hidden_dim, output_dim)
13
+
14
+ def forward(self, x):
15
+ x, _ = self.lstm(x)
16
+ return self.fc(x)
17
+
18
+ class DiffusionPredictor(BaseModule):
19
+ def __init__(self, conf):
20
+ super(DiffusionPredictor, self).__init__()
21
+
22
+ self.infer_type = conf.infer_type
23
+
24
+ self.initialize_layers(conf)
25
+ print(f'infer_type: {self.infer_type}')
26
+
27
+ def create_conformer_encoder(self, attention_dim, num_blocks):
28
+ return ConformerEncoder(
29
+ idim=0, attention_dim=attention_dim, attention_heads=2, linear_units=attention_dim,
30
+ num_blocks=num_blocks, input_layer=None, dropout_rate=0.2, positional_dropout_rate=0.2,
31
+ attention_dropout_rate=0.2, normalize_before=False, concat_after=False,
32
+ positionwise_layer_type="linear", positionwise_conv_kernel_size=3, macaron_style=True,
33
+ pos_enc_layer_type="rel_pos", selfattention_layer_type="rel_selfattn", use_cnn_module=True,
34
+ cnn_module_kernel=13)
35
+
36
+ def initialize_layers(self, conf, mfcc_dim=39, hubert_dim=1024, speech_layers=4, speech_dim=512, decoder_dim=1024, motion_start_dim=512, HAL_layers=25):
37
+
38
+ self.conf = conf
39
+ # Speech downsampling
40
+ if self.infer_type.startswith('mfcc'):
41
+ # from 100 hz to 25 hz
42
+ self.down_sample1 = nn.Conv1d(mfcc_dim, 256, kernel_size=3, stride=2, padding=1)
43
+ self.down_sample2 = nn.Conv1d(256, speech_dim, kernel_size=3, stride=2, padding=1)
44
+ elif self.infer_type.startswith('hubert'):
45
+ # from 50 hz to 25 hz
46
+ self.down_sample1 = nn.Conv1d(hubert_dim, speech_dim, kernel_size=3, stride=2, padding=1)
47
+
48
+ self.weights = nn.Parameter(torch.zeros(HAL_layers))
49
+ self.speech_encoder = self.create_conformer_encoder(speech_dim, speech_layers)
50
+ else:
51
+ print('infer_type not supported')
52
+
53
+ # Encoders & Deocoders
54
+ self.coarse_decoder = self.create_conformer_encoder(decoder_dim, conf.decoder_layers)
55
+
56
+ # LSTM predictors for Variance Adapter
57
+ if self.infer_type != 'hubert_audio_only':
58
+ self.pose_predictor = LSTM(speech_dim, 3)
59
+ self.pose_encoder = LSTM(3, speech_dim)
60
+
61
+ if 'full_control' in self.infer_type:
62
+ self.location_predictor = LSTM(speech_dim, 1)
63
+ self.location_encoder = LSTM(1, speech_dim)
64
+ self.face_scale_predictor = LSTM(speech_dim, 1)
65
+ self.face_scale_encoder = LSTM(1, speech_dim)
66
+
67
+ # Linear transformations
68
+ self.init_code_proj = nn.Sequential(nn.Linear(motion_start_dim, 128))
69
+ self.noisy_encoder = nn.Sequential(nn.Linear(conf.motion_dim, 128))
70
+ self.t_encoder = nn.Sequential(nn.Linear(1, 128))
71
+ self.encoder_direction_code = nn.Linear(conf.motion_dim, 128)
72
+
73
+ self.out_proj = nn.Linear(decoder_dim, conf.motion_dim)
74
+
75
+
76
+ def forward(self, initial_code, direction_code, seq_input_vector, face_location, face_scale, yaw_pitch_roll, noisy_x, t_emb, control_flag=False):
77
+
78
+ if self.infer_type.startswith('mfcc'):
79
+ x = self.mfcc_speech_downsample(seq_input_vector)
80
+ elif self.infer_type.startswith('hubert'):
81
+ norm_weights = F.softmax(self.weights, dim=-1)
82
+ weighted_feature = (norm_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) * seq_input_vector).sum(dim=1)
83
+ x = self.down_sample1(weighted_feature.transpose(1,2)).transpose(1,2)
84
+ x, _ = self.speech_encoder(x, masks=None)
85
+ predicted_location, predicted_scale, predicted_pose = face_location, face_scale, yaw_pitch_roll
86
+ if self.infer_type != 'hubert_audio_only':
87
+ print(f'pose controllable. control_flag: {control_flag}')
88
+ x, predicted_location, predicted_scale, predicted_pose = self.adjust_features(x, face_location, face_scale, yaw_pitch_roll, control_flag)
89
+ concatenated_features = self.combine_features(x, initial_code, direction_code, noisy_x, t_emb) # initial_code and direction_code serve as a motion guide extracted from the reference image. This aims to tell the model what the starting motion should be.
90
+ outputs = self.decode_features(concatenated_features)
91
+ return outputs, predicted_location, predicted_scale, predicted_pose
92
+
93
+ def mfcc_speech_downsample(self, seq_input_vector):
94
+ x = self.down_sample1(seq_input_vector.transpose(1,2))
95
+ return self.down_sample2(x).transpose(1,2)
96
+
97
+ def adjust_features(self, x, face_location, face_scale, yaw_pitch_roll, control_flag):
98
+ predicted_location, predicted_scale = 0, 0
99
+ if 'full_control' in self.infer_type:
100
+ print(f'full controllable. control_flag: {control_flag}')
101
+ x_residual, predicted_location = self.adjust_location(x, face_location, control_flag)
102
+ x = x + x_residual
103
+
104
+ x_residual, predicted_scale = self.adjust_scale(x, face_scale, control_flag)
105
+ x = x + x_residual
106
+
107
+ x_residual, predicted_pose= self.adjust_pose(x, yaw_pitch_roll, control_flag)
108
+ x = x + x_residual
109
+ return x, predicted_location, predicted_scale, predicted_pose
110
+
111
+ def adjust_location(self, x, face_location, control_flag):
112
+ if control_flag:
113
+ predicted_location = face_location
114
+ else:
115
+ predicted_location = self.location_predictor(x)
116
+ return self.location_encoder(predicted_location), predicted_location
117
+
118
+ def adjust_scale(self, x, face_scale, control_flag):
119
+ if control_flag:
120
+ predicted_face_scale = face_scale
121
+ else:
122
+ predicted_face_scale = self.face_scale_predictor(x)
123
+ return self.face_scale_encoder(predicted_face_scale), predicted_face_scale
124
+
125
+ def adjust_pose(self, x, yaw_pitch_roll, control_flag):
126
+ if control_flag:
127
+ predicted_pose = yaw_pitch_roll
128
+ else:
129
+ predicted_pose = self.pose_predictor(x)
130
+ return self.pose_encoder(predicted_pose), predicted_pose
131
+
132
+ def combine_features(self, x, initial_code, direction_code, noisy_x, t_emb):
133
+ init_code_proj = self.init_code_proj(initial_code).unsqueeze(1).repeat(1, x.size(1), 1)
134
+ noisy_feature = self.noisy_encoder(noisy_x)
135
+ t_emb_feature = self.t_encoder(t_emb.unsqueeze(1).float()).unsqueeze(1).repeat(1, x.size(1), 1)
136
+ direction_code_feature = self.encoder_direction_code(direction_code).unsqueeze(1).repeat(1, x.size(1), 1)
137
+ return torch.cat((x, direction_code_feature, init_code_proj, noisy_feature, t_emb_feature), dim=-1)
138
+
139
+ def decode_features(self, concatenated_features):
140
+ outputs, _ = self.coarse_decoder(concatenated_features, masks=None)
141
+ return self.out_proj(outputs)
model/unet.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from numbers import Number
4
+ from typing import NamedTuple, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from choices import *
11
+ from config_base import BaseConfig
12
+ from .blocks import *
13
+
14
+ from .nn import (conv_nd, linear, normalization, timestep_embedding,
15
+ torch_checkpoint, zero_module)
16
+
17
+
18
+ @dataclass
19
+ class BeatGANsUNetConfig(BaseConfig):
20
+ image_size: int = 64
21
+ in_channels: int = 3
22
+ # base channels, will be multiplied
23
+ model_channels: int = 64
24
+ # output of the unet
25
+ # suggest: 3
26
+ # you only need 6 if you also model the variance of the noise prediction (usually we use an analytical variance hence 3)
27
+ out_channels: int = 3
28
+ # how many repeating resblocks per resolution
29
+ # the decoding side would have "one more" resblock
30
+ # default: 2
31
+ num_res_blocks: int = 2
32
+ # you can also set the number of resblocks specifically for the input blocks
33
+ # default: None = above
34
+ num_input_res_blocks: int = None
35
+ # number of time embed channels and style channels
36
+ embed_channels: int = 512
37
+ # at what resolutions you want to do self-attention of the feature maps
38
+ # attentions generally improve performance
39
+ # default: [16]
40
+ # beatgans: [32, 16, 8]
41
+ attention_resolutions: Tuple[int] = (16, )
42
+ # number of time embed channels
43
+ time_embed_channels: int = None
44
+ # dropout applies to the resblocks (on feature maps)
45
+ dropout: float = 0.1
46
+ channel_mult: Tuple[int] = (1, 2, 4, 8)
47
+ input_channel_mult: Tuple[int] = None
48
+ conv_resample: bool = True
49
+ # always 2 = 2d conv
50
+ dims: int = 2
51
+ # don't use this, legacy from BeatGANs
52
+ num_classes: int = None
53
+ use_checkpoint: bool = False
54
+ # number of attention heads
55
+ num_heads: int = 1
56
+ # or specify the number of channels per attention head
57
+ num_head_channels: int = -1
58
+ # what's this?
59
+ num_heads_upsample: int = -1
60
+ # use resblock for upscale/downscale blocks (expensive)
61
+ # default: True (BeatGANs)
62
+ resblock_updown: bool = True
63
+ # never tried
64
+ use_new_attention_order: bool = False
65
+ resnet_two_cond: bool = False
66
+ resnet_cond_channels: int = None
67
+ # init the decoding conv layers with zero weights, this speeds up training
68
+ # default: True (BeattGANs)
69
+ resnet_use_zero_module: bool = True
70
+ # gradient checkpoint the attention operation
71
+ attn_checkpoint: bool = False
72
+
73
+ def make_model(self):
74
+ return BeatGANsUNetModel(self)
75
+
76
+
77
+ class BeatGANsUNetModel(nn.Module):
78
+ def __init__(self, conf: BeatGANsUNetConfig):
79
+ super().__init__()
80
+ self.conf = conf
81
+
82
+ if conf.num_heads_upsample == -1:
83
+ self.num_heads_upsample = conf.num_heads
84
+
85
+ self.dtype = th.float32
86
+
87
+ self.time_emb_channels = conf.time_embed_channels or conf.model_channels
88
+ self.time_embed = nn.Sequential(
89
+ linear(self.time_emb_channels, conf.embed_channels),
90
+ nn.SiLU(),
91
+ linear(conf.embed_channels, conf.embed_channels),
92
+ )
93
+
94
+ if conf.num_classes is not None:
95
+ self.label_emb = nn.Embedding(conf.num_classes,
96
+ conf.embed_channels)
97
+
98
+ ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
99
+ self.input_blocks = nn.ModuleList([
100
+ TimestepEmbedSequential(
101
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
102
+ ])
103
+
104
+ kwargs = dict(
105
+ use_condition=True,
106
+ two_cond=conf.resnet_two_cond,
107
+ use_zero_module=conf.resnet_use_zero_module,
108
+ # style channels for the resnet block
109
+ cond_emb_channels=conf.resnet_cond_channels,
110
+ )
111
+
112
+ self._feature_size = ch
113
+
114
+ # input_block_chans = [ch]
115
+ input_block_chans = [[] for _ in range(len(conf.channel_mult))]
116
+ input_block_chans[0].append(ch)
117
+
118
+ # number of blocks at each resolution
119
+ self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
120
+ self.input_num_blocks[0] = 1
121
+ self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
122
+
123
+ ds = 1
124
+ resolution = conf.image_size
125
+ for level, mult in enumerate(conf.input_channel_mult
126
+ or conf.channel_mult):
127
+ for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
128
+ layers = [
129
+ ResBlockConfig(
130
+ ch,
131
+ conf.embed_channels,
132
+ conf.dropout,
133
+ out_channels=int(mult * conf.model_channels),
134
+ dims=conf.dims,
135
+ use_checkpoint=conf.use_checkpoint,
136
+ **kwargs,
137
+ ).make_model()
138
+ ]
139
+ ch = int(mult * conf.model_channels)
140
+ if resolution in conf.attention_resolutions:
141
+ layers.append(
142
+ AttentionBlock(
143
+ ch,
144
+ use_checkpoint=conf.use_checkpoint
145
+ or conf.attn_checkpoint,
146
+ num_heads=conf.num_heads,
147
+ num_head_channels=conf.num_head_channels,
148
+ use_new_attention_order=conf.
149
+ use_new_attention_order,
150
+ ))
151
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
152
+ self._feature_size += ch
153
+ # input_block_chans.append(ch)
154
+ input_block_chans[level].append(ch)
155
+ self.input_num_blocks[level] += 1
156
+ # print(input_block_chans)
157
+ if level != len(conf.channel_mult) - 1:
158
+ resolution //= 2
159
+ out_ch = ch
160
+ self.input_blocks.append(
161
+ TimestepEmbedSequential(
162
+ ResBlockConfig(
163
+ ch,
164
+ conf.embed_channels,
165
+ conf.dropout,
166
+ out_channels=out_ch,
167
+ dims=conf.dims,
168
+ use_checkpoint=conf.use_checkpoint,
169
+ down=True,
170
+ **kwargs,
171
+ ).make_model() if conf.
172
+ resblock_updown else Downsample(ch,
173
+ conf.conv_resample,
174
+ dims=conf.dims,
175
+ out_channels=out_ch)))
176
+ ch = out_ch
177
+ # input_block_chans.append(ch)
178
+ input_block_chans[level + 1].append(ch)
179
+ self.input_num_blocks[level + 1] += 1
180
+ ds *= 2
181
+ self._feature_size += ch
182
+
183
+ self.middle_block = TimestepEmbedSequential(
184
+ ResBlockConfig(
185
+ ch,
186
+ conf.embed_channels,
187
+ conf.dropout,
188
+ dims=conf.dims,
189
+ use_checkpoint=conf.use_checkpoint,
190
+ **kwargs,
191
+ ).make_model(),
192
+ AttentionBlock(
193
+ ch,
194
+ use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
195
+ num_heads=conf.num_heads,
196
+ num_head_channels=conf.num_head_channels,
197
+ use_new_attention_order=conf.use_new_attention_order,
198
+ ),
199
+ ResBlockConfig(
200
+ ch,
201
+ conf.embed_channels,
202
+ conf.dropout,
203
+ dims=conf.dims,
204
+ use_checkpoint=conf.use_checkpoint,
205
+ **kwargs,
206
+ ).make_model(),
207
+ )
208
+ self._feature_size += ch
209
+
210
+ self.output_blocks = nn.ModuleList([])
211
+ for level, mult in list(enumerate(conf.channel_mult))[::-1]:
212
+ for i in range(conf.num_res_blocks + 1):
213
+ # print(input_block_chans)
214
+ # ich = input_block_chans.pop()
215
+ try:
216
+ ich = input_block_chans[level].pop()
217
+ except IndexError:
218
+ # this happens only when num_res_block > num_enc_res_block
219
+ # we will not have enough lateral (skip) connecions for all decoder blocks
220
+ ich = 0
221
+ # print('pop:', ich)
222
+ layers = [
223
+ ResBlockConfig(
224
+ # only direct channels when gated
225
+ channels=ch + ich,
226
+ emb_channels=conf.embed_channels,
227
+ dropout=conf.dropout,
228
+ out_channels=int(conf.model_channels * mult),
229
+ dims=conf.dims,
230
+ use_checkpoint=conf.use_checkpoint,
231
+ # lateral channels are described here when gated
232
+ has_lateral=True if ich > 0 else False,
233
+ lateral_channels=None,
234
+ **kwargs,
235
+ ).make_model()
236
+ ]
237
+ ch = int(conf.model_channels * mult)
238
+ if resolution in conf.attention_resolutions:
239
+ layers.append(
240
+ AttentionBlock(
241
+ ch,
242
+ use_checkpoint=conf.use_checkpoint
243
+ or conf.attn_checkpoint,
244
+ num_heads=self.num_heads_upsample,
245
+ num_head_channels=conf.num_head_channels,
246
+ use_new_attention_order=conf.
247
+ use_new_attention_order,
248
+ ))
249
+ if level and i == conf.num_res_blocks:
250
+ resolution *= 2
251
+ out_ch = ch
252
+ layers.append(
253
+ ResBlockConfig(
254
+ ch,
255
+ conf.embed_channels,
256
+ conf.dropout,
257
+ out_channels=out_ch,
258
+ dims=conf.dims,
259
+ use_checkpoint=conf.use_checkpoint,
260
+ up=True,
261
+ **kwargs,
262
+ ).make_model() if (
263
+ conf.resblock_updown
264
+ ) else Upsample(ch,
265
+ conf.conv_resample,
266
+ dims=conf.dims,
267
+ out_channels=out_ch))
268
+ ds //= 2
269
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
270
+ self.output_num_blocks[level] += 1
271
+ self._feature_size += ch
272
+
273
+ # print(input_block_chans)
274
+ # print('inputs:', self.input_num_blocks)
275
+ # print('outputs:', self.output_num_blocks)
276
+
277
+ if conf.resnet_use_zero_module:
278
+ self.out = nn.Sequential(
279
+ normalization(ch),
280
+ nn.SiLU(),
281
+ zero_module(
282
+ conv_nd(conf.dims,
283
+ input_ch,
284
+ conf.out_channels,
285
+ 3,
286
+ padding=1)),
287
+ )
288
+ else:
289
+ self.out = nn.Sequential(
290
+ normalization(ch),
291
+ nn.SiLU(),
292
+ conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
293
+ )
294
+
295
+ def forward(self, x, t, y=None, **kwargs):
296
+ """
297
+ Apply the model to an input batch.
298
+
299
+ :param x: an [N x C x ...] Tensor of inputs.
300
+ :param timesteps: a 1-D batch of timesteps.
301
+ :param y: an [N] Tensor of labels, if class-conditional.
302
+ :return: an [N x C x ...] Tensor of outputs.
303
+ """
304
+ assert (y is not None) == (
305
+ self.conf.num_classes is not None
306
+ ), "must specify y if and only if the model is class-conditional"
307
+
308
+ # hs = []
309
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
310
+ emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
311
+
312
+ if self.conf.num_classes is not None:
313
+ raise NotImplementedError()
314
+ # assert y.shape == (x.shape[0], )
315
+ # emb = emb + self.label_emb(y)
316
+
317
+ # new code supports input_num_blocks != output_num_blocks
318
+ h = x.type(self.dtype)
319
+ k = 0
320
+ for i in range(len(self.input_num_blocks)):
321
+ for j in range(self.input_num_blocks[i]):
322
+ h = self.input_blocks[k](h, emb=emb)
323
+ # print(i, j, h.shape)
324
+ hs[i].append(h)
325
+ k += 1
326
+ assert k == len(self.input_blocks)
327
+
328
+ h = self.middle_block(h, emb=emb)
329
+ k = 0
330
+ for i in range(len(self.output_num_blocks)):
331
+ for j in range(self.output_num_blocks[i]):
332
+ # take the lateral connection from the same layer (in reserve)
333
+ # until there is no more, use None
334
+ try:
335
+ lateral = hs[-i - 1].pop()
336
+ # print(i, j, lateral.shape)
337
+ except IndexError:
338
+ lateral = None
339
+ # print(i, j, lateral)
340
+ h = self.output_blocks[k](h, emb=emb, lateral=lateral)
341
+ k += 1
342
+
343
+ h = h.type(x.dtype)
344
+ pred = self.out(h)
345
+ return Return(pred=pred)
346
+
347
+
348
+ class Return(NamedTuple):
349
+ pred: th.Tensor
350
+
351
+
352
+ @dataclass
353
+ class BeatGANsEncoderConfig(BaseConfig):
354
+ image_size: int
355
+ in_channels: int
356
+ model_channels: int
357
+ out_hid_channels: int
358
+ out_channels: int
359
+ num_res_blocks: int
360
+ attention_resolutions: Tuple[int]
361
+ dropout: float = 0
362
+ channel_mult: Tuple[int] = (1, 2, 4, 8)
363
+ use_time_condition: bool = True
364
+ conv_resample: bool = True
365
+ dims: int = 2
366
+ use_checkpoint: bool = False
367
+ num_heads: int = 1
368
+ num_head_channels: int = -1
369
+ resblock_updown: bool = False
370
+ use_new_attention_order: bool = False
371
+ pool: str = 'adaptivenonzero'
372
+
373
+ def make_model(self):
374
+ return BeatGANsEncoderModel(self)
375
+
376
+
377
+ class BeatGANsEncoderModel(nn.Module):
378
+ """
379
+ The half UNet model with attention and timestep embedding.
380
+
381
+ For usage, see UNet.
382
+ """
383
+ def __init__(self, conf: BeatGANsEncoderConfig):
384
+ super().__init__()
385
+ self.conf = conf
386
+ self.dtype = th.float32
387
+
388
+ if conf.use_time_condition:
389
+ time_embed_dim = conf.model_channels * 4
390
+ self.time_embed = nn.Sequential(
391
+ linear(conf.model_channels, time_embed_dim),
392
+ nn.SiLU(),
393
+ linear(time_embed_dim, time_embed_dim),
394
+ )
395
+ else:
396
+ time_embed_dim = None
397
+
398
+ ch = int(conf.channel_mult[0] * conf.model_channels)
399
+ self.input_blocks = nn.ModuleList([
400
+ TimestepEmbedSequential(
401
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
402
+ ])
403
+ self._feature_size = ch
404
+ input_block_chans = [ch]
405
+ ds = 1
406
+ resolution = conf.image_size
407
+ for level, mult in enumerate(conf.channel_mult):
408
+ for _ in range(conf.num_res_blocks):
409
+ layers = [
410
+ ResBlockConfig(
411
+ ch,
412
+ time_embed_dim,
413
+ conf.dropout,
414
+ out_channels=int(mult * conf.model_channels),
415
+ dims=conf.dims,
416
+ use_condition=conf.use_time_condition,
417
+ use_checkpoint=conf.use_checkpoint,
418
+ ).make_model()
419
+ ]
420
+ ch = int(mult * conf.model_channels)
421
+ if resolution in conf.attention_resolutions:
422
+ layers.append(
423
+ AttentionBlock(
424
+ ch,
425
+ use_checkpoint=conf.use_checkpoint,
426
+ num_heads=conf.num_heads,
427
+ num_head_channels=conf.num_head_channels,
428
+ use_new_attention_order=conf.
429
+ use_new_attention_order,
430
+ ))
431
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
432
+ self._feature_size += ch
433
+ input_block_chans.append(ch)
434
+ if level != len(conf.channel_mult) - 1:
435
+ resolution //= 2
436
+ out_ch = ch
437
+ self.input_blocks.append(
438
+ TimestepEmbedSequential(
439
+ ResBlockConfig(
440
+ ch,
441
+ time_embed_dim,
442
+ conf.dropout,
443
+ out_channels=out_ch,
444
+ dims=conf.dims,
445
+ use_condition=conf.use_time_condition,
446
+ use_checkpoint=conf.use_checkpoint,
447
+ down=True,
448
+ ).make_model() if (
449
+ conf.resblock_updown
450
+ ) else Downsample(ch,
451
+ conf.conv_resample,
452
+ dims=conf.dims,
453
+ out_channels=out_ch)))
454
+ ch = out_ch
455
+ input_block_chans.append(ch)
456
+ ds *= 2
457
+ self._feature_size += ch
458
+
459
+ self.middle_block = TimestepEmbedSequential(
460
+ ResBlockConfig(
461
+ ch,
462
+ time_embed_dim,
463
+ conf.dropout,
464
+ dims=conf.dims,
465
+ use_condition=conf.use_time_condition,
466
+ use_checkpoint=conf.use_checkpoint,
467
+ ).make_model(),
468
+ AttentionBlock(
469
+ ch,
470
+ use_checkpoint=conf.use_checkpoint,
471
+ num_heads=conf.num_heads,
472
+ num_head_channels=conf.num_head_channels,
473
+ use_new_attention_order=conf.use_new_attention_order,
474
+ ),
475
+ ResBlockConfig(
476
+ ch,
477
+ time_embed_dim,
478
+ conf.dropout,
479
+ dims=conf.dims,
480
+ use_condition=conf.use_time_condition,
481
+ use_checkpoint=conf.use_checkpoint,
482
+ ).make_model(),
483
+ )
484
+ self._feature_size += ch
485
+ if conf.pool == "adaptivenonzero":
486
+ self.out = nn.Sequential(
487
+ normalization(ch),
488
+ nn.SiLU(),
489
+ nn.AdaptiveAvgPool2d((1, 1)),
490
+ conv_nd(conf.dims, ch, conf.out_channels, 1),
491
+ nn.Flatten(),
492
+ )
493
+ else:
494
+ raise NotImplementedError(f"Unexpected {conf.pool} pooling")
495
+
496
+ def forward(self, x, t=None, return_2d_feature=False):
497
+ """
498
+ Apply the model to an input batch.
499
+
500
+ :param x: an [N x C x ...] Tensor of inputs.
501
+ :param timesteps: a 1-D batch of timesteps.
502
+ :return: an [N x K] Tensor of outputs.
503
+ """
504
+ if self.conf.use_time_condition:
505
+ emb = self.time_embed(timestep_embedding(t, self.model_channels))
506
+ else:
507
+ emb = None
508
+
509
+ results = []
510
+ h = x.type(self.dtype)
511
+ for module in self.input_blocks:
512
+ h = module(h, emb=emb)
513
+ if self.conf.pool.startswith("spatial"):
514
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
515
+ h = self.middle_block(h, emb=emb)
516
+ if self.conf.pool.startswith("spatial"):
517
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
518
+ h = th.cat(results, axis=-1)
519
+ else:
520
+ h = h.type(x.dtype)
521
+
522
+ h_2d = h
523
+ h = self.out(h)
524
+
525
+ if return_2d_feature:
526
+ return h, h_2d
527
+ else:
528
+ return h
529
+
530
+ def forward_flatten(self, x):
531
+ """
532
+ transform the last 2d feature into a flatten vector
533
+ """
534
+ h = self.out(x)
535
+ return h
536
+
537
+
538
+ class SuperResModel(BeatGANsUNetModel):
539
+ """
540
+ A UNetModel that performs super-resolution.
541
+
542
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
543
+ """
544
+ def __init__(self, image_size, in_channels, *args, **kwargs):
545
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
546
+
547
+ def forward(self, x, timesteps, low_res=None, **kwargs):
548
+ _, _, new_height, new_width = x.shape
549
+ upsampled = F.interpolate(low_res, (new_height, new_width),
550
+ mode="bilinear")
551
+ x = th.cat([x, upsampled], dim=1)
552
+ return super().forward(x, timesteps, **kwargs)
model/unet_autoenc.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn.functional import silu
6
+
7
+ from .latentnet import *
8
+ from .unet import *
9
+ from choices import *
10
+
11
+
12
+ @dataclass
13
+ class BeatGANsAutoencConfig(BeatGANsUNetConfig):
14
+ # number of style channels
15
+ enc_out_channels: int = 512
16
+ enc_attn_resolutions: Tuple[int] = None
17
+ enc_pool: str = 'depthconv'
18
+ enc_num_res_block: int = 2
19
+ enc_channel_mult: Tuple[int] = None
20
+ enc_grad_checkpoint: bool = False
21
+ latent_net_conf: MLPSkipNetConfig = None
22
+
23
+ def make_model(self):
24
+ return BeatGANsAutoencModel(self)
25
+
26
+
27
+ class BeatGANsAutoencModel(BeatGANsUNetModel):
28
+ def __init__(self, conf: BeatGANsAutoencConfig):
29
+ super().__init__(conf)
30
+ self.conf = conf
31
+
32
+ # having only time, cond
33
+ self.time_embed = TimeStyleSeperateEmbed(
34
+ time_channels=conf.model_channels,
35
+ time_out_channels=conf.embed_channels,
36
+ )
37
+
38
+ self.encoder = BeatGANsEncoderConfig(
39
+ image_size=conf.image_size,
40
+ in_channels=conf.in_channels,
41
+ model_channels=conf.model_channels,
42
+ out_hid_channels=conf.enc_out_channels,
43
+ out_channels=conf.enc_out_channels,
44
+ num_res_blocks=conf.enc_num_res_block,
45
+ attention_resolutions=(conf.enc_attn_resolutions
46
+ or conf.attention_resolutions),
47
+ dropout=conf.dropout,
48
+ channel_mult=conf.enc_channel_mult or conf.channel_mult,
49
+ use_time_condition=False,
50
+ conv_resample=conf.conv_resample,
51
+ dims=conf.dims,
52
+ use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
53
+ num_heads=conf.num_heads,
54
+ num_head_channels=conf.num_head_channels,
55
+ resblock_updown=conf.resblock_updown,
56
+ use_new_attention_order=conf.use_new_attention_order,
57
+ pool=conf.enc_pool,
58
+ ).make_model()
59
+
60
+ if conf.latent_net_conf is not None:
61
+ self.latent_net = conf.latent_net_conf.make_model()
62
+
63
+ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
64
+ """
65
+ Reparameterization trick to sample from N(mu, var) from
66
+ N(0,1).
67
+ :param mu: (Tensor) Mean of the latent Gaussian [B x D]
68
+ :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
69
+ :return: (Tensor) [B x D]
70
+ """
71
+ assert self.conf.is_stochastic
72
+ std = torch.exp(0.5 * logvar)
73
+ eps = torch.randn_like(std)
74
+ return eps * std + mu
75
+
76
+ def sample_z(self, n: int, device):
77
+ assert self.conf.is_stochastic
78
+ return torch.randn(n, self.conf.enc_out_channels, device=device)
79
+
80
+ def noise_to_cond(self, noise: Tensor):
81
+ raise NotImplementedError()
82
+ assert self.conf.noise_net_conf is not None
83
+ return self.noise_net.forward(noise)
84
+
85
+ def encode(self, x):
86
+ cond = self.encoder.forward(x)
87
+ return {'cond': cond}
88
+
89
+ @property
90
+ def stylespace_sizes(self):
91
+ modules = list(self.input_blocks.modules()) + list(
92
+ self.middle_block.modules()) + list(self.output_blocks.modules())
93
+ sizes = []
94
+ for module in modules:
95
+ if isinstance(module, ResBlock):
96
+ linear = module.cond_emb_layers[-1]
97
+ sizes.append(linear.weight.shape[0])
98
+ return sizes
99
+
100
+ def encode_stylespace(self, x, return_vector: bool = True):
101
+ """
102
+ encode to style space
103
+ """
104
+ modules = list(self.input_blocks.modules()) + list(
105
+ self.middle_block.modules()) + list(self.output_blocks.modules())
106
+ # (n, c)
107
+ cond = self.encoder.forward(x)
108
+ S = []
109
+ for module in modules:
110
+ if isinstance(module, ResBlock):
111
+ # (n, c')
112
+ s = module.cond_emb_layers.forward(cond)
113
+ S.append(s)
114
+
115
+ if return_vector:
116
+ # (n, sum_c)
117
+ return torch.cat(S, dim=1)
118
+ else:
119
+ return S
120
+
121
+ def forward(self,
122
+ x,
123
+ t,
124
+ y=None,
125
+ x_start=None,
126
+ cond=None,
127
+ style=None,
128
+ noise=None,
129
+ t_cond=None,
130
+ **kwargs):
131
+ """
132
+ Apply the model to an input batch.
133
+
134
+ Args:
135
+ x_start: the original image to encode
136
+ cond: output of the encoder
137
+ noise: random noise (to predict the cond)
138
+ """
139
+
140
+ if t_cond is None:
141
+ t_cond = t
142
+
143
+ if noise is not None:
144
+ # if the noise is given, we predict the cond from noise
145
+ cond = self.noise_to_cond(noise)
146
+
147
+ if cond is None:
148
+ if x is not None:
149
+ assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
150
+
151
+ tmp = self.encode(x_start)
152
+ cond = tmp['cond']
153
+
154
+ if t is not None:
155
+ _t_emb = timestep_embedding(t, self.conf.model_channels)
156
+ _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
157
+ else:
158
+ # this happens when training only autoenc
159
+ _t_emb = None
160
+ _t_cond_emb = None
161
+
162
+ if self.conf.resnet_two_cond:
163
+ res = self.time_embed.forward(
164
+ time_emb=_t_emb,
165
+ cond=cond,
166
+ time_cond_emb=_t_cond_emb,
167
+ )
168
+ else:
169
+ raise NotImplementedError()
170
+
171
+ if self.conf.resnet_two_cond:
172
+ # two cond: first = time emb, second = cond_emb
173
+ emb = res.time_emb
174
+ cond_emb = res.emb
175
+ else:
176
+ # one cond = combined of both time and cond
177
+ emb = res.emb
178
+ cond_emb = None
179
+
180
+ # override the style if given
181
+ style = style or res.style
182
+
183
+ assert (y is not None) == (
184
+ self.conf.num_classes is not None
185
+ ), "must specify y if and only if the model is class-conditional"
186
+
187
+ if self.conf.num_classes is not None:
188
+ raise NotImplementedError()
189
+ # assert y.shape == (x.shape[0], )
190
+ # emb = emb + self.label_emb(y)
191
+
192
+ # where in the model to supply time conditions
193
+ enc_time_emb = emb
194
+ mid_time_emb = emb
195
+ dec_time_emb = emb
196
+ # where in the model to supply style conditions
197
+ enc_cond_emb = cond_emb
198
+ mid_cond_emb = cond_emb
199
+ dec_cond_emb = cond_emb
200
+
201
+ # hs = []
202
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
203
+
204
+ if x is not None:
205
+ h = x.type(self.dtype)
206
+
207
+ # input blocks
208
+ k = 0
209
+ for i in range(len(self.input_num_blocks)):
210
+ for j in range(self.input_num_blocks[i]):
211
+ h = self.input_blocks[k](h,
212
+ emb=enc_time_emb,
213
+ cond=enc_cond_emb)
214
+
215
+ # print(i, j, h.shape)
216
+ hs[i].append(h)
217
+ k += 1
218
+ assert k == len(self.input_blocks)
219
+
220
+ # middle blocks
221
+ h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
222
+ else:
223
+ # no lateral connections
224
+ # happens when training only the autonecoder
225
+ h = None
226
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
227
+
228
+ # output blocks
229
+ k = 0
230
+ for i in range(len(self.output_num_blocks)):
231
+ for j in range(self.output_num_blocks[i]):
232
+ # take the lateral connection from the same layer (in reserve)
233
+ # until there is no more, use None
234
+ try:
235
+ lateral = hs[-i - 1].pop()
236
+ # print(i, j, lateral.shape)
237
+ except IndexError:
238
+ lateral = None
239
+ # print(i, j, lateral)
240
+
241
+ h = self.output_blocks[k](h,
242
+ emb=dec_time_emb,
243
+ cond=dec_cond_emb,
244
+ lateral=lateral)
245
+ k += 1
246
+
247
+ pred = self.out(h)
248
+ return AutoencReturn(pred=pred, cond=cond)
249
+
250
+
251
+ class AutoencReturn(NamedTuple):
252
+ pred: Tensor
253
+ cond: Tensor = None
254
+
255
+
256
+ class EmbedReturn(NamedTuple):
257
+ # style and time
258
+ emb: Tensor = None
259
+ # time only
260
+ time_emb: Tensor = None
261
+ # style only (but could depend on time)
262
+ style: Tensor = None
263
+
264
+
265
+ class TimeStyleSeperateEmbed(nn.Module):
266
+ # embed only style
267
+ def __init__(self, time_channels, time_out_channels):
268
+ super().__init__()
269
+ self.time_embed = nn.Sequential(
270
+ linear(time_channels, time_out_channels),
271
+ nn.SiLU(),
272
+ linear(time_out_channels, time_out_channels),
273
+ )
274
+ self.style = nn.Identity()
275
+
276
+ def forward(self, time_emb=None, cond=None, **kwargs):
277
+ if time_emb is None:
278
+ # happens with autoenc training mode
279
+ time_emb = None
280
+ else:
281
+ time_emb = self.time_embed(time_emb)
282
+ style = self.style(cond)
283
+ return EmbedReturn(emb=style, time_emb=time_emb, style=style)
networks/__init__.py ADDED
File without changes
networks/discriminator.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from torch import nn
5
+
6
+
7
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
8
+ return F.leaky_relu(input + bias, negative_slope) * scale
9
+
10
+
11
+ class FusedLeakyReLU(nn.Module):
12
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
13
+ super().__init__()
14
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
15
+ self.negative_slope = negative_slope
16
+ self.scale = scale
17
+
18
+ def forward(self, input):
19
+ # print("FusedLeakyReLU: ", input.abs().mean())
20
+ out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
21
+ # print("FusedLeakyReLU: ", out.abs().mean())
22
+ return out
23
+
24
+
25
+ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
26
+ _, minor, in_h, in_w = input.shape
27
+ kernel_h, kernel_w = kernel.shape
28
+
29
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
30
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
31
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
32
+
33
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
34
+ out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
35
+
36
+ # out = out.permute(0, 3, 1, 2)
37
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
38
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
39
+ out = F.conv2d(out, w)
40
+ out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
41
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
42
+ # out = out.permute(0, 2, 3, 1)
43
+
44
+ return out[:, :, ::down_y, ::down_x]
45
+
46
+
47
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
48
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
49
+
50
+
51
+ def make_kernel(k):
52
+ k = torch.tensor(k, dtype=torch.float32)
53
+
54
+ if k.ndim == 1:
55
+ k = k[None, :] * k[:, None]
56
+
57
+ k /= k.sum()
58
+
59
+ return k
60
+
61
+
62
+ class Blur(nn.Module):
63
+ def __init__(self, kernel, pad, upsample_factor=1):
64
+ super().__init__()
65
+
66
+ kernel = make_kernel(kernel)
67
+
68
+ if upsample_factor > 1:
69
+ kernel = kernel * (upsample_factor ** 2)
70
+
71
+ self.register_buffer('kernel', kernel)
72
+
73
+ self.pad = pad
74
+
75
+ def forward(self, input):
76
+ return upfirdn2d(input, self.kernel, pad=self.pad)
77
+
78
+
79
+ class ScaledLeakyReLU(nn.Module):
80
+ def __init__(self, negative_slope=0.2):
81
+ super().__init__()
82
+
83
+ self.negative_slope = negative_slope
84
+
85
+ def forward(self, input):
86
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
87
+
88
+
89
+ class EqualConv2d(nn.Module):
90
+ def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
91
+ super().__init__()
92
+
93
+ self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
94
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
95
+
96
+ self.stride = stride
97
+ self.padding = padding
98
+
99
+ if bias:
100
+ self.bias = nn.Parameter(torch.zeros(out_channel))
101
+ else:
102
+ self.bias = None
103
+
104
+ def forward(self, input):
105
+
106
+ return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride,
107
+ padding=self.padding, )
108
+
109
+ def __repr__(self):
110
+ return (
111
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
112
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
113
+ )
114
+
115
+
116
+ class EqualLinear(nn.Module):
117
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
118
+ super().__init__()
119
+
120
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
121
+
122
+ if bias:
123
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
124
+ else:
125
+ self.bias = None
126
+
127
+ self.activation = activation
128
+
129
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
130
+ self.lr_mul = lr_mul
131
+
132
+ def forward(self, input):
133
+
134
+ if self.activation:
135
+ out = F.linear(input, self.weight * self.scale)
136
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
137
+ else:
138
+ out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
139
+
140
+ return out
141
+
142
+ def __repr__(self):
143
+ return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
144
+
145
+
146
+ class ConvLayer(nn.Sequential):
147
+ def __init__(
148
+ self,
149
+ in_channel,
150
+ out_channel,
151
+ kernel_size,
152
+ downsample=False,
153
+ blur_kernel=[1, 3, 3, 1],
154
+ bias=True,
155
+ activate=True,
156
+ ):
157
+ layers = []
158
+
159
+ if downsample:
160
+ factor = 2
161
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
162
+ pad0 = (p + 1) // 2
163
+ pad1 = p // 2
164
+
165
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
166
+
167
+ stride = 2
168
+ self.padding = 0
169
+
170
+ else:
171
+ stride = 1
172
+ self.padding = kernel_size // 2
173
+
174
+ layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
175
+ bias=bias and not activate))
176
+
177
+ if activate:
178
+ if bias:
179
+ layers.append(FusedLeakyReLU(out_channel))
180
+ else:
181
+ layers.append(ScaledLeakyReLU(0.2))
182
+
183
+ super().__init__(*layers)
184
+
185
+
186
+ class ResBlock(nn.Module):
187
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
188
+ super().__init__()
189
+
190
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
191
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
192
+
193
+ self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
194
+
195
+ def forward(self, input):
196
+ out = self.conv1(input)
197
+ out = self.conv2(out)
198
+
199
+ skip = self.skip(input)
200
+ out = (out + skip) / math.sqrt(2)
201
+
202
+ return out
203
+
204
+
205
+ class Discriminator(nn.Module):
206
+ def __init__(self, size, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]):
207
+ super().__init__()
208
+
209
+ self.size = size
210
+
211
+ channels = {
212
+ 4: 512,
213
+ 8: 512,
214
+ 16: 512,
215
+ 32: 512,
216
+ 64: 256 * channel_multiplier,
217
+ 128: 128 * channel_multiplier,
218
+ 256: 64 * channel_multiplier,
219
+ 512: 32 * channel_multiplier,
220
+ 1024: 16 * channel_multiplier,
221
+ }
222
+
223
+ convs = [ConvLayer(3, channels[size], 1)]
224
+ log_size = int(math.log(size, 2))
225
+ in_channel = channels[size]
226
+
227
+ for i in range(log_size, 2, -1):
228
+ out_channel = channels[2 ** (i - 1)]
229
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
230
+ in_channel = out_channel
231
+
232
+ self.convs = nn.Sequential(*convs)
233
+
234
+ self.stddev_group = 4
235
+ self.stddev_feat = 1
236
+
237
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
238
+ self.final_linear = nn.Sequential(
239
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
240
+ EqualLinear(channels[4], 1),
241
+ )
242
+
243
+ def forward(self, input):
244
+ out = self.convs(input)
245
+ batch, channel, height, width = out.shape
246
+
247
+ group = min(batch, self.stddev_group)
248
+ stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width)
249
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
250
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
251
+ stddev = stddev.repeat(group, 1, height, width)
252
+ out = torch.cat([out, stddev], 1)
253
+
254
+ out = self.final_conv(out)
255
+
256
+ out = out.view(batch, -1)
257
+ out = self.final_linear(out)
258
+
259
+ return out
networks/encoder.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
7
+ return F.leaky_relu(input + bias, negative_slope) * scale
8
+
9
+ class FusedLeakyReLU(nn.Module):
10
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
11
+ super().__init__()
12
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
13
+ self.negative_slope = negative_slope
14
+ self.scale = scale
15
+
16
+ def forward(self, input):
17
+ out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
18
+ return out
19
+
20
+
21
+ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
22
+ _, minor, in_h, in_w = input.shape
23
+ kernel_h, kernel_w = kernel.shape
24
+
25
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
26
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
27
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
28
+
29
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
30
+ out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
31
+ max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
32
+
33
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
34
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
35
+ out = F.conv2d(out, w)
36
+ out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
37
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
38
+
39
+ return out[:, :, ::down_y, ::down_x]
40
+
41
+
42
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
43
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
44
+
45
+
46
+ def make_kernel(k):
47
+ k = torch.tensor(k, dtype=torch.float32)
48
+
49
+ if k.ndim == 1:
50
+ k = k[None, :] * k[:, None]
51
+
52
+ k /= k.sum()
53
+
54
+ return k
55
+
56
+
57
+ class Blur(nn.Module):
58
+ def __init__(self, kernel, pad, upsample_factor=1):
59
+ super().__init__()
60
+
61
+ kernel = make_kernel(kernel)
62
+
63
+ if upsample_factor > 1:
64
+ kernel = kernel * (upsample_factor ** 2)
65
+
66
+ self.register_buffer('kernel', kernel)
67
+
68
+ self.pad = pad
69
+
70
+ def forward(self, input):
71
+ return upfirdn2d(input, self.kernel, pad=self.pad)
72
+
73
+
74
+ class ScaledLeakyReLU(nn.Module):
75
+ def __init__(self, negative_slope=0.2):
76
+ super().__init__()
77
+
78
+ self.negative_slope = negative_slope
79
+
80
+ def forward(self, input):
81
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
82
+
83
+
84
+ class EqualConv2d(nn.Module):
85
+ def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
86
+ super().__init__()
87
+
88
+ self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
89
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
90
+
91
+ self.stride = stride
92
+ self.padding = padding
93
+
94
+ if bias:
95
+ self.bias = nn.Parameter(torch.zeros(out_channel))
96
+ else:
97
+ self.bias = None
98
+
99
+ def forward(self, input):
100
+
101
+ return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
102
+
103
+ def __repr__(self):
104
+ return (
105
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
106
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
107
+ )
108
+
109
+
110
+ class EqualLinear(nn.Module):
111
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
112
+ super().__init__()
113
+
114
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
115
+
116
+ if bias:
117
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
118
+ else:
119
+ self.bias = None
120
+
121
+ self.activation = activation
122
+
123
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
124
+ self.lr_mul = lr_mul
125
+
126
+ def forward(self, input):
127
+
128
+ if self.activation:
129
+ out = F.linear(input, self.weight * self.scale)
130
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
131
+ else:
132
+ out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
133
+
134
+ return out
135
+
136
+ def __repr__(self):
137
+ return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
138
+
139
+
140
+ class ConvLayer(nn.Sequential):
141
+ def __init__(
142
+ self,
143
+ in_channel,
144
+ out_channel,
145
+ kernel_size,
146
+ downsample=False,
147
+ blur_kernel=[1, 3, 3, 1],
148
+ bias=True,
149
+ activate=True,
150
+ ):
151
+ layers = []
152
+
153
+ if downsample:
154
+ factor = 2
155
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
156
+ pad0 = (p + 1) // 2
157
+ pad1 = p // 2
158
+
159
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
160
+
161
+ stride = 2
162
+ self.padding = 0
163
+
164
+ else:
165
+ stride = 1
166
+ self.padding = kernel_size // 2
167
+
168
+ layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
169
+ bias=bias and not activate))
170
+
171
+ if activate:
172
+ if bias:
173
+ layers.append(FusedLeakyReLU(out_channel))
174
+ else:
175
+ layers.append(ScaledLeakyReLU(0.2))
176
+
177
+ super().__init__(*layers)
178
+
179
+
180
+ class ResBlock(nn.Module):
181
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
182
+ super().__init__()
183
+
184
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
185
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
186
+
187
+ self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
188
+
189
+ def forward(self, input):
190
+ out = self.conv1(input)
191
+ out = self.conv2(out)
192
+
193
+ skip = self.skip(input)
194
+ out = (out + skip) / math.sqrt(2)
195
+
196
+ return out
197
+
198
+ class WeightedSumLayer(nn.Module):
199
+ def __init__(self, num_tensors=8):
200
+ super(WeightedSumLayer, self).__init__()
201
+
202
+ self.weights = nn.Parameter(torch.randn(num_tensors))
203
+
204
+ def forward(self, tensor_list):
205
+
206
+ weights = torch.softmax(self.weights, dim=0)
207
+ weighted_sum = torch.zeros_like(tensor_list[0])
208
+ for tensor, weight in zip(tensor_list, weights):
209
+ weighted_sum += tensor * weight
210
+
211
+ return weighted_sum
212
+
213
+ class EncoderApp(nn.Module):
214
+ def __init__(self, size, w_dim=512, fusion_type=''):
215
+ super(EncoderApp, self).__init__()
216
+
217
+ channels = {
218
+ 4: 512,
219
+ 8: 512,
220
+ 16: 512,
221
+ 32: 512,
222
+ 64: 256,
223
+ 128: 128,
224
+ 256: 64,
225
+ 512: 32,
226
+ 1024: 16
227
+ }
228
+
229
+ self.w_dim = w_dim
230
+ log_size = int(math.log(size, 2))
231
+
232
+ self.convs = nn.ModuleList()
233
+ self.convs.append(ConvLayer(3, channels[size], 1))
234
+
235
+ in_channel = channels[size]
236
+ for i in range(log_size, 2, -1):
237
+ out_channel = channels[2 ** (i - 1)]
238
+ self.convs.append(ResBlock(in_channel, out_channel))
239
+ in_channel = out_channel
240
+
241
+ self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
242
+
243
+ self.fusion_type = fusion_type
244
+ assert self.fusion_type == 'weighted_sum'
245
+ if self.fusion_type == 'weighted_sum':
246
+ print(f'HAL layer is enabled!')
247
+ self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
248
+ self.fc1 = EqualLinear(64, 512)
249
+ self.fc2 = EqualLinear(128, 512)
250
+ self.fc3 = EqualLinear(256, 512)
251
+ self.ws = WeightedSumLayer()
252
+
253
+ def forward(self, x):
254
+
255
+ res = []
256
+ h = x
257
+ pooled_h_lists = []
258
+ for i, conv in enumerate(self.convs):
259
+ h = conv(h)
260
+ if self.fusion_type == 'weighted_sum':
261
+ pooled_h = self.adaptive_pool(h).view(x.size(0), -1)
262
+ if i == 0:
263
+ pooled_h_lists.append(self.fc1(pooled_h))
264
+ elif i == 1:
265
+ pooled_h_lists.append(self.fc2(pooled_h))
266
+ elif i == 2:
267
+ pooled_h_lists.append(self.fc3(pooled_h))
268
+ else:
269
+ pooled_h_lists.append(pooled_h)
270
+ res.append(h)
271
+
272
+ if self.fusion_type == 'weighted_sum':
273
+ last_layer = self.ws(pooled_h_lists)
274
+ else:
275
+ last_layer = res[-1].squeeze(-1).squeeze(-1)
276
+ layer_features = res[::-1][2:]
277
+
278
+ return last_layer, layer_features
279
+
280
+
281
+ class DecouplingModel(nn.Module):
282
+ def __init__(self, input_dim, hidden_dim, output_dim):
283
+ super(DecouplingModel, self).__init__()
284
+
285
+ # identity_excluded_net is called identity encoder in the paper
286
+ self.identity_net = nn.Sequential(
287
+ nn.Linear(input_dim, hidden_dim),
288
+ nn.ReLU(),
289
+ nn.Linear(hidden_dim, output_dim)
290
+ )
291
+
292
+ self.identity_net_density = nn.Sequential(
293
+ nn.Linear(input_dim, hidden_dim),
294
+ nn.ReLU(),
295
+ nn.Linear(hidden_dim, output_dim)
296
+ )
297
+
298
+ # identity_excluded_net is called motion encoder in the paper
299
+ self.identity_excluded_net = nn.Sequential(
300
+ nn.Linear(input_dim, hidden_dim),
301
+ nn.ReLU(),
302
+ nn.Linear(hidden_dim, output_dim)
303
+ )
304
+
305
+ def forward(self, x):
306
+
307
+ id_, id_rm = self.identity_net(x), self.identity_excluded_net(x)
308
+ id_density = self.identity_net_density(id_)
309
+ return id_, id_rm, id_density
310
+
311
+ class Encoder(nn.Module):
312
+ def __init__(self, size, dim=512, dim_motion=20, weighted_sum=False):
313
+ super(Encoder, self).__init__()
314
+
315
+ # image encoder
316
+ self.net_app = EncoderApp(size, dim, weighted_sum)
317
+
318
+ # decouping network
319
+ self.net_decouping = DecouplingModel(dim, dim, dim)
320
+
321
+ # part of the motion encoder
322
+ fc = [EqualLinear(dim, dim)]
323
+ for i in range(3):
324
+ fc.append(EqualLinear(dim, dim))
325
+
326
+ fc.append(EqualLinear(dim, dim_motion))
327
+ self.fc = nn.Sequential(*fc)
328
+
329
+ def enc_app(self, x):
330
+
331
+ h_source = self.net_app(x)
332
+
333
+ return h_source
334
+
335
+ def enc_motion(self, x):
336
+
337
+ h, _ = self.net_app(x)
338
+ h_motion = self.fc(h)
339
+
340
+ return h_motion
341
+
342
+ def encode_image_obj(self, image_obj):
343
+ feat, _ = self.net_app(image_obj)
344
+ id_emb, idrm_emb, id_density_emb = self.net_decouping(feat)
345
+ return id_emb, idrm_emb, id_density_emb
346
+
347
+ def forward(self, input_source, input_target, input_face, input_aug):
348
+
349
+
350
+ if input_target is not None:
351
+
352
+ h_source, feats = self.net_app(input_source)
353
+ h_target, _ = self.net_app(input_target)
354
+ h_face, _ = self.net_app(input_face)
355
+ h_aug, _ = self.net_app(input_aug)
356
+
357
+ h_source_id_emb, h_source_idrm_emb, h_source_id_density_emb = self.net_decouping(h_source)
358
+ h_target_id_emb, h_target_idrm_emb, h_target_id_density_emb = self.net_decouping(h_target)
359
+ h_face_id_emb, h_face_idrm_emb, h_face_id_density_emb = self.net_decouping(h_face)
360
+ h_aug_id_emb, h_aug_idrm_emb, h_aug_id_density_emb = self.net_decouping(h_aug)
361
+
362
+ h_target_motion_target = self.fc(h_target_idrm_emb)
363
+ h_another_face_target = self.fc(h_face_idrm_emb)
364
+
365
+ else:
366
+ h_source, feats = self.net_app(input_source)
367
+
368
+
369
+ return {'h_source':h_source, 'h_motion':h_target_motion_target, 'feats':feats, 'h_another_face_target':h_another_face_target, 'h_face':h_face, \
370
+ 'h_source_id_emb':h_source_id_emb, 'h_source_idrm_emb':h_source_idrm_emb, 'h_source_id_density_emb':h_source_id_density_emb, \
371
+ 'h_target_id_emb':h_target_id_emb, 'h_target_idrm_emb':h_target_idrm_emb, 'h_target_id_density_emb':h_target_id_density_emb, \
372
+ 'h_face_id_emb':h_face_id_emb, 'h_face_idrm_emb':h_face_idrm_emb, 'h_face_id_density_emb':h_face_id_density_emb, \
373
+ 'h_aug_id_emb':h_aug_id_emb, 'h_aug_idrm_emb':h_aug_idrm_emb ,'h_aug_id_density_emb':h_aug_id_density_emb, \
374
+ }
networks/generator.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from .encoder import Encoder
3
+ from .styledecoder import Synthesis
4
+
5
+
6
+ class Generator(nn.Module):
7
+ def __init__(self, size, style_dim=512, motion_dim=20, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]):
8
+ super(Generator, self).__init__()
9
+
10
+ # encoder
11
+ self.enc = Encoder(size, style_dim, motion_dim)
12
+ self.dec = Synthesis(size, style_dim, motion_dim, blur_kernel, channel_multiplier)
13
+
14
+ def get_direction(self):
15
+ return self.dec.direction(None)
16
+
17
+ def synthesis(self, wa, alpha, feat):
18
+ img = self.dec(wa, alpha, feat)
19
+
20
+ return img
21
+
22
+ def forward(self, img_source, img_drive, h_start=None):
23
+ wa, alpha, feats = self.enc(img_source, img_drive, h_start)
24
+ # import pdb;pdb.set_trace()
25
+ img_recon = self.dec(wa, alpha, feats)
26
+
27
+ return img_recon
networks/styledecoder.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ import numpy as np
6
+
7
+
8
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
9
+ return F.leaky_relu(input + bias, negative_slope) * scale
10
+
11
+
12
+ class FusedLeakyReLU(nn.Module):
13
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
14
+ super().__init__()
15
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
16
+ self.negative_slope = negative_slope
17
+ self.scale = scale
18
+
19
+ def forward(self, input):
20
+ out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
21
+ return out
22
+
23
+
24
+ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
25
+ _, minor, in_h, in_w = input.shape
26
+ kernel_h, kernel_w = kernel.shape
27
+
28
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
29
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
30
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
31
+
32
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
33
+ out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
34
+ max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
35
+
36
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
37
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
38
+ out = F.conv2d(out, w)
39
+ out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
40
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
41
+ return out[:, :, ::down_y, ::down_x]
42
+
43
+
44
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
45
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
46
+
47
+
48
+ class PixelNorm(nn.Module):
49
+ def __init__(self):
50
+ super().__init__()
51
+
52
+ def forward(self, input):
53
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
54
+
55
+
56
+ class MotionPixelNorm(nn.Module):
57
+ def __init__(self):
58
+ super().__init__()
59
+
60
+ def forward(self, input):
61
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=2, keepdim=True) + 1e-8)
62
+
63
+
64
+ def make_kernel(k):
65
+ k = torch.tensor(k, dtype=torch.float32)
66
+
67
+ if k.ndim == 1:
68
+ k = k[None, :] * k[:, None]
69
+
70
+ k /= k.sum()
71
+
72
+ return k
73
+
74
+
75
+ class Upsample(nn.Module):
76
+ def __init__(self, kernel, factor=2):
77
+ super().__init__()
78
+
79
+ self.factor = factor
80
+ kernel = make_kernel(kernel) * (factor ** 2)
81
+ self.register_buffer('kernel', kernel)
82
+
83
+ p = kernel.shape[0] - factor
84
+
85
+ pad0 = (p + 1) // 2 + factor - 1
86
+ pad1 = p // 2
87
+
88
+ self.pad = (pad0, pad1)
89
+
90
+ def forward(self, input):
91
+ return upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
92
+
93
+
94
+ class Downsample(nn.Module):
95
+ def __init__(self, kernel, factor=2):
96
+ super().__init__()
97
+
98
+ self.factor = factor
99
+ kernel = make_kernel(kernel)
100
+ self.register_buffer('kernel', kernel)
101
+
102
+ p = kernel.shape[0] - factor
103
+
104
+ pad0 = (p + 1) // 2
105
+ pad1 = p // 2
106
+
107
+ self.pad = (pad0, pad1)
108
+
109
+ def forward(self, input):
110
+ return upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
111
+
112
+
113
+ class Blur(nn.Module):
114
+ def __init__(self, kernel, pad, upsample_factor=1):
115
+ super().__init__()
116
+
117
+ kernel = make_kernel(kernel)
118
+
119
+ if upsample_factor > 1:
120
+ kernel = kernel * (upsample_factor ** 2)
121
+
122
+ self.register_buffer('kernel', kernel)
123
+
124
+ self.pad = pad
125
+
126
+ def forward(self, input):
127
+ return upfirdn2d(input, self.kernel, pad=self.pad)
128
+
129
+
130
+ class EqualConv2d(nn.Module):
131
+ def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
132
+ super().__init__()
133
+
134
+ self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
135
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
136
+
137
+ self.stride = stride
138
+ self.padding = padding
139
+
140
+ if bias:
141
+ self.bias = nn.Parameter(torch.zeros(out_channel))
142
+ else:
143
+ self.bias = None
144
+
145
+ def forward(self, input):
146
+
147
+ return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, )
148
+
149
+ def __repr__(self):
150
+ return (
151
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
152
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
153
+ )
154
+
155
+
156
+ class EqualLinear(nn.Module):
157
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
158
+ super().__init__()
159
+
160
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
161
+
162
+ if bias:
163
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
164
+ else:
165
+ self.bias = None
166
+
167
+ self.activation = activation
168
+
169
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
170
+ self.lr_mul = lr_mul
171
+
172
+ def forward(self, input):
173
+
174
+ if self.activation:
175
+ out = F.linear(input, self.weight * self.scale)
176
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
177
+ else:
178
+ out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
179
+
180
+ return out
181
+
182
+ def __repr__(self):
183
+ return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
184
+
185
+
186
+ class ScaledLeakyReLU(nn.Module):
187
+ def __init__(self, negative_slope=0.2):
188
+ super().__init__()
189
+
190
+ self.negative_slope = negative_slope
191
+
192
+ def forward(self, input):
193
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
194
+
195
+
196
+ class ModulatedConv2d(nn.Module):
197
+ def __init__(self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False,
198
+ downsample=False, blur_kernel=[1, 3, 3, 1], ):
199
+ super().__init__()
200
+
201
+ self.eps = 1e-8
202
+ self.kernel_size = kernel_size
203
+ self.in_channel = in_channel
204
+ self.out_channel = out_channel
205
+ self.upsample = upsample
206
+ self.downsample = downsample
207
+
208
+ if upsample:
209
+ factor = 2
210
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
211
+ pad0 = (p + 1) // 2 + factor - 1
212
+ pad1 = p // 2 + 1
213
+
214
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
215
+
216
+ if downsample:
217
+ factor = 2
218
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
219
+ pad0 = (p + 1) // 2
220
+ pad1 = p // 2
221
+
222
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
223
+
224
+ fan_in = in_channel * kernel_size ** 2
225
+ self.scale = 1 / math.sqrt(fan_in)
226
+ self.padding = kernel_size // 2
227
+
228
+ self.weight = nn.Parameter(torch.randn(1, out_channel, in_channel, kernel_size, kernel_size))
229
+
230
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
231
+ self.demodulate = demodulate
232
+
233
+ def __repr__(self):
234
+ return (
235
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
236
+ f'upsample={self.upsample}, downsample={self.downsample})'
237
+ )
238
+
239
+ def forward(self, input, style):
240
+ batch, in_channel, height, width = input.shape
241
+
242
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
243
+ weight = self.scale * self.weight * style
244
+
245
+ if self.demodulate:
246
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
247
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
248
+
249
+ weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size)
250
+
251
+ if self.upsample:
252
+ input = input.view(1, batch * in_channel, height, width)
253
+ weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size)
254
+ weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel, self.kernel_size,
255
+ self.kernel_size)
256
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
257
+ _, _, height, width = out.shape
258
+ out = out.view(batch, self.out_channel, height, width)
259
+ out = self.blur(out)
260
+ elif self.downsample:
261
+ input = self.blur(input)
262
+ _, _, height, width = input.shape
263
+ input = input.view(1, batch * in_channel, height, width)
264
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
265
+ _, _, height, width = out.shape
266
+ out = out.view(batch, self.out_channel, height, width)
267
+ else:
268
+ input = input.view(1, batch * in_channel, height, width)
269
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
270
+ _, _, height, width = out.shape
271
+ out = out.view(batch, self.out_channel, height, width)
272
+
273
+ return out
274
+
275
+
276
+ class NoiseInjection(nn.Module):
277
+ def __init__(self):
278
+ super().__init__()
279
+
280
+ self.weight = nn.Parameter(torch.zeros(1))
281
+
282
+ def forward(self, image, noise=None):
283
+
284
+ if noise is None:
285
+ return image
286
+ else:
287
+ return image + self.weight * noise
288
+
289
+
290
+ class ConstantInput(nn.Module):
291
+ def __init__(self, channel, size=4):
292
+ super().__init__()
293
+
294
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
295
+
296
+ def forward(self, input):
297
+ batch = input.shape[0]
298
+ out = self.input.repeat(batch, 1, 1, 1)
299
+
300
+ return out
301
+
302
+
303
+ class StyledConv(nn.Module):
304
+ def __init__(self, in_channel, out_channel, kernel_size, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1],
305
+ demodulate=True):
306
+ super().__init__()
307
+
308
+ self.conv = ModulatedConv2d(
309
+ in_channel,
310
+ out_channel,
311
+ kernel_size,
312
+ style_dim,
313
+ upsample=upsample,
314
+ blur_kernel=blur_kernel,
315
+ demodulate=demodulate,
316
+ )
317
+
318
+ self.noise = NoiseInjection()
319
+ self.activate = FusedLeakyReLU(out_channel)
320
+
321
+ def forward(self, input, style, noise=None):
322
+ out = self.conv(input, style)
323
+ out = self.noise(out, noise=noise)
324
+ out = self.activate(out)
325
+
326
+ return out
327
+
328
+
329
+ class ConvLayer(nn.Sequential):
330
+ def __init__(
331
+ self,
332
+ in_channel,
333
+ out_channel,
334
+ kernel_size,
335
+ downsample=False,
336
+ blur_kernel=[1, 3, 3, 1],
337
+ bias=True,
338
+ activate=True,
339
+ ):
340
+ layers = []
341
+
342
+ if downsample:
343
+ factor = 2
344
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
345
+ pad0 = (p + 1) // 2
346
+ pad1 = p // 2
347
+
348
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
349
+
350
+ stride = 2
351
+ self.padding = 0
352
+
353
+ else:
354
+ stride = 1
355
+ self.padding = kernel_size // 2
356
+
357
+ layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
358
+ bias=bias and not activate))
359
+
360
+ if activate:
361
+ if bias:
362
+ layers.append(FusedLeakyReLU(out_channel))
363
+ else:
364
+ layers.append(ScaledLeakyReLU(0.2))
365
+
366
+ super().__init__(*layers)
367
+
368
+
369
+ class ToRGB(nn.Module):
370
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
371
+ super().__init__()
372
+
373
+ if upsample:
374
+ self.upsample = Upsample(blur_kernel)
375
+
376
+ self.conv = ConvLayer(in_channel, 3, 1)
377
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
378
+
379
+ def forward(self, input, skip=None):
380
+ out = self.conv(input)
381
+ out = out + self.bias
382
+
383
+ if skip is not None:
384
+ skip = self.upsample(skip)
385
+ out = out + skip
386
+
387
+ return out
388
+
389
+
390
+ class ToFlow(nn.Module):
391
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
392
+ super().__init__()
393
+
394
+ if upsample:
395
+ self.upsample = Upsample(blur_kernel)
396
+
397
+ self.style_dim = style_dim
398
+ self.in_channel = in_channel
399
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
400
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
401
+
402
+ def forward(self, input, style, feat, skip=None): # input 是来自上一层的 feature, style 是 512 的 condition, feat 是来自于 unet 的跳层
403
+ out = self.conv(input, style)
404
+ out = out + self.bias
405
+
406
+ # warping
407
+ xs = np.linspace(-1, 1, input.size(2))
408
+
409
+ xs = np.meshgrid(xs, xs)
410
+ xs = np.stack(xs, 2)
411
+
412
+ xs = torch.tensor(xs, requires_grad=False).float().unsqueeze(0).repeat(input.size(0), 1, 1, 1).to(input.device)
413
+ # import pdb;pdb.set_trace()
414
+ if skip is not None:
415
+ skip = self.upsample(skip)
416
+ out = out + skip
417
+
418
+ sampler = torch.tanh(out[:, 0:2, :, :])
419
+ mask = torch.sigmoid(out[:, 2:3, :, :])
420
+ flow = sampler.permute(0, 2, 3, 1) + xs # xs在这里相当于一个 location 的位置
421
+
422
+ feat_warp = F.grid_sample(feat, flow) * mask
423
+ # import pdb;pdb.set_trace()
424
+ return feat_warp, feat_warp + input * (1.0 - mask), out
425
+
426
+
427
+ class Direction(nn.Module):
428
+ def __init__(self, motion_dim):
429
+ super(Direction, self).__init__()
430
+
431
+ self.weight = nn.Parameter(torch.randn(512, motion_dim))
432
+
433
+ def forward(self, input):
434
+ # input: (bs*t) x 512
435
+
436
+ weight = self.weight + 1e-8
437
+ Q, R = torch.qr(weight) # get eignvector, orthogonal [n1, n2, n3, n4]
438
+
439
+ if input is None:
440
+ return Q
441
+ else:
442
+ input_diag = torch.diag_embed(input) # alpha, diagonal matrix
443
+ out = torch.matmul(input_diag, Q.T)
444
+ out = torch.sum(out, dim=1)
445
+
446
+ return out
447
+
448
+ class Synthesis(nn.Module):
449
+ def __init__(self, size, style_dim, motion_dim, blur_kernel=[1, 3, 3, 1], channel_multiplier=1):
450
+ super(Synthesis, self).__init__()
451
+
452
+ self.size = size
453
+ self.style_dim = style_dim
454
+ self.motion_dim = motion_dim
455
+
456
+ self.direction = Direction(motion_dim) # Linear Motion Decomposition (LMD) from LIA
457
+
458
+ self.channels = {
459
+ 4: 512,
460
+ 8: 512,
461
+ 16: 512,
462
+ 32: 512,
463
+ 64: 256 * channel_multiplier,
464
+ 128: 128 * channel_multiplier,
465
+ 256: 64 * channel_multiplier,
466
+ 512: 32 * channel_multiplier,
467
+ 1024: 16 * channel_multiplier,
468
+ }
469
+
470
+ self.input = ConstantInput(self.channels[4])
471
+ self.conv1 = StyledConv(self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel)
472
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
473
+
474
+ self.log_size = int(math.log(size, 2))
475
+ self.num_layers = (self.log_size - 2) * 2 + 1
476
+
477
+ self.convs = nn.ModuleList()
478
+ self.upsamples = nn.ModuleList()
479
+ self.to_rgbs = nn.ModuleList()
480
+ self.to_flows = nn.ModuleList()
481
+
482
+ in_channel = self.channels[4]
483
+
484
+ for i in range(3, self.log_size + 1):
485
+ out_channel = self.channels[2 ** i]
486
+
487
+ self.convs.append(StyledConv(in_channel, out_channel, 3, style_dim, upsample=True,
488
+ blur_kernel=blur_kernel))
489
+ self.convs.append(StyledConv(out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel))
490
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
491
+
492
+ self.to_flows.append(ToFlow(out_channel, style_dim))
493
+
494
+ in_channel = out_channel
495
+
496
+ self.n_latent = self.log_size * 2 - 2
497
+
498
+ def forward(self, source_before_decoupling, target_motion, feats):
499
+
500
+ directions = self.direction(target_motion)
501
+ latent = source_before_decoupling + directions # wa + directions
502
+
503
+ inject_index = self.n_latent
504
+ latent = latent.unsqueeze(1).repeat(1, inject_index, 1)
505
+
506
+ out = self.input(latent)
507
+ out = self.conv1(out, latent[:, 0])
508
+
509
+ i = 1
510
+ for conv1, conv2, to_rgb, to_flow, feat in zip(self.convs[::2], self.convs[1::2], self.to_rgbs,
511
+ self.to_flows, feats):
512
+ out = conv1(out, latent[:, i])
513
+ out = conv2(out, latent[:, i + 1])
514
+ if out.size(2) == 8:
515
+ out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat)
516
+ skip = to_rgb(out_warp)
517
+ else:
518
+ out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat, skip_flow)
519
+ skip = to_rgb(out_warp, skip)
520
+ i += 2
521
+
522
+ img = skip
523
+
524
+ return img
525
+
526
+
527
+
networks/utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+
5
+
6
+ class AntiAliasInterpolation2d(nn.Module):
7
+ """
8
+ Band-limited downsampling, for better preservation of the input signal.
9
+ """
10
+
11
+ def __init__(self, channels, scale):
12
+ super(AntiAliasInterpolation2d, self).__init__()
13
+ sigma = (1 / scale - 1) / 2
14
+ kernel_size = 2 * round(sigma * 4) + 1
15
+ self.ka = kernel_size // 2
16
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
17
+
18
+ kernel_size = [kernel_size, kernel_size]
19
+ sigma = [sigma, sigma]
20
+ # The gaussian kernel is the product of the
21
+ # gaussian function of each dimension.
22
+ kernel = 1
23
+ meshgrids = torch.meshgrid(
24
+ [
25
+ torch.arange(size, dtype=torch.float32)
26
+ for size in kernel_size
27
+ ]
28
+ )
29
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
30
+ mean = (size - 1) / 2
31
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
32
+
33
+ # Make sure sum of values in gaussian kernel equals 1.
34
+ kernel = kernel / torch.sum(kernel)
35
+ # Reshape to depthwise convolutional weight
36
+ kernel = kernel.view(1, 1, *kernel.size())
37
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
38
+
39
+ self.register_buffer('weight', kernel)
40
+ self.groups = channels
41
+ self.scale = scale
42
+ inv_scale = 1 / scale
43
+ self.int_inv_scale = int(inv_scale)
44
+
45
+ def forward(self, input):
46
+ if self.scale == 1.0:
47
+ return input
48
+
49
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
50
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
51
+ out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
52
+
53
+ return out
renderer.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import *
2
+
3
+ def render_condition(
4
+ conf: TrainConfig,
5
+ model,
6
+ sampler, start, motion_direction_start, audio_driven, \
7
+ face_location, face_scale, \
8
+ yaw_pitch_roll, noisyT, control_flag,
9
+ ):
10
+ if conf.train_mode == TrainMode.diffusion:
11
+ assert conf.model_type.has_autoenc()
12
+
13
+ return sampler.sample(model=model,
14
+ noise=noisyT,
15
+ model_kwargs={
16
+ 'motion_direction_start': motion_direction_start,
17
+ 'yaw_pitch_roll': yaw_pitch_roll,
18
+ 'start': start,
19
+ 'audio_driven': audio_driven,
20
+ 'face_location': face_location,
21
+ 'face_scale': face_scale,
22
+ 'control_flag': control_flag
23
+ })
24
+ else:
25
+ raise NotImplementedError()
templates.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from experiment import *
2
+
3
+
4
+ def ddpm():
5
+ """
6
+ base configuration for all DDIM-based models.
7
+ """
8
+ conf = TrainConfig()
9
+ conf.batch_size = 32
10
+ conf.beatgans_gen_type = GenerativeType.ddim
11
+ conf.beta_scheduler = 'linear'
12
+ conf.data_name = 'ffhq'
13
+ conf.diffusion_type = 'beatgans'
14
+ conf.eval_ema_every_samples = 200_000
15
+ conf.eval_every_samples = 200_000
16
+ conf.fp16 = True
17
+ conf.lr = 1e-4
18
+ conf.model_name = ModelName.beatgans_ddpm
19
+ conf.net_attn = (16, )
20
+ conf.net_beatgans_attn_head = 1
21
+ conf.net_beatgans_embed_channels = 512
22
+ conf.net_ch_mult = (1, 2, 4, 8)
23
+ conf.net_ch = 64
24
+ conf.sample_size = 32
25
+ conf.T_eval = 20
26
+ conf.T = 1000
27
+ conf.make_model_conf()
28
+ return conf
29
+
30
+
31
+ def autoenc_base():
32
+ """
33
+ base configuration for all Diff-AE models.
34
+ """
35
+ conf = TrainConfig()
36
+ conf.batch_size = 32
37
+ conf.beatgans_gen_type = GenerativeType.ddim
38
+ conf.beta_scheduler = 'linear'
39
+ conf.data_name = 'ffhq'
40
+ conf.diffusion_type = 'beatgans'
41
+ conf.eval_ema_every_samples = 200_000
42
+ conf.eval_every_samples = 200_000
43
+ conf.fp16 = True
44
+ conf.lr = 1e-4
45
+ conf.model_name = ModelName.beatgans_autoenc
46
+ conf.net_attn = (16, )
47
+ conf.net_beatgans_attn_head = 1
48
+ conf.net_beatgans_embed_channels = 512
49
+ conf.net_beatgans_resnet_two_cond = True
50
+ conf.net_ch_mult = (1, 2, 4, 8)
51
+ conf.net_ch = 64
52
+ conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
53
+ conf.net_enc_pool = 'adaptivenonzero'
54
+ conf.sample_size = 32
55
+ conf.T_eval = 20
56
+ conf.T = 1000
57
+ conf.make_model_conf()
58
+ return conf
59
+
60
+
61
+ def ffhq64_ddpm():
62
+ conf = ddpm()
63
+ conf.data_name = 'ffhqlmdb256'
64
+ conf.warmup = 0
65
+ conf.total_samples = 72_000_000
66
+ conf.scale_up_gpus(4)
67
+ return conf
68
+
69
+
70
+ def ffhq64_autoenc():
71
+ conf = autoenc_base()
72
+ conf.data_name = 'ffhqlmdb256'
73
+ conf.warmup = 0
74
+ conf.total_samples = 72_000_000
75
+ conf.net_ch_mult = (1, 2, 4, 8)
76
+ conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
77
+ conf.eval_every_samples = 1_000_000
78
+ conf.eval_ema_every_samples = 1_000_000
79
+ conf.scale_up_gpus(4)
80
+ conf.make_model_conf()
81
+ return conf
82
+
83
+
84
+ def celeba64d2c_ddpm():
85
+ conf = ffhq128_ddpm()
86
+ conf.data_name = 'celebalmdb'
87
+ conf.eval_every_samples = 10_000_000
88
+ conf.eval_ema_every_samples = 10_000_000
89
+ conf.total_samples = 72_000_000
90
+ conf.name = 'celeba64d2c_ddpm'
91
+ return conf
92
+
93
+
94
+ def celeba64d2c_autoenc():
95
+ conf = ffhq64_autoenc()
96
+ conf.data_name = 'celebalmdb'
97
+ conf.eval_every_samples = 10_000_000
98
+ conf.eval_ema_every_samples = 10_000_000
99
+ conf.total_samples = 72_000_000
100
+ conf.name = 'celeba64d2c_autoenc'
101
+ return conf
102
+
103
+
104
+ def ffhq128_ddpm():
105
+ conf = ddpm()
106
+ conf.data_name = 'ffhqlmdb256'
107
+ conf.warmup = 0
108
+ conf.total_samples = 48_000_000
109
+ conf.img_size = 128
110
+ conf.net_ch = 128
111
+ # channels:
112
+ # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4
113
+ # sizes:
114
+ # 128 => 128 => 64 => 32 => 16 => 8
115
+ conf.net_ch_mult = (1, 1, 2, 3, 4)
116
+ conf.eval_every_samples = 1_000_000
117
+ conf.eval_ema_every_samples = 1_000_000
118
+ conf.scale_up_gpus(4)
119
+ conf.eval_ema_every_samples = 10_000_000
120
+ conf.eval_every_samples = 10_000_000
121
+ conf.make_model_conf()
122
+ return conf
123
+
124
+
125
+ def ffhq128_autoenc_base():
126
+ conf = autoenc_base()
127
+ conf.data_name = 'ffhqlmdb256'
128
+ conf.scale_up_gpus(4)
129
+ conf.img_size = 128
130
+ conf.net_ch = 128
131
+ # final resolution = 8x8
132
+ conf.net_ch_mult = (1, 1, 2, 3, 4)
133
+ # final resolution = 4x4
134
+ conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4)
135
+ conf.eval_ema_every_samples = 10_000_000
136
+ conf.eval_every_samples = 10_000_000
137
+ conf.make_model_conf()
138
+ return conf
139
+
140
+
141
+ def ffhq256_autoenc():
142
+ conf = ffhq128_autoenc_base()
143
+ conf.img_size = 256
144
+ conf.net_ch = 128
145
+ conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
146
+ conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
147
+ conf.eval_every_samples = 10_000_000
148
+ conf.eval_ema_every_samples = 10_000_000
149
+ conf.total_samples = 200_000_000
150
+ conf.batch_size = 64
151
+ conf.make_model_conf()
152
+ conf.name = 'ffhq256_autoenc'
153
+ return conf
154
+
155
+
156
+ def ffhq256_autoenc_eco():
157
+ conf = ffhq128_autoenc_base()
158
+ conf.img_size = 256
159
+ conf.net_ch = 128
160
+ conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
161
+ conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
162
+ conf.eval_every_samples = 10_000_000
163
+ conf.eval_ema_every_samples = 10_000_000
164
+ conf.total_samples = 200_000_000
165
+ conf.batch_size = 64
166
+ conf.make_model_conf()
167
+ conf.name = 'ffhq256_autoenc_eco'
168
+ return conf
169
+
170
+
171
+ def ffhq128_ddpm_72M():
172
+ conf = ffhq128_ddpm()
173
+ conf.total_samples = 72_000_000
174
+ conf.name = 'ffhq128_ddpm_72M'
175
+ return conf
176
+
177
+
178
+ def ffhq128_autoenc_72M():
179
+ conf = ffhq128_autoenc_base()
180
+ conf.total_samples = 72_000_000
181
+ conf.name = 'ffhq128_autoenc_72M'
182
+ return conf
183
+
184
+
185
+ def ffhq128_ddpm_130M():
186
+ conf = ffhq128_ddpm()
187
+ conf.total_samples = 130_000_000
188
+ conf.eval_ema_every_samples = 10_000_000
189
+ conf.eval_every_samples = 10_000_000
190
+ conf.name = 'ffhq128_ddpm_130M'
191
+ return conf
192
+
193
+
194
+ def ffhq128_autoenc_130M():
195
+ conf = ffhq128_autoenc_base()
196
+ conf.total_samples = 130_000_000
197
+ conf.eval_ema_every_samples = 10_000_000
198
+ conf.eval_every_samples = 10_000_000
199
+ conf.name = 'ffhq128_autoenc_130M'
200
+ return conf
201
+
202
+
203
+ def horse128_ddpm():
204
+ conf = ffhq128_ddpm()
205
+ conf.data_name = 'horse256'
206
+ conf.total_samples = 130_000_000
207
+ conf.eval_ema_every_samples = 10_000_000
208
+ conf.eval_every_samples = 10_000_000
209
+ conf.name = 'horse128_ddpm'
210
+ return conf
211
+
212
+
213
+ def horse128_autoenc():
214
+ conf = ffhq128_autoenc_base()
215
+ conf.data_name = 'horse256'
216
+ conf.total_samples = 130_000_000
217
+ conf.eval_ema_every_samples = 10_000_000
218
+ conf.eval_every_samples = 10_000_000
219
+ conf.name = 'horse128_autoenc'
220
+ return conf
221
+
222
+
223
+ def bedroom128_ddpm():
224
+ conf = ffhq128_ddpm()
225
+ conf.data_name = 'bedroom256'
226
+ conf.eval_ema_every_samples = 10_000_000
227
+ conf.eval_every_samples = 10_000_000
228
+ conf.total_samples = 120_000_000
229
+ conf.name = 'bedroom128_ddpm'
230
+ return conf
231
+
232
+
233
+ def bedroom128_autoenc():
234
+ conf = ffhq128_autoenc_base()
235
+ conf.data_name = 'bedroom256'
236
+ conf.eval_ema_every_samples = 10_000_000
237
+ conf.eval_every_samples = 10_000_000
238
+ conf.total_samples = 120_000_000
239
+ conf.name = 'bedroom128_autoenc'
240
+ return conf
241
+
242
+
243
+ def pretrain_celeba64d2c_72M():
244
+ conf = celeba64d2c_autoenc()
245
+ conf.pretrain = PretrainConfig(
246
+ name='72M',
247
+ path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt',
248
+ )
249
+ conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl'
250
+ return conf
251
+
252
+
253
+ def pretrain_ffhq128_autoenc72M():
254
+ conf = ffhq128_autoenc_base()
255
+ conf.postfix = ''
256
+ conf.pretrain = PretrainConfig(
257
+ name='72M',
258
+ path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt',
259
+ )
260
+ conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl'
261
+ return conf
262
+
263
+
264
+ def pretrain_ffhq128_autoenc130M():
265
+ conf = ffhq128_autoenc_base()
266
+ conf.pretrain = PretrainConfig(
267
+ name='130M',
268
+ path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt',
269
+ )
270
+ conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl'
271
+ return conf
272
+
273
+
274
+ def pretrain_ffhq256_autoenc():
275
+ conf = ffhq256_autoenc()
276
+ conf.pretrain = PretrainConfig(
277
+ name='90M',
278
+ path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt',
279
+ )
280
+ conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl'
281
+ return conf
282
+
283
+
284
+ def pretrain_horse128():
285
+ conf = horse128_autoenc()
286
+ conf.pretrain = PretrainConfig(
287
+ name='82M',
288
+ path=f'checkpoints/{horse128_autoenc().name}/last.ckpt',
289
+ )
290
+ conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl'
291
+ return conf
292
+
293
+
294
+ def pretrain_bedroom128():
295
+ conf = bedroom128_autoenc()
296
+ conf.pretrain = PretrainConfig(
297
+ name='120M',
298
+ path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt',
299
+ )
300
+ conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl'
301
+ return conf