Spaces:
Running
on
A100
Running
on
A100
import torch | |
import einops | |
import numpy as np | |
import torch.nn as nn | |
from torch import Tensor | |
from functools import partial | |
from torchdiffeq import odeint | |
from unet import UNetModel | |
from diffusers import AutoencoderKL | |
def exists(val): | |
return val is not None | |
class DepthFM(nn.Module): | |
def __init__(self, ckpt_path: str): | |
super().__init__() | |
vae_id = "runwayml/stable-diffusion-v1-5" | |
self.vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae") | |
self.scale_factor = 0.18215 | |
# set with checkpoint | |
ckpt = torch.load(ckpt_path, map_location="cpu") | |
self.noising_step = ckpt['noising_step'] | |
self.empty_text_embed = ckpt['empty_text_embedding'] | |
self.model = UNetModel(**ckpt['ldm_hparams']) | |
self.model.load_state_dict(ckpt['state_dict']) | |
def ode_fn(self, t: Tensor, x: Tensor, **kwargs): | |
if t.numel() == 1: | |
t = t.expand(x.size(0)) | |
return self.model(x=x, t=t, **kwargs) | |
def generate(self, z: Tensor, num_steps: int = 4, n_intermediates: int = 0, **kwargs): | |
""" | |
ODE solving from z0 (ims) to z1 (depth). | |
""" | |
ode_kwargs = dict(method="euler", rtol=1e-5, atol=1e-5, options=dict(step_size=1.0 / num_steps)) | |
# t specifies which intermediate times should the solver return | |
# e.g. t = [0, 0.5, 1] means return the solution at t=0, t=0.5 and t=1 | |
# but it also specifies the number of steps for fixed step size methods | |
t = torch.linspace(0, 1, n_intermediates + 2, device=z.device, dtype=z.dtype) | |
# t = torch.tensor([0., 1.], device=z.device, dtype=z.dtype) | |
# allow conditioning information for model | |
ode_fn = partial(self.ode_fn, **kwargs) | |
ode_results = odeint(ode_fn, z, t, **ode_kwargs) | |
if n_intermediates > 0: | |
return ode_results | |
return ode_results[-1] | |
def forward(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1): | |
""" | |
Args: | |
ims: Tensor of shape (b, 3, h, w) in range [-1, 1] | |
Returns: | |
depth: Tensor of shape (b, 1, h, w) in range [0, 1] | |
""" | |
if ensemble_size > 1: | |
assert ims.shape[0] == 1, "Ensemble mode only supported with batch size 1" | |
ims = ims.repeat(ensemble_size, 1, 1, 1) | |
bs, dev = ims.shape[0], ims.device | |
ims_z = self.encode(ims, sample_posterior=False) | |
conditioning = torch.tensor(self.empty_text_embed).to(dev).repeat(bs, 1, 1) | |
context = ims_z | |
x_source = ims_z | |
if self.noising_step > 0: | |
x_source = q_sample(x_source, self.noising_step) | |
# solve ODE | |
depth_z = self.generate(x_source, num_steps=num_steps, context=context, context_ca=conditioning) | |
depth = self.decode(depth_z) | |
depth = depth.mean(dim=1, keepdim=True) | |
if ensemble_size > 1: | |
depth = depth.mean(dim=0, keepdim=True) | |
# normalize depth maps to range [-1, 1] | |
depth = per_sample_min_max_normalization(depth.exp()) | |
return depth | |
def predict_depth(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1): | |
""" Inference method for DepthFM. """ | |
return self.forward(ims, num_steps, ensemble_size) | |
def encode(self, x: Tensor, sample_posterior: bool = True): | |
posterior = self.vae.encode(x) | |
if sample_posterior: | |
z = posterior.latent_dist.sample() | |
else: | |
z = posterior.latent_dist.mode() | |
# normalize latent code | |
z = z * self.scale_factor | |
return z | |
def decode(self, z: Tensor): | |
z = 1.0 / self.scale_factor * z | |
return self.vae.decode(z).sample | |
def sigmoid(x): | |
return 1 / (1 + np.exp(-x)) | |
def cosine_log_snr(t, eps=0.00001): | |
""" | |
Returns log Signal-to-Noise ratio for time step t and image size 64 | |
eps: avoid division by zero | |
""" | |
return -2 * np.log(np.tan((np.pi * t) / 2) + eps) | |
def cosine_alpha_bar(t): | |
return sigmoid(cosine_log_snr(t)) | |
def q_sample(x_start: torch.Tensor, t: int, noise: torch.Tensor = None, n_diffusion_timesteps: int = 1000): | |
""" | |
Diffuse the data for a given number of diffusion steps. In other | |
words sample from q(x_t | x_0). | |
""" | |
dev = x_start.device | |
dtype = x_start.dtype | |
if noise is None: | |
noise = torch.randn_like(x_start) | |
alpha_bar_t = cosine_alpha_bar(t / n_diffusion_timesteps) | |
alpha_bar_t = torch.tensor(alpha_bar_t).to(dev).to(dtype) | |
return torch.sqrt(alpha_bar_t) * x_start + torch.sqrt(1 - alpha_bar_t) * noise | |
def per_sample_min_max_normalization(x): | |
""" Normalize each sample in a batch independently | |
with min-max normalization to [0, 1] """ | |
bs, *shape = x.shape | |
x_ = einops.rearrange(x, "b ... -> b (...)") | |
min_val = einops.reduce(x_, "b ... -> b", "min")[..., None] | |
max_val = einops.reduce(x_, "b ... -> b", "max")[..., None] | |
x_ = (x_ - min_val) / (max_val - min_val) | |
return x_.reshape(bs, *shape) | |