File size: 2,885 Bytes
851751e
 
 
 
 
 
 
 
1615664
851751e
 
 
 
 
 
 
 
 
 
 
 
1615664
851751e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615664
851751e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615664
dbc3505
851751e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import torch
import numpy as np

from omegaconf import OmegaConf
from PIL import Image

from lidm.models.diffusion.ddim import DDIMSampler
from lidm.utils.misc_utils import instantiate_from_config
from lidm.utils.lidar_utils import range2pcd


CUSTOM_STEPS = 50
ETA = 1.0

# model loading
MODEL_PATH = './models/lidm/kitti/cam2lidar'
CFG_PATH = os.path.join(MODEL_PATH, 'config.yaml')
CKPT_PATH = os.path.join(MODEL_PATH, 'model.ckpt')

# settings
MODEL_CFG = OmegaConf.load(CFG_PATH)


def custom_to_pcd(x, config, rgb=None):
    x = x.squeeze().detach().cpu().numpy()
    x = (np.clip(x, -1., 1.) + 1.) / 2.
    if rgb is not None:
        rgb = rgb.squeeze().detach().cpu().numpy()
        rgb = (np.clip(rgb, -1., 1.) + 1.) / 2.
        rgb = rgb.transpose(1, 2, 0)
    xyz, rgb, _ = range2pcd(x, color=rgb, **config['data']['params']['dataset'])

    return xyz, rgb


def custom_to_pil(x):
    x = x.detach().cpu().squeeze().numpy()
    x = (np.clip(x, -1., 1.) + 1.) / 2.
    x = (255 * x).astype(np.uint8)

    if x.ndim == 3:
        x = x.transpose(1, 2, 0)
    x = Image.fromarray(x)

    return x


def logs2pil(logs, keys=["sample"]):
    imgs = dict()
    for k in logs:
        try:
            if len(logs[k].shape) == 4:
                img = custom_to_pil(logs[k][0, ...])
            elif len(logs[k].shape) == 3:
                img = custom_to_pil(logs[k])
            else:
                print(f"Unknown format for key {k}. ")
                img = None
        except:
            img = None
        imgs[k] = img
    return imgs


def load_model_from_config(config, sd):
    model = instantiate_from_config(config)
    model.load_state_dict(sd, strict=False)
    model.eval()
    return model


@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, verbose=False):
    ddim = DDIMSampler(model)
    bs = shape[0]
    shape = shape[1:]
    samples, intermediates = ddim.sample(steps, conditioning=cond, batch_size=bs, shape=shape, eta=eta, verbose=verbose, disable_tqdm=True)
    return samples, intermediates


@torch.no_grad()
def make_convolutional_sample(model, batch, batch_size, custom_steps=None, eta=1.0):
    xc = batch['camera']
    c = model.get_learned_conditioning(xc.to(model.device))

    with model.ema_scope("Plotting"):
        samples, z_denoise_row = model.sample_log(cond=c, batch_size=batch_size, ddim=True,
                                                  ddim_steps=custom_steps, eta=eta)
    x_samples = model.decode_first_stage(samples)

    return x_samples


def sample(model, cond):
    batch = {'camera': cond}
    img = make_convolutional_sample(model, batch, batch_size=1, custom_steps=CUSTOM_STEPS, eta=ETA)  # TODO add arguments for batch_size, custom_steps and eta
    pcd = custom_to_pcd(img, MODEL_CFG)[0].astype(np.float32)
    img = img.squeeze().detach().cpu().numpy()
    return img, pcd