File size: 9,003 Bytes
801501a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from ..custom_types import *
from ..options import Options
from ..utils import train_utils, mcubes_meshing, files_utils, mesh_utils
from ..models.occ_gmm import Spaghetti
from ..models import models_utils


class Inference:

    def get_occ_fun(self, z: T):

        def forward(x: T) -> T:
            nonlocal z
            x = x.unsqueeze(0)
            out = self.model.occupancy_network(x, z)[0, :]
            out = 2 * out.sigmoid_() - 1
            return out

        if z.dim() == 2:
            z = z.unsqueeze(0)
        return forward

    @models_utils.torch_no_grad
    def get_mesh(self, z: T, res: int) -> Optional[T_Mesh]:
        mesh = self.meshing.occ_meshing(self.get_occ_fun(z), res=res)
        return mesh

    def plot_occ(self, z: Union[T, TS], z_base, gmms: Optional[TS], fixed_items: T,
                 folder_name: str, res=200, verbose=False):
        for i in range(len(z)):
            mesh = self.get_mesh(z[i], res)
            name = f'{fixed_items[i]:04d}'
            if mesh is not None:
                files_utils.export_mesh(mesh, f'{self.opt.cp_folder}/{folder_name}/occ/{name}')
                files_utils.save_pickle(z_base[i].detach().cpu(), f'{self.opt.cp_folder}/{folder_name}/occ/{name}')
                if gmms is not None:
                    files_utils.export_gmm(gmms, i, f'{self.opt.cp_folder}/{folder_name}/occ/{name}')
            if verbose:
                print(f'done {i + 1:d}/{len(z):d}')

    def load_file(self, info_path, disclude: Optional[List[int]] = None):
        info = files_utils.load_pickle(''.join(info_path))
        keys = list(info['ids'].keys())
        items = map(lambda x: int(x.split('_')[1]) if type(x) is str else x, keys)
        items = torch.tensor(list(items), dtype=torch.int64, device=self.device)
        zh, _, gmms_sanity, _ = self.model.get_embeddings(items)
        gmms = [item for item in info['gmm']]
        zh_ = []
        split = []
        gmm_mask = torch.ones(gmms[0].shape[2], dtype=torch.bool)
        counter = 0
        # gmms_ = [[] for _ in range(len(gmms))]
        for i, key in enumerate(keys):
            gaussian_inds = info['ids'][key]
            if disclude is not None:
                for j in range(len(gaussian_inds)):
                    gmm_mask[j + counter] = gaussian_inds[j] not in disclude
                counter += len(gaussian_inds)
                gaussian_inds = [ind for ind in gaussian_inds if ind not in disclude]
                info['ids'][key] = gaussian_inds
            gaussian_inds = torch.tensor(gaussian_inds, dtype=torch.int64)
            zh_.append(zh[i, gaussian_inds])
            split.append(len(split) + torch.ones(len(info['ids'][key]), dtype=torch.int64, device=self.device))
        zh_ = torch.cat(zh_, dim=0).unsqueeze(0).to(self.device)
        gmms = [item[:, :, gmm_mask].to(self.device) for item in info['gmm']]
        return zh_, gmms, split, info['ids']

    @models_utils.torch_no_grad
    def get_z_from_file(self, info_path):
        zh_, gmms, split, _ = self.load_file(info_path)
        zh_ = self.model.merge_zh_step_a(zh_, [gmms])
        zh, _ = self.model.affine_transformer.forward_with_attention(zh_)
        # gmms_ = [torch.cat(item, dim=1).unsqueeze(0) for item in gmms_]
        # zh, _ = self.model.merge_zh(zh_, [gmms])
        return zh, zh_, gmms, torch.cat(split)

    def plot_from_info(self, info_path, res):
        zh, zh_, gmms, split = self.get_z_from_file(info_path)
        mesh = self.get_mesh(zh[0], res, gmms)
        if mesh is not None:
            attention = self.get_attention_faces(mesh, zh, fixed_z=split)
        else:
            attention = None
        return mesh, attention

    @staticmethod
    def combine_and_pad(zh_a: T, zh_b: T) -> Tuple[T, TN]:
        if zh_a.shape[1] == zh_b.shape[1]:
            mask = None
        else:
            pad_length = max(zh_a.shape[1], zh_b.shape[1])
            mask = torch.zeros(2, pad_length, device=zh_a.device, dtype=torch.bool)
            padding = torch.zeros(1, abs(zh_a.shape[1] - zh_b.shape[1]), zh_a.shape[-1], device=zh_a.device)
            if zh_a.shape[1] > zh_b.shape[1]:
                mask[1, zh_b.shape[1]:] = True
                zh_b = torch.cat((zh_b, padding), dim=1)
            else:
                mask[0, zh_a.shape[1]:] = True
                zh_a = torch.cat((zh_a, padding), dim=1)
        return torch.cat((zh_a, zh_b), dim=0), mask

    @staticmethod
    def get_intersection_z(z_a: T, z_b: T) -> T:
        diff = (z_a[0, :, None, :] - z_b[0, None]).abs().sum(-1)
        diff_a = diff.min(1)[0].lt(.1)
        diff_b = diff.min(0)[0].lt(.1)
        if diff_a.shape[0] != diff_b.shape[0]:
            padding = torch.zeros(abs(diff_a.shape[0] - diff_b.shape[0]), device=z_a.device, dtype=torch.bool)
            if diff_a.shape[0] > diff_b.shape[0]:
                diff_b = torch.cat((diff_b, padding))
            else:
                diff_a = torch.cat((diff_a, padding))
        return torch.cat((diff_a, diff_b))

    def get_attention_points(self, vs: T, zh: T, mask: TN = None, alpha: TN = None):
        vs = vs.unsqueeze(0)
        attention = self.model.occupancy_network.forward_attention(vs, zh, mask=mask, alpha=alpha)
        attention = torch.stack(attention, 0).mean(0).mean(-1)
        attention = attention.permute(1, 0, 2).reshape(attention.shape[1], -1)
        attention_max = attention.argmax(-1)
        return attention_max

    @models_utils.torch_no_grad
    def get_attention_faces(self, mesh: T_Mesh, zh: T, mask: TN = None, fixed_z: TN = None, alpha: TN = None):
        coords = mesh[0][mesh[1]].mean(1).to(zh.device)

        attention_max = self.get_attention_points(coords, zh, mask, alpha)
        if fixed_z is not None:
            attention_select = fixed_z[attention_max].cpu()
        else:
            attention_select = attention_max
        return attention_select

    @models_utils.torch_no_grad
    def plot_folder(self, *folders, res: int = 256):
        logger = train_utils.Logger()
        for folder in folders:
            paths = files_utils.collect(folder, '.pkl')
            logger.start(len(paths))
            for path in paths:
                name = path[1]
                out_path = f"{self.opt.cp_folder}/from_ui/{name}"
                mesh, colors = self.plot_from_info(path, res)
                if mesh is not None:
                    files_utils.export_mesh(mesh, out_path)
                    files_utils.export_list(colors.tolist(), f"{out_path}_faces")
                logger.reset_iter()
            logger.stop()

    def get_zh_from_idx(self, items: T):
        zh, _, gmms, __ = self.model.get_embeddings(items.to(self.device))
        zh, attn_b = self.model.merge_zh(zh, gmms)
        return zh, gmms

    @property
    def device(self):
        return self.opt.device

    def get_new_ids(self, folder_name, nums_sample):
        names = [int(path[1]) for path in files_utils.collect(f'{self.opt.cp_folder}/{folder_name}/occ/', '.obj')]
        ids = torch.arange(nums_sample)
        if len(names) == 0:
            return ids + self.opt.dataset_size
        return ids + max(max(names) + 1, self.opt.dataset_size)

    @models_utils.torch_no_grad
    def random_plot(self, folder_name: str, nums_sample, res=200, verbose=False):
        zh_base, gmms = self.model.random_samples(nums_sample)
        zh, attn_b = self.model.merge_zh(zh_base, gmms)
        numbers = self.get_new_ids(folder_name, nums_sample)
        self.plot_occ(zh, zh_base, gmms, numbers, folder_name, verbose=verbose, res=res)

    @models_utils.torch_no_grad
    def plot(self, folder_name: str, nums_sample: int, verbose=False, res: int = 200):
        if self.model.opt.dataset_size < nums_sample:
            fixed_items = torch.arange(self.model.opt.dataset_size)
        else:
            fixed_items = torch.randint(low=0, high=self.opt.dataset_size, size=(nums_sample,))
        zh_base, _, gmms = self.model.get_embeddings(fixed_items.to(self.device))
        zh, attn_b = self.model.merge_zh(zh_base, gmms)
        self.plot_occ(zh, zh_base, gmms, fixed_items, folder_name, verbose=verbose, res=res)

    def get_mesh_from_mid(self, gmm, included: T, res: int) -> Optional[T_Mesh]:
        if self.mid is None:
            return None
        gmm = [elem.to(self.device) for elem in gmm]
        included = included.to(device=self.device)
        mid_ = self.mid[included[:, 0], included[:, 1]].unsqueeze(0)
        zh = self.model.merge_zh(mid_, gmm)[0]
        mesh = self.get_mesh(zh[0], res)
        return mesh

    def set_items(self, items: T):
        self.mid = items.to(self.device)

    def __init__(self, opt: Options):
        self.opt = opt
        model: Tuple[Spaghetti, Options] = train_utils.model_lc(opt)
        self.model, self.opt = model
        self.model.eval()
        self.mid: Optional[T] = None
        self.gmms: Optional[TN] = None
        self.meshing = mcubes_meshing.MarchingCubesMeshing(self.device, scale=1.)