GenPercept / genpercept /genpercept_pipeline.py
guangkaixu
fix
421cb45
raw
history blame
20.4 kB
# --------------------------------------------------------
# What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
# Github source: https://github.com/aim-uofa/GenPercept
# Copyright (c) 2024, Advanced Intelligent Machines (AIM)
# Licensed under The BSD 2-Clause License [see LICENSE for details]
# By Guangkai Xu
# Based on Marigold, diffusers codebases
# https://github.com/prs-eth/marigold
# https://github.com/huggingface/diffusers
# --------------------------------------------------------
import logging
from typing import Dict, Optional, Union
import numpy as np
import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LCMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import BaseOutput
from PIL import Image
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import pil_to_tensor, resize
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from .util.batchsize import find_batch_size
from .util.ensemble import ensemble_depth
from .util.image_util import (
chw2hwc,
colorize_depth_maps,
get_tv_resample_method,
resize_max_res,
)
import matplotlib.pyplot as plt
from genpercept.models.dpt_head import DPTNeckHeadForUnetAfterUpsampleIdentity
class GenPerceptOutput(BaseOutput):
"""
Output class for GenPercept general perception pipeline.
Args:
pred_np (`np.ndarray`):
Predicted result, with values in the range of [0, 1].
pred_colored (`PIL.Image.Image`):
Colorized result, with the shape of [3, H, W] and values in [0, 1].
"""
pred_np: np.ndarray
pred_colored: Union[None, Image.Image]
class GenPerceptPipeline(DiffusionPipeline):
"""
Pipeline for general perception using GenPercept: https://github.com/aim-uofa/GenPercept.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
unet (`UNet2DConditionModel`):
Conditional U-Net to denoise the perception latent, conditioned on image latent.
vae (`AutoencoderKL`):
Variational Auto-Encoder (VAE) Model to encode and decode images and results
to and from latent representations.
scheduler (`DDIMScheduler`):
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
text_encoder (`CLIPTextModel`):
Text-encoder, for empty text embedding.
tokenizer (`CLIPTokenizer`):
CLIP tokenizer.
default_denoising_steps (`int`, *optional*):
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
quality with the given model. This value must be set in the model config. When the pipeline is called
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
reasonable results with various model flavors compatible with the pipeline, such as those relying on very
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
default_processing_resolution (`int`, *optional*):
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
default value is used. This is required to ensure reasonable results with various model flavors trained
with varying optimal processing resolution values.
"""
latent_scale_factor = 0.18215
def __init__(
self,
unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, LCMScheduler],
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
default_denoising_steps: Optional[int] = 10,
default_processing_resolution: Optional[int] = 768,
rgb_blending = False,
customized_head = None,
genpercept_pipeline = True,
):
super().__init__()
self.genpercept_pipeline = genpercept_pipeline
if self.genpercept_pipeline:
default_denoising_steps = 1
rgb_blending = True
self.register_modules(
unet=unet,
customized_head=customized_head,
vae=vae,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
self.register_to_config(
default_denoising_steps=default_denoising_steps,
default_processing_resolution=default_processing_resolution,
rgb_blending=rgb_blending,
)
self.default_denoising_steps = default_denoising_steps
self.default_processing_resolution = default_processing_resolution
self.rgb_blending = rgb_blending
self.text_embed = None
self.customized_head = customized_head
if self.customized_head:
assert self.rgb_blending and self.scheduler.beta_start == 1 and self.scheduler.beta_end == 1
assert self.genpercept_pipeline
@torch.no_grad()
def __call__(
self,
input_image: Union[Image.Image, torch.Tensor],
denoising_steps: Optional[int] = None,
ensemble_size: int = 1,
processing_res: Optional[int] = None,
match_input_res: bool = True,
resample_method: str = "bilinear",
batch_size: int = 0,
generator: Union[torch.Generator, None] = None,
color_map: Union[str, None] = None,
show_progress_bar: bool = True,
ensemble_kwargs: Dict = None,
mode = None,
fix_timesteps = None,
prompt = "",
) -> GenPerceptOutput:
"""
Function invoked when calling the pipeline.
Args:
input_image (`Image`):
Input RGB (or gray-scale) image.
denoising_steps (`int`, *optional*, defaults to `None`):
Number of denoising diffusion steps during inference. The default value `None` results in automatic
selection.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
processing_res (`int`, *optional*, defaults to `None`):
Effective processing resolution. When set to `0`, processes at the original image resolution. This
produces crisper predictions, but may also lead to the overall loss of global context. The default
value `None` resolves to the optimal value from the model config.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize perception result to match input resolution.
Only valid if `processing_res` > 0.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and perception results. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`.
If set to 0, the script will automatically decide the proper batch size.
generator (`torch.Generator`, *optional*, defaults to `None`)
Random generator for initial noise generation.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising.
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized result generation):
Colormap used to colorize the result.
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
Arguments for detailed ensembling settings.
Returns:
`GenPerceptOutput`: Output class for GenPercept general perception pipeline, including:
- **pred_np** (`np.ndarray`) Predicted result, with values in the range of [0, 1]
- **pred_colored** (`PIL.Image.Image`) Colorized result, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
"""
assert mode is not None, "mode of GenPerceptPipeline can be chosen from ['depth', 'normal', 'seg', 'matting', 'dis']."
self.mode = mode
# Model-specific optimal default values leading to fast and reasonable results.
if denoising_steps is None:
denoising_steps = self.default_denoising_steps
if processing_res is None:
processing_res = self.default_processing_resolution
assert processing_res >= 0
assert ensemble_size >= 1
if self.genpercept_pipeline:
assert ensemble_size == 1
assert denoising_steps == 1
else:
# Check if denoising step is reasonable
self._check_inference_step(denoising_steps)
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
# ----------------- Image Preprocess -----------------
# Convert to torch tensor
if isinstance(input_image, Image.Image):
input_image = input_image.convert("RGB")
# convert to torch tensor [H, W, rgb] -> [rgb, H, W]
rgb = pil_to_tensor(input_image)
rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
elif isinstance(input_image, torch.Tensor):
rgb = input_image
else:
raise TypeError(f"Unknown input type: {type(input_image) = }")
input_size = rgb.shape
assert (
4 == rgb.dim() and 3 == input_size[-3]
), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
# Resize image
if processing_res > 0:
rgb = resize_max_res(
rgb,
max_edge_resolution=processing_res,
resample_method=resample_method,
)
# Normalize rgb values
rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
rgb_norm = rgb_norm.to(self.dtype)
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
# ----------------- Perception Inference -----------------
# Batch repeated input image
duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
single_rgb_dataset = TensorDataset(duplicated_rgb)
if batch_size > 0:
_bs = batch_size
else:
_bs = find_batch_size(
ensemble_size=ensemble_size,
input_res=max(rgb_norm.shape[1:]),
dtype=self.dtype,
)
single_rgb_loader = DataLoader(
single_rgb_dataset, batch_size=_bs, shuffle=False
)
# Predict results (batched)
pipe_pred_ls = []
if show_progress_bar:
iterable = tqdm(
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
)
else:
iterable = single_rgb_loader
for batch in iterable:
(batched_img,) = batch
pipe_pred_raw = self.single_infer(
rgb_in=batched_img,
num_inference_steps=denoising_steps,
show_pbar=show_progress_bar,
generator=generator,
fix_timesteps=fix_timesteps,
prompt=prompt,
)
pipe_pred_ls.append(pipe_pred_raw.detach())
pipe_preds = torch.concat(pipe_pred_ls, dim=0)
torch.cuda.empty_cache() # clear vram cache for ensembling
# ----------------- Test-time ensembling -----------------
if ensemble_size > 1:
pipe_pred, _ = ensemble_depth(
pipe_preds,
scale_invariant=True,
shift_invariant=True,
max_res=50,
**(ensemble_kwargs or {}),
)
else:
pipe_pred = pipe_preds
# Resize back to original resolution
if match_input_res:
pipe_pred = resize(
pipe_pred,
input_size[-2:],
interpolation=resample_method,
antialias=True,
)
# Convert to numpy
pipe_pred = pipe_pred.squeeze()
pipe_pred = pipe_pred.cpu().numpy()
# Clip output range
pipe_pred = pipe_pred.clip(0, 1)
# Colorize
if color_map is not None:
assert self.mode in ['depth', 'disparity']
pred_colored = colorize_depth_maps(
pipe_pred, 0, 1, cmap=color_map
).squeeze() # [3, H, W], value in (0, 1)
pred_colored = (pred_colored * 255).astype(np.uint8)
pred_colored_hwc = chw2hwc(pred_colored)
pred_colored_img = Image.fromarray(pred_colored_hwc)
else:
pred_colored_img = (pipe_pred * 255.0).astype(np.uint8)
if len(pred_colored_img.shape) == 3 and pred_colored_img.shape[0] == 3:
pred_colored_img = np.transpose(pred_colored_img, (1, 2, 0))
pred_colored_img = Image.fromarray(pred_colored_img)
if len(pipe_pred.shape) == 3 and pipe_pred.shape[0] == 3:
pipe_pred = np.transpose(pipe_pred, (1, 2, 0))
return GenPerceptOutput(
pred_np=pipe_pred,
pred_colored=pred_colored_img,
)
def _check_inference_step(self, n_step: int) -> None:
"""
Check if denoising step is reasonable
Args:
n_step (`int`): denoising steps
"""
assert n_step >= 1
if isinstance(self.scheduler, DDIMScheduler):
if n_step < 10:
logging.warning(
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
)
elif isinstance(self.scheduler, LCMScheduler):
if not 1 <= n_step <= 4:
logging.warning(
f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps."
)
else:
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
def encode_text(self, prompt):
"""
Encode text embedding for empty prompt
"""
text_inputs = self.tokenizer(
prompt,
padding="do_not_pad",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
self.text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
@torch.no_grad()
def single_infer(
self,
rgb_in: torch.Tensor,
num_inference_steps: int,
generator: Union[torch.Generator, None],
show_pbar: bool,
fix_timesteps = None,
prompt = "",
) -> torch.Tensor:
"""
Perform an individual perception result without ensembling.
Args:
rgb_in (`torch.Tensor`):
Input RGB image.
num_inference_steps (`int`):
Number of diffusion denoisign steps (DDIM) during inference.
show_pbar (`bool`):
Display a progress bar of diffusion denoising.
generator (`torch.Generator`)
Random generator for initial noise generation.
Returns:
`torch.Tensor`: Predicted result.
"""
device = self.device
rgb_in = rgb_in.to(device)
# Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
if fix_timesteps:
timesteps = torch.tensor([fix_timesteps]).long().repeat(self.scheduler.timesteps.shape[0]).to(device)
else:
timesteps = self.scheduler.timesteps # [T]
# Encode image
rgb_latent = self.encode_rgb(rgb_in)
if not (self.rgb_blending or self.genpercept_pipeline):
# Initial result (noise)
pred_latent = torch.randn(
rgb_latent.shape,
device=device,
dtype=self.dtype,
generator=generator,
) # [B, 4, h, w]
else:
pred_latent = rgb_latent
# Batched empty text embedding
if self.text_embed is None:
self.encode_text(prompt)
batch_text_embed = self.text_embed.repeat(
(rgb_latent.shape[0], 1, 1)
).to(device) # [B, 2, 1024]
# Denoising loop
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising",
)
else:
iterable = enumerate(timesteps)
if not self.customized_head:
for i, t in iterable:
if self.genpercept_pipeline and i > 0:
assert ValueError, "GenPercept only forward once."
if not (self.rgb_blending or self.genpercept_pipeline):
unet_input = torch.cat(
[rgb_latent, pred_latent], dim=1
) # this order is important
else:
unet_input = pred_latent
# predict the noise residual
noise_pred = self.unet(
unet_input, t, encoder_hidden_states=batch_text_embed
).sample # [B, 4, h, w]
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(
noise_pred, t, pred_latent, generator=generator
)
pred_latent = step_output.prev_sample
pred_latent = step_output.pred_original_sample # NOTE: for GenPercept, it is equivalent to "pred_latent = - noise_pred"
pred = self.decode_pred(pred_latent)
# clip prediction
pred = torch.clip(pred, -1.0, 1.0)
# shift to [0, 1]
pred = (pred + 1.0) / 2.0
elif isinstance(self.customized_head, DPTNeckHeadForUnetAfterUpsampleIdentity):
unet_input = pred_latent
model_pred_output = self.unet(
unet_input, timesteps, encoder_hidden_states=batch_text_embed, return_feature=True
) # [B, 4, h, w]
unet_features = model_pred_output.multi_level_feats[::-1]
pred = self.customized_head(hidden_states=unet_features).prediction[:, None]
# shift to [0, 1]
pred = (pred - pred.min()) / (pred.max() - pred.min())
else:
raise ValueError
return pred
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
"""
Encode RGB image into latent.
Args:
rgb_in (`torch.Tensor`):
Input RGB image to be encoded.
Returns:
`torch.Tensor`: Image latent.
"""
# encode
h = self.vae.encoder(rgb_in)
moments = self.vae.quant_conv(h)
mean, logvar = torch.chunk(moments, 2, dim=1)
# scale latent
rgb_latent = mean * self.latent_scale_factor
return rgb_latent
def decode_pred(self, pred_latent: torch.Tensor) -> torch.Tensor:
"""
Decode pred latent into result.
Args:
pred_latent (`torch.Tensor`):
pred latent to be decoded.
Returns:
`torch.Tensor`: Decoded result.
"""
# scale latent
pred_latent = pred_latent / self.latent_scale_factor
# decode
z = self.vae.post_quant_conv(pred_latent)
stacked = self.vae.decoder(z)
if self.mode in ['depth', 'matting', 'dis', 'disparity']:
# mean of output channels
stacked = stacked.mean(dim=1, keepdim=True)
return stacked