|
from diffusers import DDIMScheduler |
|
import torchvision.transforms.functional as TF |
|
|
|
import numpy as np |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
from torchvision.utils import save_image |
|
from torchvision import transforms |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
|
|
import sys |
|
sys.path.append('./') |
|
|
|
from sparseags.guidance_utils.zero123 import Zero123Pipeline |
|
|
|
|
|
name_mapping = { |
|
"model.diffusion_model.input_blocks.1.1.": "down_blocks.0.attentions.0.", |
|
"model.diffusion_model.input_blocks.2.1.": "down_blocks.0.attentions.1.", |
|
"model.diffusion_model.input_blocks.4.1.": "down_blocks.1.attentions.0.", |
|
"model.diffusion_model.input_blocks.5.1.": "down_blocks.1.attentions.1.", |
|
"model.diffusion_model.input_blocks.7.1.": "down_blocks.2.attentions.0.", |
|
"model.diffusion_model.input_blocks.8.1.": "down_blocks.2.attentions.1.", |
|
"model.diffusion_model.middle_block.1.": "mid_block.attentions.0.", |
|
"model.diffusion_model.output_blocks.3.1.": "up_blocks.1.attentions.0.", |
|
"model.diffusion_model.output_blocks.4.1.": "up_blocks.1.attentions.1.", |
|
"model.diffusion_model.output_blocks.5.1.": "up_blocks.1.attentions.2.", |
|
"model.diffusion_model.output_blocks.6.1.": "up_blocks.2.attentions.0.", |
|
"model.diffusion_model.output_blocks.7.1.": "up_blocks.2.attentions.1.", |
|
"model.diffusion_model.output_blocks.8.1.": "up_blocks.2.attentions.2.", |
|
"model.diffusion_model.output_blocks.9.1.": "up_blocks.3.attentions.0.", |
|
"model.diffusion_model.output_blocks.10.1.": "up_blocks.3.attentions.1.", |
|
"model.diffusion_model.output_blocks.11.1.": "up_blocks.3.attentions.2.", |
|
} |
|
|
|
class Zero123(nn.Module): |
|
def __init__(self, device, fp16=True, t_range=[0.02, 0.98], model_key="ashawkey/zero123-xl-diffusers"): |
|
super().__init__() |
|
|
|
self.device = device |
|
self.fp16 = fp16 |
|
self.dtype = torch.float16 if fp16 else torch.float32 |
|
|
|
self.pipe = Zero123Pipeline.from_pretrained( |
|
model_key, |
|
trust_remote_code=True, |
|
torch_dtype=self.dtype, |
|
).to(self.device) |
|
|
|
|
|
ckpt_path = "checkpoints/zero123_6dof_23k.ckpt" |
|
print(f'[INFO] loading checkpoint from {ckpt_path} ...') |
|
old_state = torch.load(ckpt_path) |
|
pretrained_weights = old_state['state_dict']['cc_projection.weight'] |
|
pretrained_biases = old_state['state_dict']['cc_projection.bias'] |
|
linear_layer = torch.nn.Linear(768 + 18, 768) |
|
linear_layer.weight.data = pretrained_weights |
|
linear_layer.bias.data = pretrained_biases |
|
self.pipe.clip_camera_projection.proj = linear_layer.to(dtype=self.dtype, device=self.device) |
|
|
|
for name in list(old_state['state_dict'].keys()): |
|
for k, v in name_mapping.items(): |
|
if k in name: |
|
old_state['state_dict'][name.replace(k, name_mapping[k])] = old_state['state_dict'][name].to(dtype=self.dtype, device=self.device) |
|
|
|
m, u = self.pipe.unet.load_state_dict(old_state['state_dict'], strict=False) |
|
|
|
|
|
self.use_stable_zero123 = 'stable' in model_key |
|
|
|
self.pipe.image_encoder.eval() |
|
self.pipe.vae.eval() |
|
self.pipe.unet.eval() |
|
self.pipe.clip_camera_projection.eval() |
|
|
|
self.vae = self.pipe.vae |
|
self.unet = self.pipe.unet |
|
|
|
self.pipe.set_progress_bar_config(disable=True) |
|
|
|
self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
self.num_train_timesteps = self.scheduler.config.num_train_timesteps |
|
|
|
self.min_step = int(self.num_train_timesteps * t_range[0]) |
|
self.max_step = int(self.num_train_timesteps * t_range[1]) |
|
self.alphas = self.scheduler.alphas_cumprod.to(self.device) |
|
|
|
self.embeddings = None |
|
|
|
@torch.no_grad() |
|
def get_img_embeds(self, x): |
|
|
|
x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) |
|
x_pil = [TF.to_pil_image(image) for image in x] |
|
x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) |
|
c = self.pipe.image_encoder(x_clip).image_embeds |
|
v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor |
|
self.embeddings = [c, v] |
|
|
|
def get_cam_embeddings(self, polar, azimuth, radius, default_elevation=0): |
|
if self.use_stable_zero123: |
|
T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), np.deg2rad([90 + default_elevation] * len(polar))], axis=-1) |
|
else: |
|
|
|
T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) |
|
T = torch.from_numpy(T).unsqueeze(1).to(dtype=self.dtype, device=self.device) |
|
return T |
|
|
|
def get_cam_embeddings_6D(self, target_RT, cond_RT): |
|
T_target = torch.from_numpy(target_RT["c2w"]) |
|
focal_len_target = torch.from_numpy(target_RT["focal_length"]) |
|
|
|
T_cond = torch.from_numpy(cond_RT["c2w"]) |
|
focal_len_cond = torch.from_numpy(cond_RT["focal_length"]) |
|
|
|
focal_len = focal_len_target / focal_len_cond |
|
|
|
d_T = torch.linalg.inv(T_target) @ T_cond |
|
d_T = torch.cat([d_T.flatten(), torch.log(focal_len)]) |
|
return d_T.unsqueeze(0).unsqueeze(0).to(dtype=self.dtype, device=self.device) |
|
|
|
@torch.no_grad() |
|
def refine(self, pred_rgb, cam_embed, |
|
guidance_scale=5, steps=50, strength=0.8, idx=None |
|
): |
|
|
|
|
|
if pred_rgb is not None: |
|
batch_size = pred_rgb.shape[0] |
|
else: |
|
batch_size = 1 |
|
|
|
self.scheduler.set_timesteps(steps) |
|
|
|
if strength == 0: |
|
init_step = 0 |
|
latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype) |
|
else: |
|
init_step = int(steps * strength) |
|
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) |
|
latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) |
|
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) |
|
|
|
T = cam_embed |
|
if idx is not None: |
|
cc_emb = torch.cat([self.embeddings[0][idx].repeat(batch_size, 1, 1), T], dim=-1) |
|
else: |
|
cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) |
|
cc_emb = self.pipe.clip_camera_projection(cc_emb) |
|
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) |
|
|
|
if idx is not None: |
|
vae_emb = self.embeddings[1][idx].repeat(batch_size, 1, 1, 1) |
|
else: |
|
vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) |
|
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) |
|
|
|
for i, t in enumerate(self.scheduler.timesteps[init_step:]): |
|
|
|
x_in = torch.cat([latents] * 2) |
|
t_in = torch.cat([t.view(1)]).to(self.device) |
|
|
|
noise_pred = self.unet( |
|
torch.cat([x_in, vae_emb], dim=1), |
|
t_in.to(self.unet.dtype), |
|
encoder_hidden_states=cc_emb, |
|
).sample |
|
|
|
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
imgs = self.decode_latents(latents) |
|
return imgs |
|
|
|
def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False): |
|
|
|
|
|
batch_size = pred_rgb.shape[0] |
|
|
|
if as_latent: |
|
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 |
|
else: |
|
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) |
|
latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) |
|
|
|
if step_ratio is not None: |
|
|
|
|
|
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) |
|
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) |
|
else: |
|
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) |
|
|
|
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) |
|
|
|
with torch.no_grad(): |
|
noise = torch.randn_like(latents) |
|
latents_noisy = self.scheduler.add_noise(latents, noise, t) |
|
|
|
x_in = torch.cat([latents_noisy] * 2) |
|
t_in = torch.cat([t] * 2) |
|
|
|
T = self.get_cam_embeddings(polar, azimuth, radius) |
|
cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) |
|
cc_emb = self.pipe.clip_camera_projection(cc_emb) |
|
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) |
|
|
|
vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) |
|
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) |
|
|
|
noise_pred = self.unet( |
|
torch.cat([x_in, vae_emb], dim=1), |
|
t_in.to(self.unet.dtype), |
|
encoder_hidden_states=cc_emb, |
|
).sample |
|
|
|
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
grad = w * (noise_pred - noise) |
|
grad = torch.nan_to_num(grad) |
|
|
|
target = (latents - grad).detach() |
|
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') |
|
|
|
return loss |
|
|
|
def angle_between(self, sph_v1, sph_v2): |
|
def sph2cart(sv): |
|
r, theta, phi = sv[0], sv[1], sv[2] |
|
|
|
return torch.tensor([r * torch.cos(theta) * torch.cos(phi), r * torch.cos(theta) * torch.sin(phi), r * torch.sin(theta)]) |
|
def unit_vector(v): |
|
return v / torch.linalg.norm(v) |
|
def angle_between_2_sph(sv1, sv2): |
|
v1, v2 = sph2cart(sv1), sph2cart(sv2) |
|
v1_u, v2_u = unit_vector(v1), unit_vector(v2) |
|
return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0)) |
|
angles = torch.empty(len(sph_v1), len(sph_v2)) |
|
for i, sv1 in enumerate(sph_v1): |
|
for j, sv2 in enumerate(sph_v2): |
|
angles[i][j] = angle_between_2_sph(sv1, sv2) |
|
return angles |
|
|
|
def batch_train_step(self, pred_rgb, target_RT, cond_cams, step_ratio=None, guidance_scale=5, as_latent=False, step=None): |
|
|
|
|
|
batch_size = pred_rgb.shape[0] |
|
|
|
if as_latent: |
|
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 |
|
else: |
|
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) |
|
latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) |
|
|
|
if step_ratio is not None: |
|
|
|
|
|
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) |
|
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) |
|
else: |
|
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) |
|
|
|
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) |
|
|
|
with torch.no_grad(): |
|
noise = torch.randn_like(latents) |
|
latents_noisy = self.scheduler.add_noise(latents, noise, t) |
|
|
|
x_in = torch.cat([latents_noisy] * 2 * self.num_views) |
|
t_in = torch.cat([t] * 2 * self.num_views) |
|
|
|
cc_embs = [] |
|
vae_embs = [] |
|
noise_preds = [] |
|
for idx in range(self.num_views): |
|
cond_RT = { |
|
"c2w": cond_cams[idx].c2w, |
|
"focal_length": cond_cams[idx].focal_length, |
|
} |
|
T = self.get_cam_embeddings_6D(target_RT, cond_RT) |
|
cc_emb = torch.cat([self.embeddings[0][idx].repeat(batch_size, 1, 1), T], dim=-1) |
|
cc_emb = self.pipe.clip_camera_projection(cc_emb) |
|
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) |
|
|
|
vae_emb = self.embeddings[1][idx].repeat(batch_size, 1, 1, 1) |
|
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) |
|
|
|
cc_embs.append(cc_emb) |
|
vae_embs.append(vae_emb) |
|
|
|
cc_emb = torch.cat(cc_embs, dim=0) |
|
vae_emb = torch.cat(vae_embs, dim=0) |
|
noise_pred = self.unet( |
|
torch.cat([x_in, vae_emb], dim=1), |
|
t_in.to(self.unet.dtype), |
|
encoder_hidden_states=cc_emb, |
|
).sample |
|
|
|
noise_pred_chunks = noise_pred.chunk(self.num_views) |
|
for idx in range(self.num_views): |
|
noise_pred_cond, noise_pred_uncond = noise_pred_chunks[idx][0], noise_pred_chunks[idx][1] |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
noise_preds.append(noise_pred) |
|
|
|
noise_pred = torch.stack(noise_preds).sum(dim=0) / len(noise_preds) |
|
|
|
grad = w * (noise_pred - noise) |
|
grad = torch.nan_to_num(grad) |
|
|
|
target = (latents - grad).detach() |
|
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') |
|
|
|
return loss |
|
|
|
def decode_latents(self, latents): |
|
latents = 1 / self.vae.config.scaling_factor * latents |
|
|
|
imgs = self.vae.decode(latents).sample |
|
imgs = (imgs / 2 + 0.5).clamp(0, 1) |
|
|
|
return imgs |
|
|
|
def encode_imgs(self, imgs, mode=False): |
|
|
|
|
|
imgs = 2 * imgs - 1 |
|
|
|
posterior = self.vae.encode(imgs).latent_dist |
|
if mode: |
|
latents = posterior.mode() |
|
else: |
|
latents = posterior.sample() |
|
latents = latents * self.vae.config.scaling_factor |
|
|
|
return latents |
|
|
|
|
|
def process_im(im): |
|
if im.shape[-1] == 3: |
|
if self.bg_remover is None: |
|
self.bg_remover = rembg.new_session() |
|
im = rembg.remove(im, session=self.bg_remover) |
|
|
|
im = im.astype(np.float32) / 255.0 |
|
|
|
input_mask = im[..., 3:] |
|
input_img = im[..., :3] * input_mask + (1 - input_mask) |
|
input_img = input_img[..., ::-1].copy() |
|
image = torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).contiguous().to(device) |
|
image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) |
|
|
|
return image |
|
|
|
|
|
def get_T_6d(target_RT, cond_RT, use_objaverse): |
|
if use_objaverse: |
|
new_row = torch.tensor([[0., 0., 0., 1.]]) |
|
|
|
T_target = torch.from_numpy(target_RT) |
|
T_target = torch.cat((T_target, new_row), dim=0) |
|
T_target = torch.linalg.inv(T_target) |
|
T_target[:3, :] = T_target[[1, 2, 0]] |
|
|
|
T_cond = torch.from_numpy(cond_RT) |
|
T_cond = torch.cat((T_cond, new_row), dim=0) |
|
T_cond = torch.linalg.inv(T_cond) |
|
T_cond[:3, :] = T_cond[[1, 2, 0]] |
|
|
|
focal_len = torch.tensor([1., 1.]) |
|
|
|
else: |
|
T_target = torch.from_numpy(target_RT["c2w"]) |
|
focal_len_target = torch.from_numpy(target_RT["focal_length"]) |
|
|
|
T_cond = torch.from_numpy(cond_RT["c2w"]) |
|
focal_len_cond = torch.from_numpy(cond_RT["focal_length"]) |
|
|
|
focal_len = focal_len_target / focal_len_cond |
|
|
|
d_T = torch.linalg.inv(T_target) @ T_cond |
|
d_T = torch.cat([d_T.flatten(), torch.log(focal_len)]) |
|
return d_T |