Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from skimage.io import imsave | |
from torch.optim.lr_scheduler import LambdaLR | |
from tqdm import tqdm | |
from ldm.base_utils import read_pickle, concat_images_list | |
from ldm.models.diffusion.sync_dreamer_utils import get_warp_coordinates, create_target_volume | |
from ldm.models.diffusion.sync_dreamer_network import NoisyTargetViewEncoder, SpatialTime3DNet, FrustumTV3DNet | |
from ldm.modules.diffusionmodules.util import make_ddim_timesteps, timestep_embedding | |
from ldm.modules.encoders.modules import FrozenCLIPImageEmbedder | |
from ldm.util import instantiate_from_config | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
def disable_training_module(module: nn.Module): | |
module = module.eval() | |
module.train = disabled_train | |
for para in module.parameters(): | |
para.requires_grad = False | |
return module | |
def repeat_to_batch(tensor, B, VN): | |
t_shape = tensor.shape | |
ones = [1 for _ in range(len(t_shape)-1)] | |
tensor_new = tensor.view(B,1,*t_shape[1:]).repeat(1,VN,*ones).view(B*VN,*t_shape[1:]) | |
return tensor_new | |
class UNetWrapper(nn.Module): | |
def __init__(self, diff_model_config, drop_conditions=False, drop_scheme='default', use_zero_123=True): | |
super().__init__() | |
self.diffusion_model = instantiate_from_config(diff_model_config) | |
self.drop_conditions = drop_conditions | |
self.drop_scheme=drop_scheme | |
self.use_zero_123 = use_zero_123 | |
def drop(self, cond, mask): | |
shape = cond.shape | |
B = shape[0] | |
cond = mask.view(B,*[1 for _ in range(len(shape)-1)]) * cond | |
return cond | |
def get_trainable_parameters(self): | |
return self.diffusion_model.get_trainable_parameters() | |
def get_drop_scheme(self, B, device): | |
if self.drop_scheme=='default': | |
random = torch.rand(B, dtype=torch.float32, device=device) | |
drop_clip = (random > 0.15) & (random <= 0.2) | |
drop_volume = (random > 0.1) & (random <= 0.15) | |
drop_concat = (random > 0.05) & (random <= 0.1) | |
drop_all = random <= 0.05 | |
else: | |
raise NotImplementedError | |
return drop_clip, drop_volume, drop_concat, drop_all | |
def forward(self, x, t, clip_embed, volume_feats, x_concat, is_train=False): | |
""" | |
@param x: B,4,H,W | |
@param t: B, | |
@param clip_embed: B,M,768 | |
@param volume_feats: B,C,D,H,W | |
@param x_concat: B,C,H,W | |
@param is_train: | |
@return: | |
""" | |
if self.drop_conditions and is_train: | |
B = x.shape[0] | |
drop_clip, drop_volume, drop_concat, drop_all = self.get_drop_scheme(B, x.device) | |
clip_mask = 1.0 - (drop_clip | drop_all).float() | |
clip_embed = self.drop(clip_embed, clip_mask) | |
volume_mask = 1.0 - (drop_volume | drop_all).float() | |
for k, v in volume_feats.items(): | |
volume_feats[k] = self.drop(v, mask=volume_mask) | |
concat_mask = 1.0 - (drop_concat | drop_all).float() | |
x_concat = self.drop(x_concat, concat_mask) | |
if self.use_zero_123: | |
# zero123 does not multiply this when encoding, maybe a bug for zero123 | |
first_stage_scale_factor = 0.18215 | |
x_concat_ = x_concat * 1.0 | |
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor | |
else: | |
x_concat_ = x_concat | |
x = torch.cat([x, x_concat_], 1) | |
pred = self.diffusion_model(x, t, clip_embed, source_dict=volume_feats) | |
return pred | |
def predict_with_decomposed_unconditional_scales(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scales): | |
x_ = torch.cat([x] * 3, 0) | |
t_ = torch.cat([t] * 3, 0) | |
clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed), clip_embed], 0) | |
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat), x_concat*4], 0) | |
v_ = {} | |
for k, v in volume_feats.items(): | |
v_[k] = torch.cat([v, v, torch.zeros_like(v)], 0) | |
if self.use_zero_123: | |
# zero123 does not multiply this when encoding, maybe a bug for zero123 | |
first_stage_scale_factor = 0.18215 | |
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor | |
x_ = torch.cat([x_, x_concat_], 1) | |
s, s_uc1, s_uc2 = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(3) | |
s = s + unconditional_scales[0] * (s - s_uc1) + unconditional_scales[1] * (s - s_uc2) | |
return s | |
class SpatialVolumeNet(nn.Module): | |
def __init__(self, time_dim, view_dim, view_num, | |
input_image_size=256, frustum_volume_depth=48, | |
spatial_volume_size=32, spatial_volume_length=0.5, | |
frustum_volume_length=0.86603 # sqrt(3)/2 | |
): | |
super().__init__() | |
self.target_encoder = NoisyTargetViewEncoder(time_dim, view_dim, output_dim=16) | |
self.spatial_volume_feats = SpatialTime3DNet(input_dim=16 * view_num, time_dim=time_dim, dims=(64, 128, 256, 512)) | |
self.frustum_volume_feats = FrustumTV3DNet(64, time_dim, view_dim, dims=(64, 128, 256, 512)) | |
self.frustum_volume_length = frustum_volume_length | |
self.input_image_size = input_image_size | |
self.spatial_volume_size = spatial_volume_size | |
self.spatial_volume_length = spatial_volume_length | |
self.frustum_volume_size = self.input_image_size // 8 | |
self.frustum_volume_depth = frustum_volume_depth | |
self.time_dim = time_dim | |
self.view_dim = view_dim | |
self.default_origin_depth = 1.5 # our rendered images are 1.5 away from the origin, we assume camera is 1.5 away from the origin | |
def construct_spatial_volume(self, x, t_embed, v_embed, target_poses, target_Ks): | |
""" | |
@param x: B,N,4,H,W | |
@param t_embed: B,t_dim | |
@param v_embed: B,N,v_dim | |
@param target_poses: N,3,4 | |
@param target_Ks: N,3,3 | |
@return: | |
""" | |
B, N, _, H, W = x.shape | |
V = self.spatial_volume_size | |
device = x.device | |
spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device) | |
spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1) | |
spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)] | |
spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1) | |
# encode source features | |
t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim) | |
# v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim) | |
v_embed_ = v_embed | |
target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1) | |
target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1) | |
# extract 2D image features | |
spatial_volume_feats = [] | |
# project source features | |
for ni in range(0, N): | |
pose_source_ = target_poses[:, ni] | |
K_source_ = target_Ks[:, ni] | |
x_ = self.target_encoder(x[:, ni], t_embed_[:, ni], v_embed_[:, ni]) | |
C = x_.shape[1] | |
coords_source = get_warp_coordinates(spatial_volume_verts, x_.shape[-1], self.input_image_size, K_source_, pose_source_).view(B, V, V * V, 2) | |
unproj_feats_ = F.grid_sample(x_, coords_source, mode='bilinear', padding_mode='zeros', align_corners=True) | |
unproj_feats_ = unproj_feats_.view(B, C, V, V, V) | |
spatial_volume_feats.append(unproj_feats_) | |
spatial_volume_feats = torch.stack(spatial_volume_feats, 1) # B,N,C,V,V,V | |
N = spatial_volume_feats.shape[1] | |
spatial_volume_feats = spatial_volume_feats.view(B, N*C, V, V, V) | |
spatial_volume_feats = self.spatial_volume_feats(spatial_volume_feats, t_embed) # b,64,32,32,32 | |
return spatial_volume_feats | |
def construct_view_frustum_volume(self, spatial_volume, t_embed, v_embed, poses, Ks, target_indices): | |
""" | |
@param spatial_volume: B,C,V,V,V | |
@param t_embed: B,t_dim | |
@param v_embed: B,N,v_dim | |
@param poses: N,3,4 | |
@param Ks: N,3,3 | |
@param target_indices: B,TN | |
@return: B*TN,C,H,W | |
""" | |
B, TN = target_indices.shape | |
H, W = self.frustum_volume_size, self.frustum_volume_size | |
D = self.frustum_volume_depth | |
V = self.spatial_volume_size | |
near = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth - self.frustum_volume_length | |
far = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth + self.frustum_volume_length | |
target_indices = target_indices.view(B*TN) # B*TN | |
poses_ = poses[target_indices] # B*TN,3,4 | |
Ks_ = Ks[target_indices] # B*TN,3,4 | |
volume_xyz, volume_depth = create_target_volume(D, self.frustum_volume_size, self.input_image_size, poses_, Ks_, near, far) # B*TN,3 or 1,D,H,W | |
volume_xyz_ = volume_xyz / self.spatial_volume_length # since the spatial volume is constructed in [-spatial_volume_length,spatial_volume_length] | |
volume_xyz_ = volume_xyz_.permute(0, 2, 3, 4, 1) # B*TN,D,H,W,3 | |
spatial_volume_ = spatial_volume.unsqueeze(1).repeat(1, TN, 1, 1, 1, 1).view(B * TN, -1, V, V, V) | |
volume_feats = F.grid_sample(spatial_volume_, volume_xyz_, mode='bilinear', padding_mode='zeros', align_corners=True) # B*TN,C,D,H,W | |
v_embed_ = v_embed[torch.arange(B)[:,None], target_indices.view(B,TN)].view(B*TN, -1) # B*TN | |
t_embed_ = t_embed.unsqueeze(1).repeat(1,TN,1).view(B*TN,-1) | |
volume_feats_dict = self.frustum_volume_feats(volume_feats, t_embed_, v_embed_) | |
return volume_feats_dict, volume_depth | |
class SyncMultiviewDiffusion(pl.LightningModule): | |
def __init__(self, unet_config, scheduler_config, | |
finetune_unet=False, finetune_projection=True, | |
view_num=16, image_size=256, | |
cfg_scale=3.0, output_num=8, batch_view_num=4, | |
drop_conditions=False, drop_scheme='default', | |
clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"): | |
super().__init__() | |
self.finetune_unet = finetune_unet | |
self.finetune_projection = finetune_projection | |
self.view_num = view_num | |
self.viewpoint_dim = 4 | |
self.output_num = output_num | |
self.image_size = image_size | |
self.batch_view_num = batch_view_num | |
self.cfg_scale = cfg_scale | |
self.clip_image_encoder_path = clip_image_encoder_path | |
self._init_time_step_embedding() | |
self._init_first_stage() | |
self._init_schedule() | |
self._init_multiview() | |
self._init_clip_image_encoder() | |
self._init_clip_projection() | |
self.spatial_volume = SpatialVolumeNet(self.time_embed_dim, self.viewpoint_dim, self.view_num) | |
self.model = UNetWrapper(unet_config, drop_conditions=drop_conditions, drop_scheme=drop_scheme) | |
self.scheduler_config = scheduler_config | |
latent_size = image_size//8 | |
self.ddim = SyncDDIMSampler(self, 200, "uniform", 1.0, latent_size=latent_size) | |
def _init_clip_projection(self): | |
self.cc_projection = nn.Linear(772, 768) | |
nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768]) | |
nn.init.zeros_(list(self.cc_projection.parameters())[1]) | |
self.cc_projection.requires_grad_(True) | |
if not self.finetune_projection: | |
disable_training_module(self.cc_projection) | |
def _init_multiview(self): | |
K, azs, _, _, poses = read_pickle(f'meta_info/camera-{self.view_num}.pkl') | |
default_image_size = 256 | |
ratio = self.image_size/default_image_size | |
K = np.diag([ratio,ratio,1]) @ K | |
K = torch.from_numpy(K.astype(np.float32)) # [3,3] | |
K = K.unsqueeze(0).repeat(self.view_num,1,1) # N,3,3 | |
poses = torch.from_numpy(poses.astype(np.float32)) # N,3,4 | |
self.register_buffer('poses', poses) | |
self.register_buffer('Ks', K) | |
azs = (azs + np.pi) % (np.pi * 2) - np.pi # scale to [-pi,pi] and the index=0 has az=0 | |
self.register_buffer('azimuth', torch.from_numpy(azs.astype(np.float32))) | |
def get_viewpoint_embedding(self, batch_size, elevation_ref): | |
""" | |
@param batch_size: | |
@param elevation_ref: B | |
@return: | |
""" | |
azimuth_input = self.azimuth[0].unsqueeze(0) # 1 | |
azimuth_target = self.azimuth # N | |
elevation_input = -elevation_ref # note that zero123 use a negative elevation here!!! | |
elevation_target = -np.deg2rad(30) | |
d_e = elevation_target - elevation_input # B | |
N = self.azimuth.shape[0] | |
B = batch_size | |
d_e = d_e.unsqueeze(1).repeat(1, N) | |
d_a = azimuth_target - azimuth_input # N | |
d_a = d_a.unsqueeze(0).repeat(B, 1) | |
d_z = torch.zeros_like(d_a) | |
embedding = torch.stack([d_e, torch.sin(d_a), torch.cos(d_a), d_z], -1) # B,N,4 | |
return embedding | |
def _init_first_stage(self): | |
first_stage_config={ | |
"target": "ldm.models.autoencoder.AutoencoderKL", | |
"params": { | |
"embed_dim": 4, | |
"monitor": "val/rec_loss", | |
"ddconfig":{ | |
"double_z": True, | |
"z_channels": 4, | |
"resolution": self.image_size, | |
"in_channels": 3, | |
"out_ch": 3, | |
"ch": 128, | |
"ch_mult": [1,2,4,4], | |
"num_res_blocks": 2, | |
"attn_resolutions": [], | |
"dropout": 0.0 | |
}, | |
"lossconfig": {"target": "torch.nn.Identity"}, | |
} | |
} | |
self.first_stage_scale_factor = 0.18215 | |
self.first_stage_model = instantiate_from_config(first_stage_config) | |
self.first_stage_model = disable_training_module(self.first_stage_model) | |
def _init_clip_image_encoder(self): | |
self.clip_image_encoder = FrozenCLIPImageEmbedder(model=self.clip_image_encoder_path) | |
self.clip_image_encoder = disable_training_module(self.clip_image_encoder) | |
def _init_schedule(self): | |
self.num_timesteps = 1000 | |
linear_start = 0.00085 | |
linear_end = 0.0120 | |
num_timesteps = 1000 | |
betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, num_timesteps, dtype=torch.float32) ** 2 # T | |
assert betas.shape[0] == self.num_timesteps | |
# all in float64 first | |
alphas = 1. - betas | |
alphas_cumprod = torch.cumprod(alphas, dim=0) # T | |
alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0) | |
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # T | |
posterior_log_variance_clipped = torch.log(torch.clamp(posterior_variance, min=1e-20)) | |
posterior_log_variance_clipped = torch.clamp(posterior_log_variance_clipped, min=-10) | |
self.register_buffer("betas", betas.float()) | |
self.register_buffer("alphas", alphas.float()) | |
self.register_buffer("alphas_cumprod", alphas_cumprod.float()) | |
self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod).float()) | |
self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1 - alphas_cumprod).float()) | |
self.register_buffer("posterior_variance", posterior_variance.float()) | |
self.register_buffer('posterior_log_variance_clipped', posterior_log_variance_clipped.float()) | |
def _init_time_step_embedding(self): | |
self.time_embed_dim = 256 | |
self.time_embed = nn.Sequential( | |
nn.Linear(self.time_embed_dim, self.time_embed_dim), | |
nn.SiLU(True), | |
nn.Linear(self.time_embed_dim, self.time_embed_dim), | |
) | |
def encode_first_stage(self, x, sample=True): | |
with torch.no_grad(): | |
posterior = self.first_stage_model.encode(x) # b,4,h//8,w//8 | |
if sample: | |
return posterior.sample().detach() * self.first_stage_scale_factor | |
else: | |
return posterior.mode().detach() * self.first_stage_scale_factor | |
def decode_first_stage(self, z): | |
with torch.no_grad(): | |
z = 1. / self.first_stage_scale_factor * z | |
return self.first_stage_model.decode(z) | |
def prepare(self, batch): | |
# encode target | |
if 'target_image' in batch: | |
image_target = batch['target_image'].permute(0, 1, 4, 2, 3) # b,n,3,h,w | |
N = image_target.shape[1] | |
x = [self.encode_first_stage(image_target[:,ni], True) for ni in range(N)] | |
x = torch.stack(x, 1) # b,n,4,h//8,w//8 | |
else: | |
x = None | |
image_input = batch['input_image'].permute(0, 3, 1, 2) | |
elevation_input = batch['input_elevation'][:, 0] # b | |
x_input = self.encode_first_stage(image_input) | |
input_info = {'image': image_input, 'elevation': elevation_input, 'x': x_input} | |
with torch.no_grad(): | |
clip_embed = self.clip_image_encoder.encode(image_input) | |
return x, clip_embed, input_info | |
def embed_time(self, t): | |
t_embed = timestep_embedding(t, self.time_embed_dim, repeat_only=False) # B,TED | |
t_embed = self.time_embed(t_embed) # B,TED | |
return t_embed | |
def get_target_view_feats(self, x_input, spatial_volume, clip_embed, t_embed, v_embed, target_index): | |
""" | |
@param x_input: B,4,H,W | |
@param spatial_volume: B,C,V,V,V | |
@param clip_embed: B,1,768 | |
@param t_embed: B,t_dim | |
@param v_embed: B,N,v_dim | |
@param target_index: B,TN | |
@return: | |
tensors of size B*TN,* | |
""" | |
B, _, H, W = x_input.shape | |
frustum_volume_feats, frustum_volume_depth = self.spatial_volume.construct_view_frustum_volume(spatial_volume, t_embed, v_embed, self.poses, self.Ks, target_index) | |
# clip | |
TN = target_index.shape[1] | |
v_embed_ = v_embed[torch.arange(B)[:,None], target_index].view(B*TN, self.viewpoint_dim) # B*TN,v_dim | |
clip_embed_ = clip_embed.unsqueeze(1).repeat(1,TN,1,1).view(B*TN,1,768) | |
clip_embed_ = self.cc_projection(torch.cat([clip_embed_, v_embed_.unsqueeze(1)], -1)) # B*TN,1,768 | |
x_input_ = x_input.unsqueeze(1).repeat(1, TN, 1, 1, 1).view(B * TN, 4, H, W) | |
x_concat = x_input_ | |
return clip_embed_, frustum_volume_feats, x_concat | |
def training_step(self, batch): | |
B = batch['target_image'].shape[0] | |
time_steps = torch.randint(0, self.num_timesteps, (B,), device=self.device).long() | |
x, clip_embed, input_info = self.prepare(batch) | |
x_noisy, noise = self.add_noise(x, time_steps) # B,N,4,H,W | |
N = self.view_num | |
target_index = torch.randint(0, N, (B, 1), device=self.device).long() # B, 1 | |
v_embed = self.get_viewpoint_embedding(B, input_info['elevation']) # N,v_dim | |
t_embed = self.embed_time(time_steps) | |
spatial_volume = self.spatial_volume.construct_spatial_volume(x_noisy, t_embed, v_embed, self.poses, self.Ks) | |
clip_embed, volume_feats, x_concat = self.get_target_view_feats(input_info['x'], spatial_volume, clip_embed, t_embed, v_embed, target_index) | |
x_noisy_ = x_noisy[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W | |
noise_predict = self.model(x_noisy_, time_steps, clip_embed, volume_feats, x_concat, is_train=True) # B,4,H,W | |
noise_target = noise[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W | |
# loss simple for diffusion | |
loss_simple = torch.nn.functional.mse_loss(noise_target, noise_predict, reduction='none') | |
loss = loss_simple.mean() | |
self.log('sim', loss_simple.mean(), prog_bar=True, logger=True, on_step=True, on_epoch=True, rank_zero_only=True) | |
# log others | |
lr = self.optimizers().param_groups[0]['lr'] | |
self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True) | |
self.log("step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True) | |
return loss | |
def add_noise(self, x_start, t): | |
""" | |
@param x_start: B,* | |
@param t: B, | |
@return: | |
""" | |
B = x_start.shape[0] | |
noise = torch.randn_like(x_start) # B,* | |
sqrt_alphas_cumprod_ = self.sqrt_alphas_cumprod[t] # B, | |
sqrt_one_minus_alphas_cumprod_ = self.sqrt_one_minus_alphas_cumprod[t] # B | |
sqrt_alphas_cumprod_ = sqrt_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)]) | |
sqrt_one_minus_alphas_cumprod_ = sqrt_one_minus_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)]) | |
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise | |
return x_noisy, noise | |
def sample(self, sampler, batch, cfg_scale, batch_view_num, return_inter_results=False, inter_interval=50, inter_view_interval=2): | |
_, clip_embed, input_info = self.prepare(batch) | |
x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num) | |
N = x_sample.shape[1] | |
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1) | |
if return_inter_results: | |
torch.cuda.synchronize() | |
torch.cuda.empty_cache() | |
inter = torch.stack(inter['x_inter'], 2) # # B,N,T,C,H,W | |
B,N,T,C,H,W = inter.shape | |
inter_results = [] | |
for ni in tqdm(range(0, N, inter_view_interval)): | |
inter_results_ = [] | |
for ti in range(T): | |
inter_results_.append(self.decode_first_stage(inter[:, ni, ti])) | |
inter_results.append(torch.stack(inter_results_, 1)) # B,T,3,H,W | |
inter_results = torch.stack(inter_results,1) # B,N,T,3,H,W | |
return x_sample, inter_results | |
else: | |
return x_sample | |
def log_image(self, x_sample, batch, step, output_dir): | |
process = lambda x: ((torch.clip(x, min=-1, max=1).cpu().numpy() * 0.5 + 0.5) * 255).astype(np.uint8) | |
B = x_sample.shape[0] | |
N = x_sample.shape[1] | |
image_cond = [] | |
for bi in range(B): | |
img_pr_ = concat_images_list(process(batch['input_image'][bi]),*[process(x_sample[bi, ni].permute(1, 2, 0)) for ni in range(N)]) | |
image_cond.append(img_pr_) | |
output_dir = Path(output_dir) | |
imsave(str(output_dir/f'{step}.jpg'), concat_images_list(*image_cond, vert=True)) | |
def validation_step(self, batch, batch_idx): | |
if batch_idx==0 and self.global_rank==0: | |
self.eval() | |
step = self.global_step | |
batch_ = {} | |
for k, v in batch.items(): batch_[k] = v[:self.output_num] | |
x_sample = self.sample(batch_, self.cfg_scale, self.batch_view_num) | |
output_dir = Path(self.image_dir) / 'images' / 'val' | |
output_dir.mkdir(exist_ok=True, parents=True) | |
self.log_image(x_sample, batch, step, output_dir=output_dir) | |
def configure_optimizers(self): | |
lr = self.learning_rate | |
print(f'setting learning rate to {lr:.4f} ...') | |
paras = [] | |
if self.finetune_projection: | |
paras.append({"params": self.cc_projection.parameters(), "lr": lr},) | |
if self.finetune_unet: | |
paras.append({"params": self.model.parameters(), "lr": lr},) | |
else: | |
paras.append({"params": self.model.get_trainable_parameters(), "lr": lr},) | |
paras.append({"params": self.time_embed.parameters(), "lr": lr*10.0},) | |
paras.append({"params": self.spatial_volume.parameters(), "lr": lr*10.0},) | |
opt = torch.optim.AdamW(paras, lr=lr) | |
scheduler = instantiate_from_config(self.scheduler_config) | |
print("Setting up LambdaLR scheduler...") | |
scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] | |
return [opt], scheduler | |
class SyncDDIMSampler: | |
def __init__(self, model: SyncMultiviewDiffusion, ddim_num_steps, ddim_discretize="uniform", ddim_eta=1.0, latent_size=32): | |
super().__init__() | |
self.model = model | |
self.ddpm_num_timesteps = model.num_timesteps | |
self.latent_size = latent_size | |
self._make_schedule(ddim_num_steps, ddim_discretize, ddim_eta) | |
self.eta = ddim_eta | |
def _make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): | |
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) # DT | |
ddim_timesteps_ = torch.from_numpy(self.ddim_timesteps.astype(np.int64)) # DT | |
alphas_cumprod = self.model.alphas_cumprod # T | |
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' | |
self.ddim_alphas = alphas_cumprod[ddim_timesteps_].double() # DT | |
self.ddim_alphas_prev = torch.cat([alphas_cumprod[0:1], alphas_cumprod[ddim_timesteps_[:-1]]], 0) # DT | |
self.ddim_sigmas = ddim_eta * torch.sqrt((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * (1 - self.ddim_alphas / self.ddim_alphas_prev)) | |
self.ddim_alphas_raw = self.model.alphas[ddim_timesteps_].float() # DT | |
self.ddim_sigmas = self.ddim_sigmas.float() | |
self.ddim_alphas = self.ddim_alphas.float() | |
self.ddim_alphas_prev = self.ddim_alphas_prev.float() | |
self.ddim_sqrt_one_minus_alphas = torch.sqrt(1. - self.ddim_alphas).float() | |
def denoise_apply_impl(self, x_target_noisy, index, noise_pred, is_step0=False): | |
""" | |
@param x_target_noisy: B,N,4,H,W | |
@param index: index | |
@param noise_pred: B,N,4,H,W | |
@param is_step0: bool | |
@return: | |
""" | |
device = x_target_noisy.device | |
B,N,_,H,W = x_target_noisy.shape | |
# apply noise | |
a_t = self.ddim_alphas[index].to(device).float().view(1,1,1,1,1) | |
a_prev = self.ddim_alphas_prev[index].to(device).float().view(1,1,1,1,1) | |
sqrt_one_minus_at = self.ddim_sqrt_one_minus_alphas[index].to(device).float().view(1,1,1,1,1) | |
sigma_t = self.ddim_sigmas[index].to(device).float().view(1,1,1,1,1) | |
pred_x0 = (x_target_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() | |
dir_xt = torch.clamp(1. - a_prev - sigma_t**2, min=1e-7).sqrt() * noise_pred | |
x_prev = a_prev.sqrt() * pred_x0 + dir_xt | |
if not is_step0: | |
noise = sigma_t * torch.randn_like(x_target_noisy) | |
x_prev = x_prev + noise | |
return x_prev | |
def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=1, is_step0=False): | |
""" | |
@param x_target_noisy: B,N,4,H,W | |
@param input_info: | |
@param clip_embed: B,M,768 | |
@param time_steps: B, | |
@param index: int | |
@param unconditional_scale: | |
@param batch_view_num: int | |
@param is_step0: bool | |
@return: | |
""" | |
x_input, elevation_input = input_info['x'], input_info['elevation'] | |
B, N, C, H, W = x_target_noisy.shape | |
# construct source data | |
v_embed = self.model.get_viewpoint_embedding(B, elevation_input) # B,N,v_dim | |
t_embed = self.model.embed_time(time_steps) # B,t_dim | |
spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks) | |
e_t = [] | |
target_indices = torch.arange(N) # N | |
for ni in range(0, N, batch_view_num): | |
x_target_noisy_ = x_target_noisy[:, ni:ni + batch_view_num] | |
VN = x_target_noisy_.shape[1] | |
x_target_noisy_ = x_target_noisy_.reshape(B*VN,C,H,W) | |
time_steps_ = repeat_to_batch(time_steps, B, VN) | |
target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1) | |
clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_) | |
noise = self.model.model.predict_with_decomposed_unconditional_scales(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale) | |
e_t.append(noise.view(B,VN,4,H,W)) | |
e_t = torch.cat(e_t, 1) | |
x_prev = self.denoise_apply_impl(x_target_noisy, index, e_t, is_step0) | |
return x_prev | |
def sample(self, input_info, clip_embed, unconditional_scale, log_every_t=50, batch_view_num=1): | |
""" | |
@param input_info: x, elevation | |
@param clip_embed: B,M,768 | |
@param unconditional_scale: | |
@param log_every_t: | |
@param batch_view_num: | |
@return: | |
""" | |
C, H, W = 4, self.latent_size, self.latent_size | |
B = clip_embed.shape[0] | |
N = self.model.view_num | |
device = self.model.device | |
x_target_noisy = torch.randn([B, N, C, H, W], device=device) | |
timesteps = self.ddim_timesteps | |
intermediates = {'x_inter': []} | |
time_range = np.flip(timesteps) | |
total_steps = timesteps.shape[0] | |
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) | |
for i, step in enumerate(iterator): | |
index = total_steps - i - 1 # index in ddim state | |
time_steps = torch.full((B,), step, device=device, dtype=torch.long) | |
x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=batch_view_num, is_step0=index==0) | |
if index % log_every_t == 0 or index == total_steps - 1: | |
intermediates['x_inter'].append(x_target_noisy) | |
return x_target_noisy, intermediates |