lehduong commited on
Commit
a767b8d
·
verified ·
1 Parent(s): c2900a3

Delete dataset/raydiff_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataset/raydiff_utils.py +0 -739
dataset/raydiff_utils.py DELETED
@@ -1,739 +0,0 @@
1
-
2
- """
3
- Adapted from code originally written by David Novotny.
4
- """
5
-
6
- import torch
7
- from pytorch3d.transforms import Rotate, Translate
8
-
9
- import cv2
10
- import numpy as np
11
- import torch
12
- from pytorch3d.renderer import PerspectiveCameras, RayBundle
13
-
14
- def intersect_skew_line_groups(p, r, mask):
15
- # p, r both of shape (B, N, n_intersected_lines, 3)
16
- # mask of shape (B, N, n_intersected_lines)
17
- p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
18
- if p_intersect is None:
19
- return None, None, None, None
20
- _, p_line_intersect = point_line_distance(
21
- p, r, p_intersect[..., None, :].expand_as(p)
22
- )
23
- intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
24
- dim=-1
25
- )
26
- return p_intersect, p_line_intersect, intersect_dist_squared, r
27
-
28
-
29
- def intersect_skew_lines_high_dim(p, r, mask=None):
30
- # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
31
- dim = p.shape[-1]
32
- # make sure the heading vectors are l2-normed
33
- if mask is None:
34
- mask = torch.ones_like(p[..., 0])
35
- r = torch.nn.functional.normalize(r, dim=-1)
36
-
37
- eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
38
- I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
39
- sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
40
-
41
- # I_eps = torch.zeros_like(I_min_cov.sum(dim=-3)) + 1e-10
42
- # p_intersect = torch.pinverse(I_min_cov.sum(dim=-3) + I_eps).matmul(sum_proj)[..., 0]
43
- p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
44
-
45
- # I_min_cov.sum(dim=-3): torch.Size([1, 1, 3, 3])
46
- # sum_proj: torch.Size([1, 1, 3, 1])
47
-
48
- # p_intersect = np.linalg.lstsq(I_min_cov.sum(dim=-3).numpy(), sum_proj.numpy(), rcond=None)[0]
49
-
50
- if torch.any(torch.isnan(p_intersect)):
51
- print(p_intersect)
52
- return None, None
53
- ipdb.set_trace()
54
- assert False
55
- return p_intersect, r
56
-
57
-
58
- def point_line_distance(p1, r1, p2):
59
- df = p2 - p1
60
- proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
61
- line_pt_nearest = p2 - proj_vector
62
- d = (proj_vector).norm(dim=-1)
63
- return d, line_pt_nearest
64
-
65
-
66
- def compute_optical_axis_intersection(cameras):
67
- centers = cameras.get_camera_center()
68
- principal_points = cameras.principal_point
69
-
70
- one_vec = torch.ones((len(cameras), 1), device=centers.device)
71
- optical_axis = torch.cat((principal_points, one_vec), -1)
72
-
73
- # optical_axis = torch.cat(
74
- # (principal_points, cameras.focal_length[:, 0].unsqueeze(1)), -1
75
- # )
76
-
77
- pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
78
- pp2 = torch.diagonal(pp, dim1=0, dim2=1).T
79
-
80
- directions = pp2 - centers
81
- centers = centers.unsqueeze(0).unsqueeze(0)
82
- directions = directions.unsqueeze(0).unsqueeze(0)
83
-
84
- p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
85
- p=centers, r=directions, mask=None
86
- )
87
-
88
- if p_intersect is None:
89
- dist = None
90
- else:
91
- p_intersect = p_intersect.squeeze().unsqueeze(0)
92
- dist = (p_intersect - centers).norm(dim=-1)
93
-
94
- return p_intersect, dist, p_line_intersect, pp2, r
95
-
96
-
97
- def normalize_cameras(cameras, scale=1.0):
98
- """
99
- Normalizes cameras such that the optical axes point to the origin, the rotation is
100
- identity, and the norm of the translation of the first camera is 1.
101
-
102
- Args:
103
- cameras (pytorch3d.renderer.cameras.CamerasBase).
104
- scale (float): Norm of the translation of the first camera.
105
-
106
- Returns:
107
- new_cameras (pytorch3d.renderer.cameras.CamerasBase): Normalized cameras.
108
- undo_transform (function): Function that undoes the normalization.
109
- """
110
-
111
- # Let distance from first camera to origin be unit
112
- new_cameras = cameras.clone()
113
- new_transform = (
114
- new_cameras.get_world_to_view_transform()
115
- ) # potential R is not valid matrix
116
- p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
117
- cameras
118
- )
119
-
120
- if p_intersect is None:
121
- print("Warning: optical axes code has a nan. Returning identity cameras.")
122
- new_cameras.R[:] = torch.eye(3, device=cameras.R.device, dtype=cameras.R.dtype)
123
- new_cameras.T[:] = torch.tensor(
124
- [0, 0, 1], device=cameras.T.device, dtype=cameras.T.dtype
125
- )
126
- return new_cameras, lambda x: x
127
-
128
- d = dist.squeeze(dim=1).squeeze(dim=0)[0]
129
- # Degenerate case
130
- if d == 0:
131
- print(cameras.T)
132
- print(new_transform.get_matrix()[:, 3, :3])
133
- assert False
134
- assert d != 0
135
-
136
- # Can't figure out how to make scale part of the transform too without messing up R.
137
- # Ideally, we would just wrap it all in a single Pytorch3D transform so that it
138
- # would work with any structure (eg PointClouds, Meshes).
139
- tR = Rotate(new_cameras.R[0].unsqueeze(0)).inverse()
140
- tT = Translate(p_intersect)
141
- t = tR.compose(tT)
142
-
143
- new_transform = t.compose(new_transform)
144
- new_cameras.R = new_transform.get_matrix()[:, :3, :3]
145
- new_cameras.T = new_transform.get_matrix()[:, 3, :3] / d * scale
146
-
147
- def undo_transform(cameras):
148
- cameras_copy = cameras.clone()
149
- cameras_copy.T *= d / scale
150
- new_t = (
151
- t.inverse().compose(cameras_copy.get_world_to_view_transform()).get_matrix()
152
- )
153
- cameras_copy.R = new_t[:, :3, :3]
154
- cameras_copy.T = new_t[:, 3, :3]
155
- return cameras_copy
156
-
157
- return new_cameras, undo_transform
158
-
159
- def first_camera_transform(cameras, rotation_only=True):
160
- new_cameras = cameras.clone()
161
- new_transform = new_cameras.get_world_to_view_transform()
162
- tR = Rotate(new_cameras.R[0].unsqueeze(0))
163
- if rotation_only:
164
- t = tR.inverse()
165
- else:
166
- tT = Translate(new_cameras.T[0].unsqueeze(0))
167
- t = tR.compose(tT).inverse()
168
-
169
- new_transform = t.compose(new_transform)
170
- new_cameras.R = new_transform.get_matrix()[:, :3, :3]
171
- new_cameras.T = new_transform.get_matrix()[:, 3, :3]
172
-
173
- return new_cameras
174
-
175
-
176
- def get_identity_cameras_with_intrinsics(cameras):
177
- D = len(cameras)
178
- device = cameras.R.device
179
-
180
- new_cameras = cameras.clone()
181
- new_cameras.R = torch.eye(3, device=device).unsqueeze(0).repeat((D, 1, 1))
182
- new_cameras.T = torch.zeros((D, 3), device=device)
183
-
184
- return new_cameras
185
-
186
-
187
- def normalize_cameras_batch(cameras, scale=1.0, normalize_first_camera=False):
188
- new_cameras = []
189
- undo_transforms = []
190
- for cam in cameras:
191
- if normalize_first_camera:
192
- # Normalize cameras such that first camera is identity and origin is at
193
- # first camera center.
194
- normalized_cameras = first_camera_transform(cam, rotation_only=False)
195
- undo_transform = None
196
- else:
197
- normalized_cameras, undo_transform = normalize_cameras(cam, scale=scale)
198
- new_cameras.append(normalized_cameras)
199
- undo_transforms.append(undo_transform)
200
- return new_cameras, undo_transforms
201
-
202
-
203
- class Rays(object):
204
- def __init__(
205
- self,
206
- rays=None,
207
- origins=None,
208
- directions=None,
209
- moments=None,
210
- is_plucker=False,
211
- moments_rescale=1.0,
212
- ndc_coordinates=None,
213
- crop_parameters=None,
214
- num_patches_x=16,
215
- num_patches_y=16,
216
- ):
217
- """
218
- Ray class to keep track of current ray representation.
219
-
220
- Args:
221
- rays: (..., 6).
222
- origins: (..., 3).
223
- directions: (..., 3).
224
- moments: (..., 3).
225
- is_plucker: If True, rays are in plucker coordinates (Default: False).
226
- moments_rescale: Rescale the moment component of the rays by a scalar.
227
- ndc_coordinates: (..., 2): NDC coordinates of each ray.
228
- """
229
- if rays is not None:
230
- self.rays = rays
231
- self._is_plucker = is_plucker
232
- elif origins is not None and directions is not None:
233
- self.rays = torch.cat((origins, directions), dim=-1)
234
- self._is_plucker = False
235
- elif directions is not None and moments is not None:
236
- self.rays = torch.cat((directions, moments), dim=-1)
237
- self._is_plucker = True
238
- else:
239
- raise Exception("Invalid combination of arguments")
240
-
241
- if moments_rescale != 1.0:
242
- self.rescale_moments(moments_rescale)
243
-
244
- if ndc_coordinates is not None:
245
- self.ndc_coordinates = ndc_coordinates
246
- elif crop_parameters is not None:
247
- # (..., H, W, 2)
248
- xy_grid = compute_ndc_coordinates(
249
- crop_parameters,
250
- num_patches_x=num_patches_x,
251
- num_patches_y=num_patches_y,
252
- )[..., :2]
253
- xy_grid = xy_grid.reshape(*xy_grid.shape[:-3], -1, 2)
254
- self.ndc_coordinates = xy_grid
255
- else:
256
- self.ndc_coordinates = None
257
-
258
- def __getitem__(self, index):
259
- return Rays(
260
- rays=self.rays[index],
261
- is_plucker=self._is_plucker,
262
- ndc_coordinates=(
263
- self.ndc_coordinates[index]
264
- if self.ndc_coordinates is not None
265
- else None
266
- ),
267
- )
268
-
269
- def to_spatial(self, include_ndc_coordinates=False):
270
- """
271
- Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W)
272
-
273
- Returns:
274
- torch.Tensor: (..., 6, H, W)
275
- """
276
- rays = self.to_plucker().rays
277
- *batch_dims, P, D = rays.shape
278
- H = W = int(np.sqrt(P))
279
- assert H * W == P
280
- rays = torch.transpose(rays, -1, -2) # (..., 6, H * W)
281
- rays = rays.reshape(*batch_dims, D, H, W)
282
- if include_ndc_coordinates:
283
- ndc_coords = self.ndc_coordinates.transpose(-1, -2) # (..., 2, H * W)
284
- ndc_coords = ndc_coords.reshape(*batch_dims, 2, H, W)
285
- rays = torch.cat((rays, ndc_coords), dim=-3)
286
- return rays
287
-
288
- def rescale_moments(self, scale):
289
- """
290
- Rescale the moment component of the rays by a scalar. Might be desirable since
291
- moments may come from a very narrow distribution.
292
-
293
- Note that this modifies in place!
294
- """
295
- if self.is_plucker:
296
- self.rays[..., 3:] *= scale
297
- return self
298
- else:
299
- return self.to_plucker().rescale_moments(scale)
300
-
301
- @classmethod
302
- def from_spatial(cls, rays, moments_rescale=1.0, ndc_coordinates=None):
303
- """
304
- Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6)
305
-
306
- Args:
307
- rays: (..., 6, H, W)
308
-
309
- Returns:
310
- Rays: (..., H * W, 6)
311
- """
312
- *batch_dims, D, H, W = rays.shape
313
- rays = rays.reshape(*batch_dims, D, H * W)
314
- rays = torch.transpose(rays, -1, -2)
315
- return cls(
316
- rays=rays,
317
- is_plucker=True,
318
- moments_rescale=moments_rescale,
319
- ndc_coordinates=ndc_coordinates,
320
- )
321
-
322
- def to_point_direction(self, normalize_moment=True):
323
- """
324
- Convert to point direction representation <O, D>.
325
-
326
- Returns:
327
- rays: (..., 6).
328
- """
329
- if self._is_plucker:
330
- direction = torch.nn.functional.normalize(self.rays[..., :3], dim=-1)
331
- moment = self.rays[..., 3:]
332
- if normalize_moment:
333
- c = torch.linalg.norm(direction, dim=-1, keepdim=True)
334
- moment = moment / c
335
- points = torch.cross(direction, moment, dim=-1)
336
- return Rays(
337
- rays=torch.cat((points, direction), dim=-1),
338
- is_plucker=False,
339
- ndc_coordinates=self.ndc_coordinates,
340
- )
341
- else:
342
- return self
343
-
344
- def to_plucker(self):
345
- """
346
- Convert to plucker representation <D, OxD>.
347
- """
348
- if self.is_plucker:
349
- return self
350
- else:
351
- ray = self.rays.clone()
352
- ray_origins = ray[..., :3]
353
- ray_directions = ray[..., 3:]
354
- # Normalize ray directions to unit vectors
355
- ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True)
356
- plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
357
- new_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
358
- return Rays(
359
- rays=new_ray, is_plucker=True, ndc_coordinates=self.ndc_coordinates
360
- )
361
-
362
- def get_directions(self, normalize=True):
363
- if self.is_plucker:
364
- directions = self.rays[..., :3]
365
- else:
366
- directions = self.rays[..., 3:]
367
- if normalize:
368
- directions = torch.nn.functional.normalize(directions, dim=-1)
369
- return directions
370
-
371
- def get_origins(self):
372
- if self.is_plucker:
373
- origins = self.to_point_direction().get_origins()
374
- else:
375
- origins = self.rays[..., :3]
376
- return origins
377
-
378
- def get_moments(self):
379
- if self.is_plucker:
380
- moments = self.rays[..., 3:]
381
- else:
382
- moments = self.to_plucker().get_moments()
383
- return moments
384
-
385
- def get_ndc_coordinates(self):
386
- return self.ndc_coordinates
387
-
388
- @property
389
- def is_plucker(self):
390
- return self._is_plucker
391
-
392
- @property
393
- def device(self):
394
- return self.rays.device
395
-
396
- def __repr__(self, *args, **kwargs):
397
- ray_str = self.rays.__repr__(*args, **kwargs)[6:] # remove "tensor"
398
- if self._is_plucker:
399
- return "PluRay" + ray_str
400
- else:
401
- return "DirRay" + ray_str
402
-
403
- def to(self, device):
404
- self.rays = self.rays.to(device)
405
-
406
- def clone(self):
407
- return Rays(rays=self.rays.clone(), is_plucker=self._is_plucker)
408
-
409
- @property
410
- def shape(self):
411
- return self.rays.shape
412
-
413
- def visualize(self):
414
- directions = torch.nn.functional.normalize(self.get_directions(), dim=-1).cpu()
415
- moments = torch.nn.functional.normalize(self.get_moments(), dim=-1).cpu()
416
- return (directions + 1) / 2, (moments + 1) / 2
417
-
418
- def to_ray_bundle(self, length=0.3, recenter=True):
419
- lengths = torch.ones_like(self.get_origins()[..., :2]) * length
420
- lengths[..., 0] = 0
421
- if recenter:
422
- centers, _ = intersect_skew_lines_high_dim(
423
- self.get_origins(), self.get_directions()
424
- )
425
- centers = centers.unsqueeze(1).repeat(1, lengths.shape[1], 1)
426
- else:
427
- centers = self.get_origins()
428
- return RayBundle(
429
- origins=centers,
430
- directions=self.get_directions(),
431
- lengths=lengths,
432
- xys=self.get_directions(),
433
- )
434
-
435
-
436
- def cameras_to_rays(
437
- cameras,
438
- crop_parameters,
439
- use_half_pix=True,
440
- use_plucker=True,
441
- num_patches_x=16,
442
- num_patches_y=16,
443
- ):
444
- """
445
- Unprojects rays from camera center to grid on image plane.
446
-
447
- Args:
448
- cameras: Pytorch3D cameras to unproject. Can be batched.
449
- crop_parameters: Crop parameters in NDC (cc_x, cc_y, crop_width, scale).
450
- Shape is (B, 4).
451
- use_half_pix: If True, use half pixel offset (Default: True).
452
- use_plucker: If True, return rays in plucker coordinates (Default: False).
453
- num_patches_x: Number of patches in x direction (Default: 16).
454
- num_patches_y: Number of patches in y direction (Default: 16).
455
- """
456
- unprojected = []
457
- crop_parameters_list = (
458
- crop_parameters if crop_parameters is not None else [None for _ in cameras]
459
- )
460
- for camera, crop_param in zip(cameras, crop_parameters_list):
461
- xyd_grid = compute_ndc_coordinates(
462
- crop_parameters=crop_param,
463
- use_half_pix=use_half_pix,
464
- num_patches_x=num_patches_x,
465
- num_patches_y=num_patches_y,
466
- )
467
-
468
- unprojected.append(
469
- camera.unproject_points(
470
- xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True
471
- )
472
- )
473
- unprojected = torch.stack(unprojected, dim=0) # (N, P, 3)
474
- origins = cameras.get_camera_center().unsqueeze(1) # (N, 1, 3)
475
- origins = origins.repeat(1, num_patches_x * num_patches_y, 1) # (N, P, 3)
476
- directions = unprojected - origins
477
-
478
- rays = Rays(
479
- origins=origins,
480
- directions=directions,
481
- crop_parameters=crop_parameters,
482
- num_patches_x=num_patches_x,
483
- num_patches_y=num_patches_y,
484
- )
485
- if use_plucker:
486
- return rays.to_plucker()
487
- return rays
488
-
489
-
490
- def rays_to_cameras(
491
- rays,
492
- crop_parameters,
493
- num_patches_x=16,
494
- num_patches_y=16,
495
- use_half_pix=True,
496
- sampled_ray_idx=None,
497
- cameras=None,
498
- focal_length=(3.453,),
499
- ):
500
- """
501
- If cameras are provided, will use those intrinsics. Otherwise will use the provided
502
- focal_length(s). Dataset default is 3.32.
503
-
504
- Args:
505
- rays (Rays): (N, P, 6)
506
- crop_parameters (torch.Tensor): (N, 4)
507
- """
508
- device = rays.device
509
- origins = rays.get_origins()
510
- directions = rays.get_directions()
511
- camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
512
-
513
- # Retrieve target rays
514
- if cameras is None:
515
- if len(focal_length) == 1:
516
- focal_length = focal_length * rays.shape[0]
517
- I_camera = PerspectiveCameras(focal_length=focal_length, device=device)
518
- else:
519
- # Use same intrinsics but reset to identity extrinsics.
520
- I_camera = cameras.clone()
521
- I_camera.R[:] = torch.eye(3, device=device)
522
- I_camera.T[:] = torch.zeros(3, device=device)
523
- I_patch_rays = cameras_to_rays(
524
- cameras=I_camera,
525
- num_patches_x=num_patches_x,
526
- num_patches_y=num_patches_y,
527
- use_half_pix=use_half_pix,
528
- crop_parameters=crop_parameters,
529
- ).get_directions()
530
-
531
- if sampled_ray_idx is not None:
532
- I_patch_rays = I_patch_rays[:, sampled_ray_idx]
533
-
534
- # Compute optimal rotation to align rays
535
- R = torch.zeros_like(I_camera.R)
536
- for i in range(len(I_camera)):
537
- R[i] = compute_optimal_rotation_alignment(
538
- I_patch_rays[i],
539
- directions[i],
540
- )
541
-
542
- # Construct and return rotated camera
543
- cam = I_camera.clone()
544
- cam.R = R
545
- cam.T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
546
- return cam
547
-
548
-
549
- # https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/
550
- def ql_decomposition(A):
551
- P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float()
552
- A_tilde = torch.matmul(A, P)
553
- Q_tilde, R_tilde = torch.linalg.qr(A_tilde)
554
- Q = torch.matmul(Q_tilde, P)
555
- L = torch.matmul(torch.matmul(P, R_tilde), P)
556
- d = torch.diag(L)
557
- Q[:, 0] *= torch.sign(d[0])
558
- Q[:, 1] *= torch.sign(d[1])
559
- Q[:, 2] *= torch.sign(d[2])
560
- L[0] *= torch.sign(d[0])
561
- L[1] *= torch.sign(d[1])
562
- L[2] *= torch.sign(d[2])
563
- return Q, L
564
-
565
-
566
- def rays_to_cameras_homography(
567
- rays,
568
- crop_parameters,
569
- num_patches_x=16,
570
- num_patches_y=16,
571
- use_half_pix=True,
572
- sampled_ray_idx=None,
573
- reproj_threshold=0.2,
574
- ):
575
- """
576
- Args:
577
- rays (Rays): (N, P, 6)
578
- crop_parameters (torch.Tensor): (N, 4)
579
- """
580
- device = rays.device
581
- origins = rays.get_origins()
582
- directions = rays.get_directions()
583
- camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
584
-
585
- # Retrieve target rays
586
- I_camera = PerspectiveCameras(focal_length=[1] * rays.shape[0], device=device)
587
- I_patch_rays = cameras_to_rays(
588
- cameras=I_camera,
589
- num_patches_x=num_patches_x,
590
- num_patches_y=num_patches_y,
591
- use_half_pix=use_half_pix,
592
- crop_parameters=crop_parameters,
593
- ).get_directions()
594
-
595
- if sampled_ray_idx is not None:
596
- I_patch_rays = I_patch_rays[:, sampled_ray_idx]
597
-
598
- # Compute optimal rotation to align rays
599
- Rs = []
600
- focal_lengths = []
601
- principal_points = []
602
- for i in range(rays.shape[-3]):
603
- R, f, pp = compute_optimal_rotation_intrinsics(
604
- I_patch_rays[i],
605
- directions[i],
606
- reproj_threshold=reproj_threshold,
607
- )
608
- Rs.append(R)
609
- focal_lengths.append(f)
610
- principal_points.append(pp)
611
-
612
- R = torch.stack(Rs)
613
- focal_lengths = torch.stack(focal_lengths)
614
- principal_points = torch.stack(principal_points)
615
- T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
616
- return PerspectiveCameras(
617
- R=R,
618
- T=T,
619
- focal_length=focal_lengths,
620
- principal_point=principal_points,
621
- device=device,
622
- )
623
-
624
-
625
- def compute_optimal_rotation_alignment(A, B):
626
- """
627
- Compute optimal R that minimizes: || A - B @ R ||_F
628
-
629
- Args:
630
- A (torch.Tensor): (N, 3)
631
- B (torch.Tensor): (N, 3)
632
-
633
- Returns:
634
- R (torch.tensor): (3, 3)
635
- """
636
- # normally with R @ B, this would be A @ B.T
637
- H = B.T @ A
638
- U, _, Vh = torch.linalg.svd(H, full_matrices=True)
639
- s = torch.linalg.det(U @ Vh)
640
- S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
641
- return U @ S_prime @ Vh
642
-
643
-
644
- def compute_optimal_rotation_intrinsics(
645
- rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2
646
- ):
647
- """
648
- Note: for some reason, f seems to be 1/f.
649
-
650
- Args:
651
- rays_origin (torch.Tensor): (N, 3)
652
- rays_target (torch.Tensor): (N, 3)
653
- z_threshold (float): Threshold for z value to be considered valid.
654
-
655
- Returns:
656
- R (torch.tensor): (3, 3)
657
- focal_length (torch.tensor): (2,)
658
- principal_point (torch.tensor): (2,)
659
- """
660
- device = rays_origin.device
661
- z_mask = torch.logical_and(
662
- torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold
663
- )[:, 2]
664
- rays_target = rays_target[z_mask]
665
- rays_origin = rays_origin[z_mask]
666
- rays_origin = rays_origin[:, :2] / rays_origin[:, -1:]
667
- rays_target = rays_target[:, :2] / rays_target[:, -1:]
668
-
669
- A, _ = cv2.findHomography(
670
- rays_origin.cpu().numpy(),
671
- rays_target.cpu().numpy(),
672
- cv2.RANSAC,
673
- reproj_threshold,
674
- )
675
- A = torch.from_numpy(A).float().to(device)
676
-
677
- if torch.linalg.det(A) < 0:
678
- A = -A
679
-
680
- R, L = ql_decomposition(A)
681
- L = L / L[2][2]
682
-
683
- f = torch.stack((L[0][0], L[1][1]))
684
- pp = torch.stack((L[2][0], L[2][1]))
685
- return R, f, pp
686
-
687
-
688
- def compute_ndc_coordinates(
689
- crop_parameters=None,
690
- use_half_pix=True,
691
- num_patches_x=16,
692
- num_patches_y=16,
693
- device=None,
694
- ):
695
- """
696
- Computes NDC Grid using crop_parameters. If crop_parameters is not provided,
697
- then it assumes that the crop is the entire image (corresponding to an NDC grid
698
- where top left corner is (1, 1) and bottom right corner is (-1, -1)).
699
- """
700
- if crop_parameters is None:
701
- cc_x, cc_y, width = 0, 0, 2
702
- else:
703
- if len(crop_parameters.shape) > 1:
704
- return torch.stack(
705
- [
706
- compute_ndc_coordinates(
707
- crop_parameters=crop_param,
708
- use_half_pix=use_half_pix,
709
- num_patches_x=num_patches_x,
710
- num_patches_y=num_patches_y,
711
- )
712
- for crop_param in crop_parameters
713
- ],
714
- dim=0,
715
- )
716
- device = crop_parameters.device
717
- cc_x, cc_y, width, _ = crop_parameters
718
-
719
- dx = 1 / num_patches_x
720
- dy = 1 / num_patches_y
721
- if use_half_pix:
722
- min_y = 1 - dy
723
- max_y = -min_y
724
- min_x = 1 - dx
725
- max_x = -min_x
726
- else:
727
- min_y = min_x = 1
728
- max_y = -1 + 2 * dy
729
- max_x = -1 + 2 * dx
730
-
731
- y, x = torch.meshgrid(
732
- torch.linspace(min_y, max_y, num_patches_y, dtype=torch.float32, device=device),
733
- torch.linspace(min_x, max_x, num_patches_x, dtype=torch.float32, device=device),
734
- indexing="ij",
735
- )
736
- x_prime = x * width / 2 - cc_x
737
- y_prime = y * width / 2 - cc_y
738
- xyd_grid = torch.stack([x_prime, y_prime, torch.ones_like(x)], dim=-1)
739
- return xyd_grid