amos1088's picture
uuu
5a35e98
import torch
import einops
import numpy as np
import torch.nn as nn
from torch import Tensor
from functools import partial
from torchdiffeq import odeint
from unet import UNetModel
from diffusers import AutoencoderKL
def exists(val):
return val is not None
class DepthFM(nn.Module):
def __init__(self, ckpt_path: str):
super().__init__()
vae_id = "runwayml/stable-diffusion-v1-5"
self.vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae")
self.scale_factor = 0.18215
# set with checkpoint
ckpt = torch.load(ckpt_path, map_location="cpu")
self.noising_step = ckpt['noising_step']
self.empty_text_embed = ckpt['empty_text_embedding']
self.model = UNetModel(**ckpt['ldm_hparams'])
self.model.load_state_dict(ckpt['state_dict'])
def ode_fn(self, t: Tensor, x: Tensor, **kwargs):
if t.numel() == 1:
t = t.expand(x.size(0))
return self.model(x=x, t=t, **kwargs)
def generate(self, z: Tensor, num_steps: int = 4, n_intermediates: int = 0, **kwargs):
"""
ODE solving from z0 (ims) to z1 (depth).
"""
ode_kwargs = dict(method="euler", rtol=1e-5, atol=1e-5, options=dict(step_size=1.0 / num_steps))
# t specifies which intermediate times should the solver return
# e.g. t = [0, 0.5, 1] means return the solution at t=0, t=0.5 and t=1
# but it also specifies the number of steps for fixed step size methods
t = torch.linspace(0, 1, n_intermediates + 2, device=z.device, dtype=z.dtype)
# t = torch.tensor([0., 1.], device=z.device, dtype=z.dtype)
# allow conditioning information for model
ode_fn = partial(self.ode_fn, **kwargs)
ode_results = odeint(ode_fn, z, t, **ode_kwargs)
if n_intermediates > 0:
return ode_results
return ode_results[-1]
def forward(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1):
"""
Args:
ims: Tensor of shape (b, 3, h, w) in range [-1, 1]
Returns:
depth: Tensor of shape (b, 1, h, w) in range [0, 1]
"""
if ensemble_size > 1:
assert ims.shape[0] == 1, "Ensemble mode only supported with batch size 1"
ims = ims.repeat(ensemble_size, 1, 1, 1)
bs, dev = ims.shape[0], ims.device
ims_z = self.encode(ims, sample_posterior=False)
conditioning = torch.tensor(self.empty_text_embed).to(dev).repeat(bs, 1, 1)
context = ims_z
x_source = ims_z
if self.noising_step > 0:
x_source = q_sample(x_source, self.noising_step)
# solve ODE
depth_z = self.generate(x_source, num_steps=num_steps, context=context, context_ca=conditioning)
depth = self.decode(depth_z)
depth = depth.mean(dim=1, keepdim=True)
if ensemble_size > 1:
depth = depth.mean(dim=0, keepdim=True)
# normalize depth maps to range [-1, 1]
depth = per_sample_min_max_normalization(depth.exp())
return depth
@torch.no_grad()
def predict_depth(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1):
""" Inference method for DepthFM. """
return self.forward(ims, num_steps, ensemble_size)
@torch.no_grad()
def encode(self, x: Tensor, sample_posterior: bool = True):
posterior = self.vae.encode(x)
if sample_posterior:
z = posterior.latent_dist.sample()
else:
z = posterior.latent_dist.mode()
# normalize latent code
z = z * self.scale_factor
return z
@torch.no_grad()
def decode(self, z: Tensor):
z = 1.0 / self.scale_factor * z
return self.vae.decode(z).sample
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def cosine_log_snr(t, eps=0.00001):
"""
Returns log Signal-to-Noise ratio for time step t and image size 64
eps: avoid division by zero
"""
return -2 * np.log(np.tan((np.pi * t) / 2) + eps)
def cosine_alpha_bar(t):
return sigmoid(cosine_log_snr(t))
def q_sample(x_start: torch.Tensor, t: int, noise: torch.Tensor = None, n_diffusion_timesteps: int = 1000):
"""
Diffuse the data for a given number of diffusion steps. In other
words sample from q(x_t | x_0).
"""
dev = x_start.device
dtype = x_start.dtype
if noise is None:
noise = torch.randn_like(x_start)
alpha_bar_t = cosine_alpha_bar(t / n_diffusion_timesteps)
alpha_bar_t = torch.tensor(alpha_bar_t).to(dev).to(dtype)
return torch.sqrt(alpha_bar_t) * x_start + torch.sqrt(1 - alpha_bar_t) * noise
def per_sample_min_max_normalization(x):
""" Normalize each sample in a batch independently
with min-max normalization to [0, 1] """
bs, *shape = x.shape
x_ = einops.rearrange(x, "b ... -> b (...)")
min_val = einops.reduce(x_, "b ... -> b", "min")[..., None]
max_val = einops.reduce(x_, "b ... -> b", "max")[..., None]
x_ = (x_ - min_val) / (max_val - min_val)
return x_.reshape(bs, *shape)