Spaces:
Paused
Paused
File size: 5,172 Bytes
5a35e98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import torch
import einops
import numpy as np
import torch.nn as nn
from torch import Tensor
from functools import partial
from torchdiffeq import odeint
from unet import UNetModel
from diffusers import AutoencoderKL
def exists(val):
return val is not None
class DepthFM(nn.Module):
def __init__(self, ckpt_path: str):
super().__init__()
vae_id = "runwayml/stable-diffusion-v1-5"
self.vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae")
self.scale_factor = 0.18215
# set with checkpoint
ckpt = torch.load(ckpt_path, map_location="cpu")
self.noising_step = ckpt['noising_step']
self.empty_text_embed = ckpt['empty_text_embedding']
self.model = UNetModel(**ckpt['ldm_hparams'])
self.model.load_state_dict(ckpt['state_dict'])
def ode_fn(self, t: Tensor, x: Tensor, **kwargs):
if t.numel() == 1:
t = t.expand(x.size(0))
return self.model(x=x, t=t, **kwargs)
def generate(self, z: Tensor, num_steps: int = 4, n_intermediates: int = 0, **kwargs):
"""
ODE solving from z0 (ims) to z1 (depth).
"""
ode_kwargs = dict(method="euler", rtol=1e-5, atol=1e-5, options=dict(step_size=1.0 / num_steps))
# t specifies which intermediate times should the solver return
# e.g. t = [0, 0.5, 1] means return the solution at t=0, t=0.5 and t=1
# but it also specifies the number of steps for fixed step size methods
t = torch.linspace(0, 1, n_intermediates + 2, device=z.device, dtype=z.dtype)
# t = torch.tensor([0., 1.], device=z.device, dtype=z.dtype)
# allow conditioning information for model
ode_fn = partial(self.ode_fn, **kwargs)
ode_results = odeint(ode_fn, z, t, **ode_kwargs)
if n_intermediates > 0:
return ode_results
return ode_results[-1]
def forward(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1):
"""
Args:
ims: Tensor of shape (b, 3, h, w) in range [-1, 1]
Returns:
depth: Tensor of shape (b, 1, h, w) in range [0, 1]
"""
if ensemble_size > 1:
assert ims.shape[0] == 1, "Ensemble mode only supported with batch size 1"
ims = ims.repeat(ensemble_size, 1, 1, 1)
bs, dev = ims.shape[0], ims.device
ims_z = self.encode(ims, sample_posterior=False)
conditioning = torch.tensor(self.empty_text_embed).to(dev).repeat(bs, 1, 1)
context = ims_z
x_source = ims_z
if self.noising_step > 0:
x_source = q_sample(x_source, self.noising_step)
# solve ODE
depth_z = self.generate(x_source, num_steps=num_steps, context=context, context_ca=conditioning)
depth = self.decode(depth_z)
depth = depth.mean(dim=1, keepdim=True)
if ensemble_size > 1:
depth = depth.mean(dim=0, keepdim=True)
# normalize depth maps to range [-1, 1]
depth = per_sample_min_max_normalization(depth.exp())
return depth
@torch.no_grad()
def predict_depth(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1):
""" Inference method for DepthFM. """
return self.forward(ims, num_steps, ensemble_size)
@torch.no_grad()
def encode(self, x: Tensor, sample_posterior: bool = True):
posterior = self.vae.encode(x)
if sample_posterior:
z = posterior.latent_dist.sample()
else:
z = posterior.latent_dist.mode()
# normalize latent code
z = z * self.scale_factor
return z
@torch.no_grad()
def decode(self, z: Tensor):
z = 1.0 / self.scale_factor * z
return self.vae.decode(z).sample
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def cosine_log_snr(t, eps=0.00001):
"""
Returns log Signal-to-Noise ratio for time step t and image size 64
eps: avoid division by zero
"""
return -2 * np.log(np.tan((np.pi * t) / 2) + eps)
def cosine_alpha_bar(t):
return sigmoid(cosine_log_snr(t))
def q_sample(x_start: torch.Tensor, t: int, noise: torch.Tensor = None, n_diffusion_timesteps: int = 1000):
"""
Diffuse the data for a given number of diffusion steps. In other
words sample from q(x_t | x_0).
"""
dev = x_start.device
dtype = x_start.dtype
if noise is None:
noise = torch.randn_like(x_start)
alpha_bar_t = cosine_alpha_bar(t / n_diffusion_timesteps)
alpha_bar_t = torch.tensor(alpha_bar_t).to(dev).to(dtype)
return torch.sqrt(alpha_bar_t) * x_start + torch.sqrt(1 - alpha_bar_t) * noise
def per_sample_min_max_normalization(x):
""" Normalize each sample in a batch independently
with min-max normalization to [0, 1] """
bs, *shape = x.shape
x_ = einops.rearrange(x, "b ... -> b (...)")
min_val = einops.reduce(x_, "b ... -> b", "min")[..., None]
max_val = einops.reduce(x_, "b ... -> b", "max")[..., None]
x_ = (x_ - min_val) / (max_val - min_val)
return x_.reshape(bs, *shape)
|