Spaces:
Build error
Build error
import math | |
import fire | |
import gradio as gr | |
import numpy as np | |
import rich | |
import torch | |
from contextlib import nullcontext | |
from einops import rearrange | |
from functools import partial | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.util import load_and_preprocess, instantiate_from_config | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from rich import print | |
from torch import autocast | |
from torchvision import transforms | |
_SHOW_INTERMEDIATE = True | |
_GPU_INDEX = 0 | |
# _GPU_INDEX = 2 | |
def load_model_from_config(config, ckpt, device, verbose=False): | |
print(f'Loading model from {ckpt}') | |
pl_sd = torch.load(ckpt, map_location=device) | |
if 'global_step' in pl_sd: | |
print(f'Global Step: {pl_sd["global_step"]}') | |
sd = pl_sd['state_dict'] | |
model = instantiate_from_config(config.model) | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0 and verbose: | |
print('missing keys:') | |
print(m) | |
if len(u) > 0 and verbose: | |
print('unexpected keys:') | |
print(u) | |
model.to(device) | |
model.eval() | |
return model | |
def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, | |
ddim_eta, x, y, z): | |
precision_scope = autocast if precision == 'autocast' else nullcontext | |
with precision_scope('cuda'): | |
with model.ema_scope(): | |
c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1) | |
T = torch.tensor([math.radians(x), math.sin( | |
math.radians(y)), math.cos(math.radians(y)), z]) | |
T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device) | |
c = torch.cat([c, T], dim=-1) | |
c = model.cc_projection(c) | |
cond = {} | |
cond['c_crossattn'] = [c] | |
c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach() | |
cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach() | |
.repeat(n_samples, 1, 1, 1)] | |
if scale != 1.0: | |
uc = {} | |
uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)] | |
uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] | |
else: | |
uc = None | |
shape = [4, h // 8, w // 8] | |
samples_ddim, _ = sampler.sample(S=ddim_steps, | |
conditioning=cond, | |
batch_size=n_samples, | |
shape=shape, | |
verbose=False, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=uc, | |
eta=ddim_eta, | |
x_T=None) | |
print(samples_ddim.shape) | |
# samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False) | |
x_samples_ddim = model.decode_first_stage(samples_ddim) | |
return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu() | |
def main( | |
model, | |
device, | |
input_im, | |
preprocess=True, | |
x=0., | |
y=0., | |
z=0., | |
scale=3.0, | |
n_samples=4, | |
ddim_steps=50, | |
ddim_eta=1.0, | |
precision='fp32', | |
h=256, | |
w=256, | |
): | |
# input_im[input_im == [0., 0., 0.]] = [1., 1., 1., 1.] | |
print('old input_im:', input_im.size) | |
if preprocess: | |
input_im = load_and_preprocess(input_im) | |
input_im = (input_im / 255.0).astype(np.float32) | |
# (H, W, 3) array in [0, 1]. | |
else: | |
input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS) | |
input_im = np.asarray(input_im, dtype=np.float32) / 255.0 | |
# (H, W, 4) array in [0, 1]. | |
# old method: very important, thresholding background | |
# input_im[input_im[:, :, -1] <= 0.9] = [1., 1., 1., 1.] | |
# new method: apply correct method of compositing to avoid sudden transitions / thresholding | |
# (smoothly transition foreground to white background based on alpha values) | |
alpha = input_im[:, :, 3:4] | |
white_im = np.ones_like(input_im) | |
input_im = alpha * input_im + (1.0 - alpha) * white_im | |
input_im = input_im[:, :, 0:3] | |
# (H, W, 3) array in [0, 1]. | |
print('new input_im:', input_im.shape, input_im.dtype, input_im.min(), input_im.max()) | |
show_in_im = Image.fromarray((input_im * 255).astype(np.uint8)) | |
input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device) | |
input_im = input_im * 2 - 1 | |
input_im = transforms.functional.resize(input_im, [h, w]) | |
sampler = DDIMSampler(model) | |
x_samples_ddim = sample_model(input_im, model, sampler, precision, h, w, | |
ddim_steps, n_samples, scale, ddim_eta, x, y, z) | |
output_ims = [] | |
for x_sample in x_samples_ddim: | |
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') | |
output_ims.append(Image.fromarray(x_sample.astype(np.uint8))) | |
if _SHOW_INTERMEDIATE: | |
return (output_ims, show_in_im) | |
else: | |
return output_ims | |
description = ''' | |
Generate novel viewpoints of an object depicted in one input image using a fine-tuned version of Stable Diffusion. | |
''' | |
article = ''' | |
## How to use this? | |
TBD | |
## How does this work? | |
TBD | |
''' | |
def run_demo( | |
device_idx=_GPU_INDEX, | |
ckpt='last.ckpt', | |
config='configs/sd-objaverse-finetune-c_concat-256.yaml', | |
): | |
device = f'cuda:{device_idx}' | |
config = OmegaConf.load(config) | |
model = load_model_from_config(config, ckpt, device=device) | |
inputs = [ | |
gr.Image(type='pil', image_mode='RGBA'), # shape=[512, 512] | |
gr.Checkbox(True, label='Preprocess image (remove background and center)', | |
info='If enabled, the uploaded image will be preprocessed to remove the background and center the object by cropping and/or padding as necessary. ' | |
'If disabled, the image will be used as-is, *BUT* a fully transparent or white background is required.'), | |
# gr.Number(label='polar (between axis z+)'), | |
# gr.Number(label='azimuth (between axis x+)'), | |
# gr.Number(label='z (distance from center)'), | |
gr.Slider(-90, 90, value=0, step=5, label='Polar angle (vertical rotation in degrees)', | |
info='Positive values move the camera down, while negative values move the camera up.'), | |
gr.Slider(-90, 90, value=0, step=5, label='Azimuth angle (horizontal rotation in degrees)', | |
info='Positive values move the camera right, while negative values move the camera left.'), | |
gr.Slider(-2, 2, value=0, step=0.5, label='Radius (distance from center)', | |
info='Positive values move the camera further away, while negative values move the camera closer.'), | |
gr.Slider(0, 30, value=3, step=1, label='cfg scale'), | |
gr.Slider(1, 8, value=4, step=1, label='Number of samples to generate'), | |
gr.Slider(5, 200, value=100, step=5, label='Number of steps'), | |
] | |
output = [gr.Gallery(label='Generated images from specified new viewpoint')] | |
output[0].style(grid=2) | |
if _SHOW_INTERMEDIATE: | |
output += [gr.Image(type='pil', image_mode='RGB', label='Preprocessed input image')] | |
fn_with_model = partial(main, model, device) | |
fn_with_model.__name__ = 'fn_with_model' | |
examples = [ | |
# ['assets/zero-shot/bear.png', 0, 0, 0, 3, 4, 100], | |
# ['assets/zero-shot/car.png', 0, 0, 0, 3, 4, 100], | |
# ['assets/zero-shot/elephant.png', 0, 0, 0, 3, 4, 100], | |
# ['assets/zero-shot/pikachu.png', 0, 0, 0, 3, 4, 100], | |
# ['assets/zero-shot/spyro.png', 0, 0, 0, 3, 4, 100], | |
# ['assets/zero-shot/taxi.png', 0, 0, 0, 3, 4, 100], | |
] | |
demo = gr.Interface( | |
fn=fn_with_model, | |
title='Demo for Zero-Shot Control of Camera Viewpoints within a Single Image', | |
description=description, | |
article=article, | |
inputs=inputs, | |
outputs=output, | |
examples=examples, | |
allow_flagging='never', | |
) | |
demo.launch(enable_queue=True, share=True) | |
if __name__ == '__main__': | |
fire.Fire(run_demo) | |