HarmonyView / ldm /models /diffusion /sync_dreamer.py
byeongjun-park's picture
HarmonyView update
01a5b8c
from pathlib import Path
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from skimage.io import imsave
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
from ldm.base_utils import read_pickle, concat_images_list
from ldm.models.diffusion.sync_dreamer_utils import get_warp_coordinates, create_target_volume
from ldm.models.diffusion.sync_dreamer_network import NoisyTargetViewEncoder, SpatialTime3DNet, FrustumTV3DNet
from ldm.modules.diffusionmodules.util import make_ddim_timesteps, timestep_embedding
from ldm.modules.encoders.modules import FrozenCLIPImageEmbedder
from ldm.util import instantiate_from_config
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def disable_training_module(module: nn.Module):
module = module.eval()
module.train = disabled_train
for para in module.parameters():
para.requires_grad = False
return module
def repeat_to_batch(tensor, B, VN):
t_shape = tensor.shape
ones = [1 for _ in range(len(t_shape)-1)]
tensor_new = tensor.view(B,1,*t_shape[1:]).repeat(1,VN,*ones).view(B*VN,*t_shape[1:])
return tensor_new
class UNetWrapper(nn.Module):
def __init__(self, diff_model_config, drop_conditions=False, drop_scheme='default', use_zero_123=True):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
self.drop_conditions = drop_conditions
self.drop_scheme=drop_scheme
self.use_zero_123 = use_zero_123
def drop(self, cond, mask):
shape = cond.shape
B = shape[0]
cond = mask.view(B,*[1 for _ in range(len(shape)-1)]) * cond
return cond
def get_trainable_parameters(self):
return self.diffusion_model.get_trainable_parameters()
def get_drop_scheme(self, B, device):
if self.drop_scheme=='default':
random = torch.rand(B, dtype=torch.float32, device=device)
drop_clip = (random > 0.15) & (random <= 0.2)
drop_volume = (random > 0.1) & (random <= 0.15)
drop_concat = (random > 0.05) & (random <= 0.1)
drop_all = random <= 0.05
else:
raise NotImplementedError
return drop_clip, drop_volume, drop_concat, drop_all
def forward(self, x, t, clip_embed, volume_feats, x_concat, is_train=False):
"""
@param x: B,4,H,W
@param t: B,
@param clip_embed: B,M,768
@param volume_feats: B,C,D,H,W
@param x_concat: B,C,H,W
@param is_train:
@return:
"""
if self.drop_conditions and is_train:
B = x.shape[0]
drop_clip, drop_volume, drop_concat, drop_all = self.get_drop_scheme(B, x.device)
clip_mask = 1.0 - (drop_clip | drop_all).float()
clip_embed = self.drop(clip_embed, clip_mask)
volume_mask = 1.0 - (drop_volume | drop_all).float()
for k, v in volume_feats.items():
volume_feats[k] = self.drop(v, mask=volume_mask)
concat_mask = 1.0 - (drop_concat | drop_all).float()
x_concat = self.drop(x_concat, concat_mask)
if self.use_zero_123:
# zero123 does not multiply this when encoding, maybe a bug for zero123
first_stage_scale_factor = 0.18215
x_concat_ = x_concat * 1.0
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
else:
x_concat_ = x_concat
x = torch.cat([x, x_concat_], 1)
pred = self.diffusion_model(x, t, clip_embed, source_dict=volume_feats)
return pred
def predict_with_decomposed_unconditional_scales(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scales):
x_ = torch.cat([x] * 3, 0)
t_ = torch.cat([t] * 3, 0)
clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed), clip_embed], 0)
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat), x_concat*4], 0)
v_ = {}
for k, v in volume_feats.items():
v_[k] = torch.cat([v, v, torch.zeros_like(v)], 0)
if self.use_zero_123:
# zero123 does not multiply this when encoding, maybe a bug for zero123
first_stage_scale_factor = 0.18215
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
x_ = torch.cat([x_, x_concat_], 1)
s, s_uc1, s_uc2 = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(3)
s = s + unconditional_scales[0] * (s - s_uc1) + unconditional_scales[1] * (s - s_uc2)
return s
class SpatialVolumeNet(nn.Module):
def __init__(self, time_dim, view_dim, view_num,
input_image_size=256, frustum_volume_depth=48,
spatial_volume_size=32, spatial_volume_length=0.5,
frustum_volume_length=0.86603 # sqrt(3)/2
):
super().__init__()
self.target_encoder = NoisyTargetViewEncoder(time_dim, view_dim, output_dim=16)
self.spatial_volume_feats = SpatialTime3DNet(input_dim=16 * view_num, time_dim=time_dim, dims=(64, 128, 256, 512))
self.frustum_volume_feats = FrustumTV3DNet(64, time_dim, view_dim, dims=(64, 128, 256, 512))
self.frustum_volume_length = frustum_volume_length
self.input_image_size = input_image_size
self.spatial_volume_size = spatial_volume_size
self.spatial_volume_length = spatial_volume_length
self.frustum_volume_size = self.input_image_size // 8
self.frustum_volume_depth = frustum_volume_depth
self.time_dim = time_dim
self.view_dim = view_dim
self.default_origin_depth = 1.5 # our rendered images are 1.5 away from the origin, we assume camera is 1.5 away from the origin
def construct_spatial_volume(self, x, t_embed, v_embed, target_poses, target_Ks):
"""
@param x: B,N,4,H,W
@param t_embed: B,t_dim
@param v_embed: B,N,v_dim
@param target_poses: N,3,4
@param target_Ks: N,3,3
@return:
"""
B, N, _, H, W = x.shape
V = self.spatial_volume_size
device = x.device
spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1)
spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
# encode source features
t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
# v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim)
v_embed_ = v_embed
target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
# extract 2D image features
spatial_volume_feats = []
# project source features
for ni in range(0, N):
pose_source_ = target_poses[:, ni]
K_source_ = target_Ks[:, ni]
x_ = self.target_encoder(x[:, ni], t_embed_[:, ni], v_embed_[:, ni])
C = x_.shape[1]
coords_source = get_warp_coordinates(spatial_volume_verts, x_.shape[-1], self.input_image_size, K_source_, pose_source_).view(B, V, V * V, 2)
unproj_feats_ = F.grid_sample(x_, coords_source, mode='bilinear', padding_mode='zeros', align_corners=True)
unproj_feats_ = unproj_feats_.view(B, C, V, V, V)
spatial_volume_feats.append(unproj_feats_)
spatial_volume_feats = torch.stack(spatial_volume_feats, 1) # B,N,C,V,V,V
N = spatial_volume_feats.shape[1]
spatial_volume_feats = spatial_volume_feats.view(B, N*C, V, V, V)
spatial_volume_feats = self.spatial_volume_feats(spatial_volume_feats, t_embed) # b,64,32,32,32
return spatial_volume_feats
def construct_view_frustum_volume(self, spatial_volume, t_embed, v_embed, poses, Ks, target_indices):
"""
@param spatial_volume: B,C,V,V,V
@param t_embed: B,t_dim
@param v_embed: B,N,v_dim
@param poses: N,3,4
@param Ks: N,3,3
@param target_indices: B,TN
@return: B*TN,C,H,W
"""
B, TN = target_indices.shape
H, W = self.frustum_volume_size, self.frustum_volume_size
D = self.frustum_volume_depth
V = self.spatial_volume_size
near = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth - self.frustum_volume_length
far = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth + self.frustum_volume_length
target_indices = target_indices.view(B*TN) # B*TN
poses_ = poses[target_indices] # B*TN,3,4
Ks_ = Ks[target_indices] # B*TN,3,4
volume_xyz, volume_depth = create_target_volume(D, self.frustum_volume_size, self.input_image_size, poses_, Ks_, near, far) # B*TN,3 or 1,D,H,W
volume_xyz_ = volume_xyz / self.spatial_volume_length # since the spatial volume is constructed in [-spatial_volume_length,spatial_volume_length]
volume_xyz_ = volume_xyz_.permute(0, 2, 3, 4, 1) # B*TN,D,H,W,3
spatial_volume_ = spatial_volume.unsqueeze(1).repeat(1, TN, 1, 1, 1, 1).view(B * TN, -1, V, V, V)
volume_feats = F.grid_sample(spatial_volume_, volume_xyz_, mode='bilinear', padding_mode='zeros', align_corners=True) # B*TN,C,D,H,W
v_embed_ = v_embed[torch.arange(B)[:,None], target_indices.view(B,TN)].view(B*TN, -1) # B*TN
t_embed_ = t_embed.unsqueeze(1).repeat(1,TN,1).view(B*TN,-1)
volume_feats_dict = self.frustum_volume_feats(volume_feats, t_embed_, v_embed_)
return volume_feats_dict, volume_depth
class SyncMultiviewDiffusion(pl.LightningModule):
def __init__(self, unet_config, scheduler_config,
finetune_unet=False, finetune_projection=True,
view_num=16, image_size=256,
cfg_scale=3.0, output_num=8, batch_view_num=4,
drop_conditions=False, drop_scheme='default',
clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"):
super().__init__()
self.finetune_unet = finetune_unet
self.finetune_projection = finetune_projection
self.view_num = view_num
self.viewpoint_dim = 4
self.output_num = output_num
self.image_size = image_size
self.batch_view_num = batch_view_num
self.cfg_scale = cfg_scale
self.clip_image_encoder_path = clip_image_encoder_path
self._init_time_step_embedding()
self._init_first_stage()
self._init_schedule()
self._init_multiview()
self._init_clip_image_encoder()
self._init_clip_projection()
self.spatial_volume = SpatialVolumeNet(self.time_embed_dim, self.viewpoint_dim, self.view_num)
self.model = UNetWrapper(unet_config, drop_conditions=drop_conditions, drop_scheme=drop_scheme)
self.scheduler_config = scheduler_config
latent_size = image_size//8
self.ddim = SyncDDIMSampler(self, 200, "uniform", 1.0, latent_size=latent_size)
def _init_clip_projection(self):
self.cc_projection = nn.Linear(772, 768)
nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768])
nn.init.zeros_(list(self.cc_projection.parameters())[1])
self.cc_projection.requires_grad_(True)
if not self.finetune_projection:
disable_training_module(self.cc_projection)
def _init_multiview(self):
K, azs, _, _, poses = read_pickle(f'meta_info/camera-{self.view_num}.pkl')
default_image_size = 256
ratio = self.image_size/default_image_size
K = np.diag([ratio,ratio,1]) @ K
K = torch.from_numpy(K.astype(np.float32)) # [3,3]
K = K.unsqueeze(0).repeat(self.view_num,1,1) # N,3,3
poses = torch.from_numpy(poses.astype(np.float32)) # N,3,4
self.register_buffer('poses', poses)
self.register_buffer('Ks', K)
azs = (azs + np.pi) % (np.pi * 2) - np.pi # scale to [-pi,pi] and the index=0 has az=0
self.register_buffer('azimuth', torch.from_numpy(azs.astype(np.float32)))
def get_viewpoint_embedding(self, batch_size, elevation_ref):
"""
@param batch_size:
@param elevation_ref: B
@return:
"""
azimuth_input = self.azimuth[0].unsqueeze(0) # 1
azimuth_target = self.azimuth # N
elevation_input = -elevation_ref # note that zero123 use a negative elevation here!!!
elevation_target = -np.deg2rad(30)
d_e = elevation_target - elevation_input # B
N = self.azimuth.shape[0]
B = batch_size
d_e = d_e.unsqueeze(1).repeat(1, N)
d_a = azimuth_target - azimuth_input # N
d_a = d_a.unsqueeze(0).repeat(B, 1)
d_z = torch.zeros_like(d_a)
embedding = torch.stack([d_e, torch.sin(d_a), torch.cos(d_a), d_z], -1) # B,N,4
return embedding
def _init_first_stage(self):
first_stage_config={
"target": "ldm.models.autoencoder.AutoencoderKL",
"params": {
"embed_dim": 4,
"monitor": "val/rec_loss",
"ddconfig":{
"double_z": True,
"z_channels": 4,
"resolution": self.image_size,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": [1,2,4,4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0
},
"lossconfig": {"target": "torch.nn.Identity"},
}
}
self.first_stage_scale_factor = 0.18215
self.first_stage_model = instantiate_from_config(first_stage_config)
self.first_stage_model = disable_training_module(self.first_stage_model)
def _init_clip_image_encoder(self):
self.clip_image_encoder = FrozenCLIPImageEmbedder(model=self.clip_image_encoder_path)
self.clip_image_encoder = disable_training_module(self.clip_image_encoder)
def _init_schedule(self):
self.num_timesteps = 1000
linear_start = 0.00085
linear_end = 0.0120
num_timesteps = 1000
betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, num_timesteps, dtype=torch.float32) ** 2 # T
assert betas.shape[0] == self.num_timesteps
# all in float64 first
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) # T
alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # T
posterior_log_variance_clipped = torch.log(torch.clamp(posterior_variance, min=1e-20))
posterior_log_variance_clipped = torch.clamp(posterior_log_variance_clipped, min=-10)
self.register_buffer("betas", betas.float())
self.register_buffer("alphas", alphas.float())
self.register_buffer("alphas_cumprod", alphas_cumprod.float())
self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod).float())
self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1 - alphas_cumprod).float())
self.register_buffer("posterior_variance", posterior_variance.float())
self.register_buffer('posterior_log_variance_clipped', posterior_log_variance_clipped.float())
def _init_time_step_embedding(self):
self.time_embed_dim = 256
self.time_embed = nn.Sequential(
nn.Linear(self.time_embed_dim, self.time_embed_dim),
nn.SiLU(True),
nn.Linear(self.time_embed_dim, self.time_embed_dim),
)
def encode_first_stage(self, x, sample=True):
with torch.no_grad():
posterior = self.first_stage_model.encode(x) # b,4,h//8,w//8
if sample:
return posterior.sample().detach() * self.first_stage_scale_factor
else:
return posterior.mode().detach() * self.first_stage_scale_factor
def decode_first_stage(self, z):
with torch.no_grad():
z = 1. / self.first_stage_scale_factor * z
return self.first_stage_model.decode(z)
def prepare(self, batch):
# encode target
if 'target_image' in batch:
image_target = batch['target_image'].permute(0, 1, 4, 2, 3) # b,n,3,h,w
N = image_target.shape[1]
x = [self.encode_first_stage(image_target[:,ni], True) for ni in range(N)]
x = torch.stack(x, 1) # b,n,4,h//8,w//8
else:
x = None
image_input = batch['input_image'].permute(0, 3, 1, 2)
elevation_input = batch['input_elevation'][:, 0] # b
x_input = self.encode_first_stage(image_input)
input_info = {'image': image_input, 'elevation': elevation_input, 'x': x_input}
with torch.no_grad():
clip_embed = self.clip_image_encoder.encode(image_input)
return x, clip_embed, input_info
def embed_time(self, t):
t_embed = timestep_embedding(t, self.time_embed_dim, repeat_only=False) # B,TED
t_embed = self.time_embed(t_embed) # B,TED
return t_embed
def get_target_view_feats(self, x_input, spatial_volume, clip_embed, t_embed, v_embed, target_index):
"""
@param x_input: B,4,H,W
@param spatial_volume: B,C,V,V,V
@param clip_embed: B,1,768
@param t_embed: B,t_dim
@param v_embed: B,N,v_dim
@param target_index: B,TN
@return:
tensors of size B*TN,*
"""
B, _, H, W = x_input.shape
frustum_volume_feats, frustum_volume_depth = self.spatial_volume.construct_view_frustum_volume(spatial_volume, t_embed, v_embed, self.poses, self.Ks, target_index)
# clip
TN = target_index.shape[1]
v_embed_ = v_embed[torch.arange(B)[:,None], target_index].view(B*TN, self.viewpoint_dim) # B*TN,v_dim
clip_embed_ = clip_embed.unsqueeze(1).repeat(1,TN,1,1).view(B*TN,1,768)
clip_embed_ = self.cc_projection(torch.cat([clip_embed_, v_embed_.unsqueeze(1)], -1)) # B*TN,1,768
x_input_ = x_input.unsqueeze(1).repeat(1, TN, 1, 1, 1).view(B * TN, 4, H, W)
x_concat = x_input_
return clip_embed_, frustum_volume_feats, x_concat
def training_step(self, batch):
B = batch['target_image'].shape[0]
time_steps = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
x, clip_embed, input_info = self.prepare(batch)
x_noisy, noise = self.add_noise(x, time_steps) # B,N,4,H,W
N = self.view_num
target_index = torch.randint(0, N, (B, 1), device=self.device).long() # B, 1
v_embed = self.get_viewpoint_embedding(B, input_info['elevation']) # N,v_dim
t_embed = self.embed_time(time_steps)
spatial_volume = self.spatial_volume.construct_spatial_volume(x_noisy, t_embed, v_embed, self.poses, self.Ks)
clip_embed, volume_feats, x_concat = self.get_target_view_feats(input_info['x'], spatial_volume, clip_embed, t_embed, v_embed, target_index)
x_noisy_ = x_noisy[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W
noise_predict = self.model(x_noisy_, time_steps, clip_embed, volume_feats, x_concat, is_train=True) # B,4,H,W
noise_target = noise[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W
# loss simple for diffusion
loss_simple = torch.nn.functional.mse_loss(noise_target, noise_predict, reduction='none')
loss = loss_simple.mean()
self.log('sim', loss_simple.mean(), prog_bar=True, logger=True, on_step=True, on_epoch=True, rank_zero_only=True)
# log others
lr = self.optimizers().param_groups[0]['lr']
self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
self.log("step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
return loss
def add_noise(self, x_start, t):
"""
@param x_start: B,*
@param t: B,
@return:
"""
B = x_start.shape[0]
noise = torch.randn_like(x_start) # B,*
sqrt_alphas_cumprod_ = self.sqrt_alphas_cumprod[t] # B,
sqrt_one_minus_alphas_cumprod_ = self.sqrt_one_minus_alphas_cumprod[t] # B
sqrt_alphas_cumprod_ = sqrt_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)])
sqrt_one_minus_alphas_cumprod_ = sqrt_one_minus_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)])
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
return x_noisy, noise
def sample(self, sampler, batch, cfg_scale, batch_view_num, return_inter_results=False, inter_interval=50, inter_view_interval=2):
_, clip_embed, input_info = self.prepare(batch)
x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num)
N = x_sample.shape[1]
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
if return_inter_results:
torch.cuda.synchronize()
torch.cuda.empty_cache()
inter = torch.stack(inter['x_inter'], 2) # # B,N,T,C,H,W
B,N,T,C,H,W = inter.shape
inter_results = []
for ni in tqdm(range(0, N, inter_view_interval)):
inter_results_ = []
for ti in range(T):
inter_results_.append(self.decode_first_stage(inter[:, ni, ti]))
inter_results.append(torch.stack(inter_results_, 1)) # B,T,3,H,W
inter_results = torch.stack(inter_results,1) # B,N,T,3,H,W
return x_sample, inter_results
else:
return x_sample
def log_image(self, x_sample, batch, step, output_dir):
process = lambda x: ((torch.clip(x, min=-1, max=1).cpu().numpy() * 0.5 + 0.5) * 255).astype(np.uint8)
B = x_sample.shape[0]
N = x_sample.shape[1]
image_cond = []
for bi in range(B):
img_pr_ = concat_images_list(process(batch['input_image'][bi]),*[process(x_sample[bi, ni].permute(1, 2, 0)) for ni in range(N)])
image_cond.append(img_pr_)
output_dir = Path(output_dir)
imsave(str(output_dir/f'{step}.jpg'), concat_images_list(*image_cond, vert=True))
@torch.no_grad()
def validation_step(self, batch, batch_idx):
if batch_idx==0 and self.global_rank==0:
self.eval()
step = self.global_step
batch_ = {}
for k, v in batch.items(): batch_[k] = v[:self.output_num]
x_sample = self.sample(batch_, self.cfg_scale, self.batch_view_num)
output_dir = Path(self.image_dir) / 'images' / 'val'
output_dir.mkdir(exist_ok=True, parents=True)
self.log_image(x_sample, batch, step, output_dir=output_dir)
def configure_optimizers(self):
lr = self.learning_rate
print(f'setting learning rate to {lr:.4f} ...')
paras = []
if self.finetune_projection:
paras.append({"params": self.cc_projection.parameters(), "lr": lr},)
if self.finetune_unet:
paras.append({"params": self.model.parameters(), "lr": lr},)
else:
paras.append({"params": self.model.get_trainable_parameters(), "lr": lr},)
paras.append({"params": self.time_embed.parameters(), "lr": lr*10.0},)
paras.append({"params": self.spatial_volume.parameters(), "lr": lr*10.0},)
opt = torch.optim.AdamW(paras, lr=lr)
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}]
return [opt], scheduler
class SyncDDIMSampler:
def __init__(self, model: SyncMultiviewDiffusion, ddim_num_steps, ddim_discretize="uniform", ddim_eta=1.0, latent_size=32):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.latent_size = latent_size
self._make_schedule(ddim_num_steps, ddim_discretize, ddim_eta)
self.eta = ddim_eta
def _make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) # DT
ddim_timesteps_ = torch.from_numpy(self.ddim_timesteps.astype(np.int64)) # DT
alphas_cumprod = self.model.alphas_cumprod # T
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
self.ddim_alphas = alphas_cumprod[ddim_timesteps_].double() # DT
self.ddim_alphas_prev = torch.cat([alphas_cumprod[0:1], alphas_cumprod[ddim_timesteps_[:-1]]], 0) # DT
self.ddim_sigmas = ddim_eta * torch.sqrt((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * (1 - self.ddim_alphas / self.ddim_alphas_prev))
self.ddim_alphas_raw = self.model.alphas[ddim_timesteps_].float() # DT
self.ddim_sigmas = self.ddim_sigmas.float()
self.ddim_alphas = self.ddim_alphas.float()
self.ddim_alphas_prev = self.ddim_alphas_prev.float()
self.ddim_sqrt_one_minus_alphas = torch.sqrt(1. - self.ddim_alphas).float()
@torch.no_grad()
def denoise_apply_impl(self, x_target_noisy, index, noise_pred, is_step0=False):
"""
@param x_target_noisy: B,N,4,H,W
@param index: index
@param noise_pred: B,N,4,H,W
@param is_step0: bool
@return:
"""
device = x_target_noisy.device
B,N,_,H,W = x_target_noisy.shape
# apply noise
a_t = self.ddim_alphas[index].to(device).float().view(1,1,1,1,1)
a_prev = self.ddim_alphas_prev[index].to(device).float().view(1,1,1,1,1)
sqrt_one_minus_at = self.ddim_sqrt_one_minus_alphas[index].to(device).float().view(1,1,1,1,1)
sigma_t = self.ddim_sigmas[index].to(device).float().view(1,1,1,1,1)
pred_x0 = (x_target_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt()
dir_xt = torch.clamp(1. - a_prev - sigma_t**2, min=1e-7).sqrt() * noise_pred
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
if not is_step0:
noise = sigma_t * torch.randn_like(x_target_noisy)
x_prev = x_prev + noise
return x_prev
@torch.no_grad()
def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=1, is_step0=False):
"""
@param x_target_noisy: B,N,4,H,W
@param input_info:
@param clip_embed: B,M,768
@param time_steps: B,
@param index: int
@param unconditional_scale:
@param batch_view_num: int
@param is_step0: bool
@return:
"""
x_input, elevation_input = input_info['x'], input_info['elevation']
B, N, C, H, W = x_target_noisy.shape
# construct source data
v_embed = self.model.get_viewpoint_embedding(B, elevation_input) # B,N,v_dim
t_embed = self.model.embed_time(time_steps) # B,t_dim
spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
e_t = []
target_indices = torch.arange(N) # N
for ni in range(0, N, batch_view_num):
x_target_noisy_ = x_target_noisy[:, ni:ni + batch_view_num]
VN = x_target_noisy_.shape[1]
x_target_noisy_ = x_target_noisy_.reshape(B*VN,C,H,W)
time_steps_ = repeat_to_batch(time_steps, B, VN)
target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1)
clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
noise = self.model.model.predict_with_decomposed_unconditional_scales(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
e_t.append(noise.view(B,VN,4,H,W))
e_t = torch.cat(e_t, 1)
x_prev = self.denoise_apply_impl(x_target_noisy, index, e_t, is_step0)
return x_prev
@torch.no_grad()
def sample(self, input_info, clip_embed, unconditional_scale, log_every_t=50, batch_view_num=1):
"""
@param input_info: x, elevation
@param clip_embed: B,M,768
@param unconditional_scale:
@param log_every_t:
@param batch_view_num:
@return:
"""
C, H, W = 4, self.latent_size, self.latent_size
B = clip_embed.shape[0]
N = self.model.view_num
device = self.model.device
x_target_noisy = torch.randn([B, N, C, H, W], device=device)
timesteps = self.ddim_timesteps
intermediates = {'x_inter': []}
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1 # index in ddim state
time_steps = torch.full((B,), step, device=device, dtype=torch.long)
x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=batch_view_num, is_step0=index==0)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(x_target_noisy)
return x_target_noisy, intermediates