wxl commited on
Commit
93bde9f
·
1 Parent(s): 4150f51

add get bev for front view image

Browse files
app.py CHANGED
@@ -1,7 +1,104 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import scipy.io as io
4
+ import numpy as np
5
+ import warnings
6
+ import torch.nn.functional as F
7
  import gradio as gr
8
+ import torchgeometry as tgm
9
+ from models.utils.torch_geometry import get_perspective_transform, warp_perspective
10
 
11
+ warnings.filterwarnings("ignore")
 
12
 
13
+ def get_BEV_kitti(front_img, fov, pitch, scale, out_size):
14
+ Hp, Wp = front_img.shape[:2]
15
+
16
+ Wo,Ho = int(Wp*scale),int(Wp*scale)
17
+
18
+ fov = fov *torch.pi/180 #
19
+ theta = pitch*torch.pi/180 # Camera pitch angle
20
+
21
+
22
+ f = Hp/2/torch.tan(torch.tensor(fov))
23
+ phi = torch.pi/2 - fov
24
+ delta = torch.pi/2+theta - torch.tensor(phi)
25
+ l = torch.sqrt(f**2+(Hp/2)**2)
26
+ h = l*torch.sin(delta)
27
+ f_ = l*torch.cos(delta)
28
+
29
+ ######################
30
+
31
+ frame = torch.from_numpy(front_img).to(device)
32
+
33
+ out = torch.zeros((2, 2,2)).to(device)
34
+
35
+ y = (torch.ones((2, 2)).to(device).T *(torch.arange(0,Ho, step=Ho-1)).to(device)).T
36
+ x = torch.ones((2, 2)).to(device) *torch.arange(0, Wo, step=Wo-1).to(device)
37
+ l0 = torch.ones((2, 2)).to(device)*Ho - y
38
+ l1 = torch.ones((2, 2)).to(device) * f_+ l0
39
+
40
+ f1_0 = torch.arctan(h/l1)
41
+ f1_1 = torch.ones((2, 2)).to(device)*(torch.pi/2+theta) - f1_0
42
+ y_ = l0*torch.sin(f1_0)/torch.sin(f1_1)
43
+ j_p = torch.ones((2, 2)).to(device) * Hp - y_
44
+ i_p = torch.ones((2, 2)).to(device) * Wp/2 -(f_+torch.sin(torch.tensor(theta))*(torch.ones((2, 2)).to(device)*Hp-j_p))*(Wo/2*torch.ones((2, 2)).to(device)-x)/l1
45
+
46
+ out[:,:,0] = i_p.reshape((2, 2))
47
+ out[:,:,1] = j_p.reshape((2, 2))
48
+
49
+ four_point_org = out.permute(2,0,1)
50
+ four_point_new = torch.stack((x,y), dim = -1).permute(2,0,1)
51
+ four_point_org = four_point_org.unsqueeze(0).flatten(2).permute(0, 2, 1)
52
+ four_point_new = four_point_new.unsqueeze(0).flatten(2).permute(0, 2, 1)
53
+ H = get_perspective_transform(four_point_org, four_point_new)
54
+
55
+ scale1,scale2 = out_size/Wo,out_size/Ho
56
+ T3 = np.array([[scale1, 0, 0], [0, scale2, 0], [0, 0, 1]])
57
+ Homo = torch.matmul(torch.tensor(T3).unsqueeze(0).to(device).float(), H)
58
+ BEV = warp_perspective(frame.permute(2,0,1).unsqueeze(0).float(), Homo, (out_size,out_size))
59
+
60
+ BEV = BEV[0].cpu().int().permute(1,2,0).numpy().astype(np.uint8)
61
+
62
+ return BEV
63
+
64
+ @torch.no_grad()
65
+ def KittiBEV():
66
+ torch.cuda.empty_cache()
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown(
70
+ """
71
+ # HC-Net: Fine-Grained Cross-View Geo-Localization Using a Correlation-Aware Homography Estimator
72
+ ## Get BEV from front-view image.
73
+ """)
74
+
75
+ with gr.Row():
76
+ front_img = gr.Image(label="Front-view Image").style(height=450)
77
+ BEV_output = gr.Image(label="BEV Image").style(height=450)
78
+
79
+ fov = gr.Slider(1,90, value=20, label="FOV")
80
+ pitch = gr.Slider(-180, 180, value=0, label="Pitch")
81
+ scale = gr.Slider(1, 10, value=1.0, label="Scale")
82
+ out_size = gr.Slider(500, 1000, value=500, label="Out size")
83
+ btn = gr.Button(value="Get BEV Image")
84
+ btn.click(get_BEV_kitti,inputs= [front_img, fov, pitch, scale, out_size], outputs=BEV_output, queue=False)
85
+ gr.Markdown(
86
+ """
87
+ ### Note:
88
+ - If you wish to acquire **quantitative localization error results** for your uploaded data, kindly supply the real GPS for the ground image as well as the corresponding GPS for the center of the satellite image.
89
+ - When inputting GPS coordinates, please make sure their precision extends to **at least six decimal places**.
90
+ """)
91
+
92
+ gr.Markdown("## Image Examples")
93
+ gr.Examples(
94
+ examples=[['./figure/exp1.jpg', 27, 7, 6, 1000]],
95
+ inputs= [front_img, fov, pitch, scale, out_size],
96
+ outputs=[BEV_output],
97
+ fn=get_BEV_kitti,
98
+ cache_examples=False,
99
+ )
100
+ demo.launch(server_port=7981)
101
+
102
+ if __name__ == '__main__':
103
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
104
+ KittiBEV()
figure/exp1.jpg ADDED
models/utils/__pycache__/torch_geometry.cpython-38.pyc ADDED
Binary file (16 kB). View file
 
models/utils/torch_geometry.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+ from typing import Tuple, Optional
6
+
7
+
8
+ ##########################
9
+ #### from pytorch3d ####
10
+ ##########################
11
+ def _axis_angle_rotation(axis: str, angle):
12
+ """
13
+ Return the rotation matrices for one of the rotations about an axis
14
+ of which Euler angles describe, for each value of the angle given.
15
+
16
+ Args:
17
+ axis: Axis label "X" or "Y or "Z".
18
+ angle: any shape tensor of Euler angles in radians
19
+
20
+ Returns:
21
+ Rotation matrices as tensor of shape (..., 3, 3).
22
+ """
23
+
24
+ cos = torch.cos(angle)
25
+ sin = torch.sin(angle)
26
+ one = torch.ones_like(angle)
27
+ zero = torch.zeros_like(angle)
28
+
29
+ if axis == "X":
30
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
31
+ if axis == "Y":
32
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
33
+ if axis == "Z":
34
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
35
+
36
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
37
+
38
+ def euler_angles_to_matrix(euler_angles, convention: str):
39
+ """
40
+ Convert rotations given as Euler angles in radians to rotation matrices.
41
+
42
+ Args:
43
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
44
+ convention: Convention string of three uppercase letters from
45
+ {"X", "Y", and "Z"}.
46
+
47
+ Returns:
48
+ Rotation matrices as tensor of shape (..., 3, 3).
49
+ """
50
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
51
+ raise ValueError("Invalid input euler angles.")
52
+ if len(convention) != 3:
53
+ raise ValueError("Convention must have 3 letters.")
54
+ if convention[1] in (convention[0], convention[2]):
55
+ raise ValueError(f"Invalid convention {convention}.")
56
+ for letter in convention:
57
+ if letter not in ("X", "Y", "Z"):
58
+ raise ValueError(f"Invalid letter {letter} in convention string.")
59
+ matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
60
+ return functools.reduce(torch.matmul, matrices)
61
+
62
+ ###########################
63
+ #### from pytorchgemotry ####
64
+ ###########################
65
+ def get_perspective_transform(src, dst):
66
+ r"""Calculates a perspective transform from four pairs of the corresponding
67
+ points.
68
+
69
+ The function calculates the matrix of a perspective transform so that:
70
+
71
+ .. math ::
72
+
73
+ \begin{bmatrix}
74
+ t_{i}x_{i}^{'} \\
75
+ t_{i}y_{i}^{'} \\
76
+ t_{i} \\
77
+ \end{bmatrix}
78
+ =
79
+ \textbf{map_matrix} \cdot
80
+ \begin{bmatrix}
81
+ x_{i} \\
82
+ y_{i} \\
83
+ 1 \\
84
+ \end{bmatrix}
85
+
86
+ where
87
+
88
+ .. math ::
89
+ dst(i) = (x_{i}^{'},y_{i}^{'}), src(i) = (x_{i}, y_{i}), i = 0,1,2,3
90
+
91
+ Args:
92
+ src (Tensor): coordinates of quadrangle vertices in the source image.
93
+ dst (Tensor): coordinates of the corresponding quadrangle vertices in
94
+ the destination image.
95
+
96
+ Returns:
97
+ Tensor: the perspective transformation.
98
+
99
+ Shape:
100
+ - Input: :math:`(B, 4, 2)` and :math:`(B, 4, 2)`
101
+ - Output: :math:`(B, 3, 3)`
102
+ """
103
+ if not torch.is_tensor(src):
104
+ raise TypeError("Input type is not a torch.Tensor. Got {}"
105
+ .format(type(src)))
106
+ if not torch.is_tensor(dst):
107
+ raise TypeError("Input type is not a torch.Tensor. Got {}"
108
+ .format(type(dst)))
109
+ if not src.shape[-2:] == (4, 2):
110
+ raise ValueError("Inputs must be a Bx4x2 tensor. Got {}"
111
+ .format(src.shape))
112
+ if not src.shape == dst.shape:
113
+ raise ValueError("Inputs must have the same shape. Got {}"
114
+ .format(dst.shape))
115
+ if not (src.shape[0] == dst.shape[0]):
116
+ raise ValueError("Inputs must have same batch size dimension. Got {}"
117
+ .format(src.shape, dst.shape))
118
+
119
+ def ax(p, q):
120
+ ones = torch.ones_like(p)[..., 0:1]
121
+ zeros = torch.zeros_like(p)[..., 0:1]
122
+ return torch.cat(
123
+ [p[:, 0:1], p[:, 1:2], ones, zeros, zeros, zeros,
124
+ -p[:, 0:1] * q[:, 0:1], -p[:, 1:2] * q[:, 0:1]
125
+ ], dim=1)
126
+
127
+ def ay(p, q):
128
+ ones = torch.ones_like(p)[..., 0:1]
129
+ zeros = torch.zeros_like(p)[..., 0:1]
130
+ return torch.cat(
131
+ [zeros, zeros, zeros, p[:, 0:1], p[:, 1:2], ones,
132
+ -p[:, 0:1] * q[:, 1:2], -p[:, 1:2] * q[:, 1:2]], dim=1)
133
+ # we build matrix A by using only 4 point correspondence. The linear
134
+ # system is solved with the least square method, so here
135
+ # we could even pass more correspondence
136
+ p = []
137
+ p.append(ax(src[:, 0], dst[:, 0]))
138
+ p.append(ay(src[:, 0], dst[:, 0]))
139
+
140
+ p.append(ax(src[:, 1], dst[:, 1]))
141
+ p.append(ay(src[:, 1], dst[:, 1]))
142
+
143
+ p.append(ax(src[:, 2], dst[:, 2]))
144
+ p.append(ay(src[:, 2], dst[:, 2]))
145
+
146
+ p.append(ax(src[:, 3], dst[:, 3]))
147
+ p.append(ay(src[:, 3], dst[:, 3]))
148
+
149
+ # A is Bx8x8
150
+ A = torch.stack(p, dim=1)
151
+
152
+ # b is a Bx8x1
153
+ b = torch.stack([
154
+ dst[:, 0:1, 0], dst[:, 0:1, 1],
155
+ dst[:, 1:2, 0], dst[:, 1:2, 1],
156
+ dst[:, 2:3, 0], dst[:, 2:3, 1],
157
+ dst[:, 3:4, 0], dst[:, 3:4, 1],
158
+ ], dim=1)
159
+
160
+ # solve the system Ax = b
161
+ # X, LU = torch.gesv(b, A)
162
+ X = torch.linalg.solve(A, b)
163
+
164
+ # create variable to return
165
+ batch_size = src.shape[0]
166
+ M = torch.ones(batch_size, 9, device=src.device, dtype=src.dtype)
167
+ M[..., :8] = torch.squeeze(X, dim=-1)
168
+ return M.view(-1, 3, 3) # Bx3x3
169
+
170
+ def warp_perspective(src, M, dsize, flags='bilinear', border_mode=None,
171
+ border_value=0):
172
+ r"""Applies a perspective transformation to an image.
173
+
174
+ The function warp_perspective transforms the source image using
175
+ the specified matrix:
176
+
177
+ .. math::
178
+ \text{dst} (x, y) = \text{src} \left(
179
+ \frac{M_{11} x + M_{12} y + M_{13}}{M_{31} x + M_{32} y + M_{33}} ,
180
+ \frac{M_{21} x + M_{22} y + M_{23}}{M_{31} x + M_{32} y + M_{33}}
181
+ \right )
182
+
183
+ Args:
184
+ src (torch.Tensor): input image.
185
+ M (Tensor): transformation matrix.
186
+ dsize (tuple): size of the output image (height, width).
187
+
188
+ Returns:
189
+ Tensor: the warped input image.
190
+
191
+ Shape:
192
+ - Input: :math:`(B, C, H, W)` and :math:`(B, 3, 3)`
193
+ - Output: :math:`(B, C, H, W)`
194
+
195
+ .. note::
196
+ See a working example `here <https://github.com/arraiy/torchgeometry/
197
+ blob/master/examples/warp_perspective.ipynb>`_.
198
+ """
199
+ if not torch.is_tensor(src):
200
+ raise TypeError("Input src type is not a torch.Tensor. Got {}"
201
+ .format(type(src)))
202
+ if not torch.is_tensor(M):
203
+ raise TypeError("Input M type is not a torch.Tensor. Got {}"
204
+ .format(type(M)))
205
+ if not len(src.shape) == 4:
206
+ raise ValueError("Input src must be a BxCxHxW tensor. Got {}"
207
+ .format(src.shape))
208
+ if not (len(M.shape) == 3 or M.shape[-2:] == (3, 3)):
209
+ raise ValueError("Input M must be a Bx3x3 tensor. Got {}"
210
+ .format(src.shape))
211
+ # launches the warper
212
+ return transform_warp_impl(src, M, (src.shape[-2:]), dsize)
213
+
214
+
215
+ def transform_warp_impl(src, dst_pix_trans_src_pix, dsize_src, dsize_dst):
216
+ """Compute the transform in normalized cooridnates and perform the warping.
217
+ """
218
+ dst_norm_trans_dst_norm = dst_norm_to_dst_norm(
219
+ dst_pix_trans_src_pix, dsize_src, dsize_dst)
220
+ return homography_warp(src, torch.inverse(
221
+ dst_norm_trans_dst_norm), dsize_dst)
222
+
223
+ def dst_norm_to_dst_norm(dst_pix_trans_src_pix, dsize_src, dsize_dst):
224
+ # source and destination sizes
225
+ src_h, src_w = dsize_src
226
+ dst_h, dst_w = dsize_dst
227
+ # the devices and types
228
+ device = dst_pix_trans_src_pix.device
229
+ dtype = dst_pix_trans_src_pix.dtype
230
+ # compute the transformation pixel/norm for src/dst
231
+ src_norm_trans_src_pix = normal_transform_pixel(
232
+ src_h, src_w).to(device).to(dtype)
233
+ src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix)
234
+ dst_norm_trans_dst_pix = normal_transform_pixel(
235
+ dst_h, dst_w).to(device).to(dtype)
236
+ # compute chain transformations
237
+ dst_norm_trans_src_norm = torch.matmul(
238
+ dst_norm_trans_dst_pix, torch.matmul(
239
+ dst_pix_trans_src_pix, src_pix_trans_src_norm))
240
+ return dst_norm_trans_src_norm
241
+
242
+ def normal_transform_pixel(height, width):
243
+
244
+ tr_mat = torch.Tensor([[1.0, 0.0, -1.0],
245
+ [0.0, 1.0, -1.0],
246
+ [0.0, 0.0, 1.0]]) # 1x3x3
247
+
248
+ tr_mat[0, 0] = tr_mat[0, 0] * 2.0 / (width - 1.0)
249
+ tr_mat[1, 1] = tr_mat[1, 1] * 2.0 / (height - 1.0)
250
+
251
+ tr_mat = tr_mat.unsqueeze(0)
252
+
253
+ return tr_mat
254
+
255
+ def homography_warp(patch_src: torch.Tensor,
256
+ dst_homo_src: torch.Tensor,
257
+ dsize: Tuple[int, int],
258
+ mode: Optional[str] = 'bilinear',
259
+ padding_mode: Optional[str] = 'zeros') -> torch.Tensor:
260
+ r"""Function that warps image patchs or tensors by homographies.
261
+
262
+ See :class:`~torchgeometry.HomographyWarper` for details.
263
+
264
+ Args:
265
+ patch_src (torch.Tensor): The image or tensor to warp. Should be from
266
+ source of shape :math:`(N, C, H, W)`.
267
+ dst_homo_src (torch.Tensor): The homography or stack of homographies
268
+ from source to destination of shape
269
+ :math:`(N, 3, 3)`.
270
+ dsize (Tuple[int, int]): The height and width of the image to warp.
271
+ mode (Optional[str]): interpolation mode to calculate output values
272
+ 'bilinear' | 'nearest'. Default: 'bilinear'.
273
+ padding_mode (Optional[str]): padding mode for outside grid values
274
+ 'zeros' | 'border' | 'reflection'. Default: 'zeros'.
275
+
276
+ Return:
277
+ torch.Tensor: Patch sampled at locations from source to destination.
278
+
279
+ Example:
280
+ >>> input = torch.rand(1, 3, 32, 32)
281
+ >>> homography = torch.eye(3).view(1, 3, 3)
282
+ >>> output = tgm.homography_warp(input, homography, (32, 32)) # NxCxHxW
283
+ """
284
+ height, width = dsize
285
+ warper = HomographyWarper(height, width, mode, padding_mode)
286
+ return warper(patch_src, dst_homo_src)
287
+
288
+ class HomographyWarper(nn.Module):
289
+ r"""Warps image patches or tensors by homographies.
290
+
291
+ .. math::
292
+
293
+ X_{dst} = H_{src}^{\{dst\}} * X_{src}
294
+
295
+ Args:
296
+ height (int): The height of the image to warp.
297
+ width (int): The width of the image to warp.
298
+ mode (Optional[str]): interpolation mode to calculate output values
299
+ 'bilinear' | 'nearest'. Default: 'bilinear'.
300
+ padding_mode (Optional[str]): padding mode for outside grid values
301
+ 'zeros' | 'border' | 'reflection'. Default: 'zeros'.
302
+ normalized_coordinates (Optional[bool]): wether to use a grid with
303
+ normalized coordinates.
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ height: int,
309
+ width: int,
310
+ mode: Optional[str] = 'bilinear',
311
+ padding_mode: Optional[str] = 'zeros',
312
+ normalized_coordinates: Optional[bool] = True) -> None:
313
+ super(HomographyWarper, self).__init__()
314
+ self.width: int = width
315
+ self.height: int = height
316
+ self.mode: Optional[str] = mode
317
+ self.padding_mode: Optional[str] = padding_mode
318
+ self.normalized_coordinates: Optional[bool] = normalized_coordinates
319
+
320
+ # create base grid to compute the flow
321
+ self.grid: torch.Tensor = create_meshgrid(
322
+ height, width, normalized_coordinates=normalized_coordinates)
323
+
324
+ def warp_grid(self, dst_homo_src: torch.Tensor) -> torch.Tensor:
325
+ r"""Computes the grid to warp the coordinates grid by an homography.
326
+
327
+ Args:
328
+ dst_homo_src (torch.Tensor): Homography or homographies (stacked) to
329
+ transform all points in the grid. Shape of the
330
+ homography has to be :math:`(N, 3, 3)`.
331
+
332
+ Returns:
333
+ torch.Tensor: the transformed grid of shape :math:`(N, H, W, 2)`.
334
+ """
335
+ batch_size: int = dst_homo_src.shape[0]
336
+ device: torch.device = dst_homo_src.device
337
+ dtype: torch.dtype = dst_homo_src.dtype
338
+ # expand grid to match the input batch size
339
+ grid: torch.Tensor = self.grid.expand(batch_size, -1, -1, -1) # NxHxWx2
340
+ if len(dst_homo_src.shape) == 3: # local homography case
341
+ dst_homo_src = dst_homo_src.view(batch_size, 1, 3, 3) # NxHxWx3x3
342
+ # perform the actual grid transformation,
343
+ # the grid is copied to input device and casted to the same type
344
+ flow: torch.Tensor = transform_points(
345
+ dst_homo_src, grid.to(device).to(dtype)) # NxHxWx2
346
+ return flow.view(batch_size, self.height, self.width, 2) # NxHxWx2
347
+
348
+ def forward(
349
+ self,
350
+ patch_src: torch.Tensor,
351
+ dst_homo_src: torch.Tensor) -> torch.Tensor:
352
+ r"""Warps an image or tensor from source into reference frame.
353
+
354
+ Args:
355
+ patch_src (torch.Tensor): The image or tensor to warp.
356
+ Should be from source.
357
+ dst_homo_src (torch.Tensor): The homography or stack of homographies
358
+ from source to destination. The homography assumes normalized
359
+ coordinates [-1, 1].
360
+
361
+ Return:
362
+ torch.Tensor: Patch sampled at locations from source to destination.
363
+
364
+ Shape:
365
+ - Input: :math:`(N, C, H, W)` and :math:`(N, 3, 3)`
366
+ - Output: :math:`(N, C, H, W)`
367
+
368
+ Example:
369
+ >>> input = torch.rand(1, 3, 32, 32)
370
+ >>> homography = torch.eye(3).view(1, 3, 3)
371
+ >>> warper = tgm.HomographyWarper(32, 32)
372
+ >>> output = warper(input, homography) # NxCxHxW
373
+ """
374
+ if not dst_homo_src.device == patch_src.device:
375
+ raise TypeError("Patch and homography must be on the same device. \
376
+ Got patch.device: {} dst_H_src.device: {}."
377
+ .format(patch_src.device, dst_homo_src.device))
378
+ return F.grid_sample(patch_src, self.warp_grid(dst_homo_src),
379
+ mode=self.mode, padding_mode=self.padding_mode)
380
+
381
+
382
+
383
+ def create_meshgrid(
384
+ height: int,
385
+ width: int,
386
+ normalized_coordinates: Optional[bool] = True):
387
+ """Generates a coordinate grid for an image.
388
+
389
+ When the flag `normalized_coordinates` is set to True, the grid is
390
+ normalized to be in the range [-1,1] to be consistent with the pytorch
391
+ function grid_sample.
392
+ http://pytorch.org/docs/master/nn.html#torch.nn.functional.grid_sample
393
+
394
+ Args:
395
+ height (int): the image height (rows).
396
+ width (int): the image width (cols).
397
+ normalized_coordinates (Optional[bool]): wether to normalize
398
+ coordinates in the range [-1, 1] in order to be consistent with the
399
+ PyTorch function grid_sample.
400
+
401
+ Return:
402
+ torch.Tensor: returns a grid tensor with shape :math:`(1, H, W, 2)`.
403
+ """
404
+ # generate coordinates
405
+ xs: Optional[torch.Tensor] = None
406
+ ys: Optional[torch.Tensor] = None
407
+ if normalized_coordinates:
408
+ xs = torch.linspace(-1, 1, width)
409
+ ys = torch.linspace(-1, 1, height)
410
+ else:
411
+ xs = torch.linspace(0, width - 1, width)
412
+ ys = torch.linspace(0, height - 1, height)
413
+ # generate grid by stacking coordinates
414
+ base_grid: torch.Tensor = torch.stack(
415
+ torch.meshgrid([xs, ys])).transpose(1, 2) # 2xHxW
416
+ return torch.unsqueeze(base_grid, dim=0).permute(0, 2, 3, 1) # 1xHxWx2
417
+
418
+
419
+ def transform_points(trans_01: torch.Tensor,
420
+ points_1: torch.Tensor) -> torch.Tensor:
421
+ r"""Function that applies transformations to a set of points.
422
+
423
+ Args:
424
+ trans_01 (torch.Tensor): tensor for transformations of shape
425
+ :math:`(B, D+1, D+1)`.
426
+ points_1 (torch.Tensor): tensor of points of shape :math:`(B, N, D)`.
427
+ Returns:
428
+ torch.Tensor: tensor of N-dimensional points.
429
+
430
+ Shape:
431
+ - Output: :math:`(B, N, D)`
432
+
433
+ Examples:
434
+
435
+ >>> points_1 = torch.rand(2, 4, 3) # BxNx3
436
+ >>> trans_01 = torch.eye(4).view(1, 4, 4) # Bx4x4
437
+ >>> points_0 = tgm.transform_points(trans_01, points_1) # BxNx3
438
+ """
439
+ if not torch.is_tensor(trans_01) or not torch.is_tensor(points_1):
440
+ raise TypeError("Input type is not a torch.Tensor")
441
+ if not trans_01.device == points_1.device:
442
+ raise TypeError("Tensor must be in the same device")
443
+ if not trans_01.shape[0] == points_1.shape[0]:
444
+ raise ValueError("Input batch size must be the same for both tensors")
445
+ if not trans_01.shape[-1] == (points_1.shape[-1] + 1):
446
+ raise ValueError("Last input dimensions must differe by one unit")
447
+ # to homogeneous
448
+ points_1_h = convert_points_to_homogeneous(points_1) # BxNxD+1
449
+ # transform coordinates
450
+ points_0_h = torch.matmul(
451
+ trans_01.unsqueeze(1), points_1_h.unsqueeze(-1))
452
+ points_0_h = torch.squeeze(points_0_h, dim=-1)
453
+ # to euclidean
454
+ points_0 = convert_points_from_homogeneous(points_0_h) # BxNxD
455
+ return points_0
456
+
457
+
458
+ def convert_points_to_homogeneous(points):
459
+ r"""Function that converts points from Euclidean to homogeneous space.
460
+
461
+ See :class:`~torchgeometry.ConvertPointsToHomogeneous` for details.
462
+
463
+ Examples::
464
+
465
+ >>> input = torch.rand(2, 4, 3) # BxNx3
466
+ >>> output = tgm.convert_points_to_homogeneous(input) # BxNx4
467
+ """
468
+ if not torch.is_tensor(points):
469
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
470
+ type(points)))
471
+ if len(points.shape) < 2:
472
+ raise ValueError("Input must be at least a 2D tensor. Got {}".format(
473
+ points.shape))
474
+
475
+ return nn.functional.pad(points, (0, 1), "constant", 1.0)
476
+
477
+
478
+
479
+ def convert_points_from_homogeneous(points):
480
+ r"""Function that converts points from homogeneous to Euclidean space.
481
+
482
+ See :class:`~torchgeometry.ConvertPointsFromHomogeneous` for details.
483
+
484
+ Examples::
485
+
486
+ >>> input = torch.rand(2, 4, 3) # BxNx3
487
+ >>> output = tgm.convert_points_from_homogeneous(input) # BxNx2
488
+ """
489
+ if not torch.is_tensor(points):
490
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
491
+ type(points)))
492
+ if len(points.shape) < 2:
493
+ raise ValueError("Input must be at least a 2D tensor. Got {}".format(
494
+ points.shape))
495
+
496
+ return points[..., :-1] / points[..., -1:]