File size: 10,591 Bytes
4f54ccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import os
import math
import cv2
import trimesh
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import nvdiffrast.torch as dr
from sparseags.mesh_utils.mesh import Mesh, safe_normalize


def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'):
    assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
    y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
    if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
        y = torch.nn.functional.interpolate(y, size, mode=min)
    else: # Magnification
        if mag == 'bilinear' or mag == 'bicubic':
            y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
        else:
            y = torch.nn.functional.interpolate(y, size, mode=mag)
    return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC

def scale_img_hwc(x, size, mag='bilinear', min='bilinear'):
    return scale_img_nhwc(x[None, ...], size, mag, min)[0]

def scale_img_nhw(x, size, mag='bilinear', min='bilinear'):
    return scale_img_nhwc(x[..., None], size, mag, min)[..., 0]

def scale_img_hw(x, size, mag='bilinear', min='bilinear'):
    return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0]

def trunc_rev_sigmoid(x, eps=1e-6):
    x = x.clamp(eps, 1 - eps)
    return torch.log(x / (1 - x))

def make_divisible(x, m=8):
    return int(math.ceil(x / m) * m)

class Renderer(nn.Module):
    def __init__(self, opt):
        
        super().__init__()

        self.opt = opt
        self.enable_dino = self.opt.lambda_dino > 0

        self.mesh = Mesh.load(self.opt.mesh, resize=False, enable_dino=self.enable_dino)

        if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
            self.glctx = dr.RasterizeGLContext()
        else:
            self.glctx = dr.RasterizeCudaContext()
        
        self.v_offsets = torch.zeros_like(self.mesh.v)
        self.raw_albedo = trunc_rev_sigmoid(self.mesh.albedo)

        # extract trainable parameters
        if opt.trainable_texture:
            self.v_offsets = nn.Parameter(self.v_offsets)
            self.raw_albedo = nn.Parameter(self.raw_albedo)

        if self.enable_dino:
            self.raw_feature = nn.Parameter((self.mesh.feature))


    def get_params(self):

        params = [
            {'params': self.raw_albedo, 'lr': self.opt.texture_lr},
        ]

        if self.enable_dino:
            params.append({'params': self.raw_feature, 'lr': self.opt.texture_lr})

        if self.opt.train_geo:
            params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})

        return params

    @torch.no_grad()
    def export_mesh(self, save_path):
        self.mesh.v = (self.mesh.v + self.v_offsets).detach()
        self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach())
        if self.enable_dino:
            self.mesh.feature = self.raw_feature.detach()
        self.mesh.write(save_path, self.enable_dino)

    
    def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'):
        
        # do super-sampling
        if ssaa != 1:
            h = make_divisible(h0 * ssaa, 8)
            w = make_divisible(w0 * ssaa, 8)
        else:
            h, w = h0, w0
        
        results = {}

        # get v
        if self.opt.train_geo:
            v = self.mesh.v + self.v_offsets # [N, 3]
        else:
            v = self.mesh.v

        pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
        proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)

        # get v_clip and render rgb
        v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
        v_clip = v_cam @ proj.T

        rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w))

        alpha = (rast[0, ..., 3:] > 0).float()
        depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1]
        depth = depth.squeeze(0) # [H, W, 1]

        texc, texc_db = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all')
        albedo = dr.texture(self.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]
        albedo = torch.sigmoid(albedo)
        if self.enable_dino:
            # NOTE: backward error when use filter_mode='linear-mipmap-linear'
            feature = dr.texture(self.raw_feature.unsqueeze(0), texc, uv_da=texc_db, filter_mode='linear')
        #     feature = torch.sigmoid(feature)
        # get vn and render normal
        if self.opt.train_geo:
            i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long()
            v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]

            face_normals = torch.cross(v1 - v0, v2 - v0)
            face_normals = safe_normalize(face_normals)
            
            vn = torch.zeros_like(v)
            vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
            vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
            vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)

            vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
        else:
            vn = self.mesh.vn
        
        normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, self.mesh.fn)
        normal = safe_normalize(normal[0])

        # rotated normal (where [0, 0, 1] always faces camera)
        rot_normal = normal @ pose[:3, :3]
        viewcos = rot_normal[..., [2]]

        # antialias
        albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
        albedo = alpha * albedo + (1 - alpha) * bg_color

        if self.enable_dino:
            feature = dr.antialias(feature, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
            feature = alpha * feature + (1 - alpha) * bg_color

        # ssaa
        if ssaa != 1:
            albedo = scale_img_hwc(albedo, (h0, w0))
            alpha = scale_img_hwc(alpha, (h0, w0))
            depth = scale_img_hwc(depth, (h0, w0))
            normal = scale_img_hwc(normal, (h0, w0))
            viewcos = scale_img_hwc(viewcos, (h0, w0))
            if self.enable_dino:
                feature = scale_img_hwc(feature, (h0, w0))

        results['image'] = albedo.clamp(0, 1)
        results['alpha'] = alpha
        results['depth'] = depth
        results['normal'] = (normal + 1) / 2
        results['viewcos'] = viewcos
        results['feature'] = feature if self.enable_dino else None # [H, W, 384]

        return results


    def render_batch(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'):
        
        # do super-sampling
        if ssaa != 1:
            h = make_divisible(h0 * ssaa, 8)
            w = make_divisible(w0 * ssaa, 8)
        else:
            h, w = h0, w0
        
        results = {}

        # get v
        if self.opt.train_geo:
            v = self.mesh.v + self.v_offsets # [N, 3]
        else:
            v = self.mesh.v

        bs = pose.shape[0]
        pose = pose.to(v.device)
        proj = proj.to(v.device).transpose(1, 2)

        # get v_clip and render rgb
        v_cam = torch.bmm(F.pad(v, pad=(0, 1), mode='constant', value=1.0).expand(bs, -1, -1), torch.linalg.inv(pose).transpose(1, 2)).float()
        v_clip = torch.bmm(v_cam, proj)

        rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w))

        alpha = (rast[..., 3:] > 0).float()
        depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1]

        texc, texc_db = dr.interpolate(self.mesh.vt.expand(bs, -1, -1).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all')
        albedo = dr.texture(self.raw_albedo.detach().unsqueeze(0).contiguous(), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]
        albedo = torch.sigmoid(albedo)
        if self.enable_dino:
            # NOTE: backward error when use filter_mode='linear-mipmap-linear'
            feature = dr.texture(self.raw_feature.unsqueeze(0), texc, uv_da=texc_db, filter_mode='linear')
        #     feature = torch.sigmoid(feature)
        # get vn and render normal
        if self.opt.train_geo:
            i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long()
            v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]

            face_normals = torch.cross(v1 - v0, v2 - v0)
            face_normals = safe_normalize(face_normals)
            
            vn = torch.zeros_like(v)
            vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
            vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
            vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)

            vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
        else:
            vn = self.mesh.vn
        
        normal, _ = dr.interpolate(vn.expand(bs, -1, -1).contiguous(), rast, self.mesh.fn)
        normal = safe_normalize(normal).reshape(bs, -1, 3)

        # rotated normal (where [0, 0, 1] always faces camera)
        rot_normal = torch.bmm(normal, pose[:, :3, :3]).reshape(bs, h, w, 3)
        viewcos = rot_normal[..., [2]]

        # antialias
        albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f) # [H, W, 3]
        albedo = alpha * albedo + (1 - alpha) * bg_color

        if self.enable_dino:
            feature = dr.antialias(feature, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
            feature = alpha * feature + (1 - alpha) * bg_color

        # ssaa
        if ssaa != 1:
            albedo = scale_img_hwc(albedo, (h0, w0))
            alpha = scale_img_hwc(alpha, (h0, w0))
            depth = scale_img_hwc(depth, (h0, w0))
            normal = scale_img_hwc(normal, (h0, w0))
            viewcos = scale_img_hwc(viewcos, (h0, w0))
            if self.enable_dino:
                feature = scale_img_hwc(feature, (h0, w0))

        results['image'] = albedo.clamp(0, 1)
        results['alpha'] = alpha
        results['depth'] = depth
        results['normal'] = (normal + 1) / 2
        results['viewcos'] = viewcos
        results['feature'] = feature if self.enable_dino else None # [H, W, 384]

        return results