LiDAR-Diffusion / sample_cond.py
Hancy's picture
modify on ZeroGPU
1615664
raw
history blame
2.89 kB
import os
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from lidm.models.diffusion.ddim import DDIMSampler
from lidm.utils.misc_utils import instantiate_from_config
from lidm.utils.lidar_utils import range2pcd
CUSTOM_STEPS = 50
ETA = 1.0
# model loading
MODEL_PATH = './models/lidm/kitti/cam2lidar'
CFG_PATH = os.path.join(MODEL_PATH, 'config.yaml')
CKPT_PATH = os.path.join(MODEL_PATH, 'model.ckpt')
# settings
MODEL_CFG = OmegaConf.load(CFG_PATH)
def custom_to_pcd(x, config, rgb=None):
x = x.squeeze().detach().cpu().numpy()
x = (np.clip(x, -1., 1.) + 1.) / 2.
if rgb is not None:
rgb = rgb.squeeze().detach().cpu().numpy()
rgb = (np.clip(rgb, -1., 1.) + 1.) / 2.
rgb = rgb.transpose(1, 2, 0)
xyz, rgb, _ = range2pcd(x, color=rgb, **config['data']['params']['dataset'])
return xyz, rgb
def custom_to_pil(x):
x = x.detach().cpu().squeeze().numpy()
x = (np.clip(x, -1., 1.) + 1.) / 2.
x = (255 * x).astype(np.uint8)
if x.ndim == 3:
x = x.transpose(1, 2, 0)
x = Image.fromarray(x)
return x
def logs2pil(logs, keys=["sample"]):
imgs = dict()
for k in logs:
try:
if len(logs[k].shape) == 4:
img = custom_to_pil(logs[k][0, ...])
elif len(logs[k].shape) == 3:
img = custom_to_pil(logs[k])
else:
print(f"Unknown format for key {k}. ")
img = None
except:
img = None
imgs[k] = img
return imgs
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd, strict=False)
model.eval()
return model
@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, verbose=False):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
samples, intermediates = ddim.sample(steps, conditioning=cond, batch_size=bs, shape=shape, eta=eta, verbose=verbose, disable_tqdm=True)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(model, batch, batch_size, custom_steps=None, eta=1.0):
xc = batch['camera']
c = model.get_learned_conditioning(xc.to(model.device))
with model.ema_scope("Plotting"):
samples, z_denoise_row = model.sample_log(cond=c, batch_size=batch_size, ddim=True,
ddim_steps=custom_steps, eta=eta)
x_samples = model.decode_first_stage(samples)
return x_samples
def sample(model, cond):
batch = {'camera': cond}
img = make_convolutional_sample(model, batch, batch_size=1, custom_steps=CUSTOM_STEPS, eta=ETA) # TODO add arguments for batch_size, custom_steps and eta
pcd = custom_to_pcd(img, MODEL_CFG)[0].astype(np.float32)
img = img.squeeze().detach().cpu().numpy()
return img, pcd