from diffusers import DDIMScheduler import torchvision.transforms.functional as TF import numpy as np from PIL import Image import matplotlib.pyplot as plt import torch import torch.nn as nn import torchvision from torchvision.utils import save_image from torchvision import transforms import torch.nn.functional as F from einops import rearrange import sys sys.path.append('./') from sparseags.guidance_utils.zero123 import Zero123Pipeline name_mapping = { "model.diffusion_model.input_blocks.1.1.": "down_blocks.0.attentions.0.", "model.diffusion_model.input_blocks.2.1.": "down_blocks.0.attentions.1.", "model.diffusion_model.input_blocks.4.1.": "down_blocks.1.attentions.0.", "model.diffusion_model.input_blocks.5.1.": "down_blocks.1.attentions.1.", "model.diffusion_model.input_blocks.7.1.": "down_blocks.2.attentions.0.", "model.diffusion_model.input_blocks.8.1.": "down_blocks.2.attentions.1.", "model.diffusion_model.middle_block.1.": "mid_block.attentions.0.", "model.diffusion_model.output_blocks.3.1.": "up_blocks.1.attentions.0.", "model.diffusion_model.output_blocks.4.1.": "up_blocks.1.attentions.1.", "model.diffusion_model.output_blocks.5.1.": "up_blocks.1.attentions.2.", "model.diffusion_model.output_blocks.6.1.": "up_blocks.2.attentions.0.", "model.diffusion_model.output_blocks.7.1.": "up_blocks.2.attentions.1.", "model.diffusion_model.output_blocks.8.1.": "up_blocks.2.attentions.2.", "model.diffusion_model.output_blocks.9.1.": "up_blocks.3.attentions.0.", "model.diffusion_model.output_blocks.10.1.": "up_blocks.3.attentions.1.", "model.diffusion_model.output_blocks.11.1.": "up_blocks.3.attentions.2.", } class Zero123(nn.Module): def __init__(self, device, fp16=True, t_range=[0.02, 0.98], model_key="ashawkey/zero123-xl-diffusers"): super().__init__() self.device = device self.fp16 = fp16 self.dtype = torch.float16 if fp16 else torch.float32 self.pipe = Zero123Pipeline.from_pretrained( model_key, trust_remote_code=True, torch_dtype=self.dtype, ).to(self.device) # load weights from the checkpoint ckpt_path = "checkpoints/zero123_6dof_23k.ckpt" print(f'[INFO] loading checkpoint from {ckpt_path} ...') old_state = torch.load(ckpt_path) pretrained_weights = old_state['state_dict']['cc_projection.weight'] pretrained_biases = old_state['state_dict']['cc_projection.bias'] linear_layer = torch.nn.Linear(768 + 18, 768) linear_layer.weight.data = pretrained_weights linear_layer.bias.data = pretrained_biases self.pipe.clip_camera_projection.proj = linear_layer.to(dtype=self.dtype, device=self.device) for name in list(old_state['state_dict'].keys()): for k, v in name_mapping.items(): if k in name: old_state['state_dict'][name.replace(k, name_mapping[k])] = old_state['state_dict'][name].to(dtype=self.dtype, device=self.device) m, u = self.pipe.unet.load_state_dict(old_state['state_dict'], strict=False) # stable-zero123 has a different camera embedding self.use_stable_zero123 = 'stable' in model_key self.pipe.image_encoder.eval() self.pipe.vae.eval() self.pipe.unet.eval() self.pipe.clip_camera_projection.eval() self.vae = self.pipe.vae self.unet = self.pipe.unet self.pipe.set_progress_bar_config(disable=True) self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience self.embeddings = None @torch.no_grad() def get_img_embeds(self, x): # x: image tensor in [0, 1] x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) x_pil = [TF.to_pil_image(image) for image in x] x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) c = self.pipe.image_encoder(x_clip).image_embeds v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor self.embeddings = [c, v] def get_cam_embeddings(self, polar, azimuth, radius, default_elevation=0): if self.use_stable_zero123: T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), np.deg2rad([90 + default_elevation] * len(polar))], axis=-1) else: # original zero123 camera embedding T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) T = torch.from_numpy(T).unsqueeze(1).to(dtype=self.dtype, device=self.device) # [8, 1, 4] return T def get_cam_embeddings_6D(self, target_RT, cond_RT): T_target = torch.from_numpy(target_RT["c2w"]) focal_len_target = torch.from_numpy(target_RT["focal_length"]) T_cond = torch.from_numpy(cond_RT["c2w"]) focal_len_cond = torch.from_numpy(cond_RT["focal_length"]) focal_len = focal_len_target / focal_len_cond d_T = torch.linalg.inv(T_target) @ T_cond d_T = torch.cat([d_T.flatten(), torch.log(focal_len)]) return d_T.unsqueeze(0).unsqueeze(0).to(dtype=self.dtype, device=self.device) @torch.no_grad() def refine(self, pred_rgb, cam_embed, guidance_scale=5, steps=50, strength=0.8, idx=None ): ######## Slight modification ######## if pred_rgb is not None: batch_size = pred_rgb.shape[0] else: batch_size = 1 self.scheduler.set_timesteps(steps) if strength == 0: init_step = 0 latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype) else: init_step = int(steps * strength) pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) T = cam_embed if idx is not None: cc_emb = torch.cat([self.embeddings[0][idx].repeat(batch_size, 1, 1), T], dim=-1) else: cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) cc_emb = self.pipe.clip_camera_projection(cc_emb) cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) if idx is not None: vae_emb = self.embeddings[1][idx].repeat(batch_size, 1, 1, 1) else: vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) for i, t in enumerate(self.scheduler.timesteps[init_step:]): x_in = torch.cat([latents] * 2) t_in = torch.cat([t.view(1)]).to(self.device) noise_pred = self.unet( torch.cat([x_in, vae_emb], dim=1), t_in.to(self.unet.dtype), encoder_hidden_states=cc_emb, ).sample noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents).prev_sample imgs = self.decode_latents(latents) # [1, 3, 256, 256] return imgs def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False): # pred_rgb: tensor [1, 3, H, W] in [0, 1] batch_size = pred_rgb.shape[0] if as_latent: latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 else: pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) with torch.no_grad(): noise = torch.randn_like(latents) latents_noisy = self.scheduler.add_noise(latents, noise, t) x_in = torch.cat([latents_noisy] * 2) t_in = torch.cat([t] * 2) T = self.get_cam_embeddings(polar, azimuth, radius) cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) cc_emb = self.pipe.clip_camera_projection(cc_emb) cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) noise_pred = self.unet( torch.cat([x_in, vae_emb], dim=1), t_in.to(self.unet.dtype), encoder_hidden_states=cc_emb, ).sample noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) target = (latents - grad).detach() loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') return loss def angle_between(self, sph_v1, sph_v2): def sph2cart(sv): r, theta, phi = sv[0], sv[1], sv[2] # The polar representation is different from Stable-DreamFusion return torch.tensor([r * torch.cos(theta) * torch.cos(phi), r * torch.cos(theta) * torch.sin(phi), r * torch.sin(theta)]) def unit_vector(v): return v / torch.linalg.norm(v) def angle_between_2_sph(sv1, sv2): v1, v2 = sph2cart(sv1), sph2cart(sv2) v1_u, v2_u = unit_vector(v1), unit_vector(v2) return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0)) angles = torch.empty(len(sph_v1), len(sph_v2)) for i, sv1 in enumerate(sph_v1): for j, sv2 in enumerate(sph_v2): angles[i][j] = angle_between_2_sph(sv1, sv2) return angles def batch_train_step(self, pred_rgb, target_RT, cond_cams, step_ratio=None, guidance_scale=5, as_latent=False, step=None): # pred_rgb: tensor [1, 3, H, W] in [0, 1] batch_size = pred_rgb.shape[0] if as_latent: latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 else: pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) if step_ratio is not None: # dreamtime-like # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) with torch.no_grad(): noise = torch.randn_like(latents) latents_noisy = self.scheduler.add_noise(latents, noise, t) x_in = torch.cat([latents_noisy] * 2 * self.num_views) t_in = torch.cat([t] * 2 * self.num_views) cc_embs = [] vae_embs = [] noise_preds = [] for idx in range(self.num_views): cond_RT = { "c2w": cond_cams[idx].c2w, "focal_length": cond_cams[idx].focal_length, } T = self.get_cam_embeddings_6D(target_RT, cond_RT) cc_emb = torch.cat([self.embeddings[0][idx].repeat(batch_size, 1, 1), T], dim=-1) cc_emb = self.pipe.clip_camera_projection(cc_emb) cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) vae_emb = self.embeddings[1][idx].repeat(batch_size, 1, 1, 1) vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) cc_embs.append(cc_emb) vae_embs.append(vae_emb) cc_emb = torch.cat(cc_embs, dim=0) vae_emb = torch.cat(vae_embs, dim=0) noise_pred = self.unet( torch.cat([x_in, vae_emb], dim=1), t_in.to(self.unet.dtype), encoder_hidden_states=cc_emb, ).sample noise_pred_chunks = noise_pred.chunk(self.num_views) for idx in range(self.num_views): noise_pred_cond, noise_pred_uncond = noise_pred_chunks[idx][0], noise_pred_chunks[idx][1] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) noise_preds.append(noise_pred) noise_pred = torch.stack(noise_preds).sum(dim=0) / len(noise_preds) # self.num_views # Average over all views grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) target = (latents - grad).detach() loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') return loss def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents imgs = self.vae.decode(latents).sample imgs = (imgs / 2 + 0.5).clamp(0, 1) return imgs def encode_imgs(self, imgs, mode=False): # imgs: [B, 3, H, W] imgs = 2 * imgs - 1 posterior = self.vae.encode(imgs).latent_dist if mode: latents = posterior.mode() else: latents = posterior.sample() latents = latents * self.vae.config.scaling_factor return latents def process_im(im): if im.shape[-1] == 3: if self.bg_remover is None: self.bg_remover = rembg.new_session() im = rembg.remove(im, session=self.bg_remover) im = im.astype(np.float32) / 255.0 input_mask = im[..., 3:] input_img = im[..., :3] * input_mask + (1 - input_mask) input_img = input_img[..., ::-1].copy() image = torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).contiguous().to(device) image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) return image def get_T_6d(target_RT, cond_RT, use_objaverse): if use_objaverse: new_row = torch.tensor([[0., 0., 0., 1.]]) T_target = torch.from_numpy(target_RT) # world to cam matrix T_target = torch.cat((T_target, new_row), dim=0) T_target = torch.linalg.inv(T_target) # Cam to world matrix T_target[:3, :] = T_target[[1, 2, 0]] T_cond = torch.from_numpy(cond_RT) T_cond = torch.cat((T_cond, new_row), dim=0) T_cond = torch.linalg.inv(T_cond) T_cond[:3, :] = T_cond[[1, 2, 0]] focal_len = torch.tensor([1., 1.]) else: T_target = torch.from_numpy(target_RT["c2w"]) focal_len_target = torch.from_numpy(target_RT["focal_length"]) T_cond = torch.from_numpy(cond_RT["c2w"]) focal_len_cond = torch.from_numpy(cond_RT["focal_length"]) focal_len = focal_len_target / focal_len_cond d_T = torch.linalg.inv(T_target) @ T_cond d_T = torch.cat([d_T.flatten(), torch.log(focal_len)]) return d_T