alexnasa commited on
Commit
5049961
·
verified ·
1 Parent(s): f0d65fe

Delete src/pixel3dmm/preprocessing/facer/facer/transform.py

Browse files
src/pixel3dmm/preprocessing/facer/facer/transform.py DELETED
@@ -1,384 +0,0 @@
1
- from typing import List, Dict, Callable, Tuple, Optional
2
- import torch
3
- import torch.nn.functional as F
4
- import functools
5
- import numpy as np
6
-
7
-
8
- def get_crop_and_resize_matrix(
9
- box: torch.Tensor, target_shape: Tuple[int, int],
10
- target_face_scale: float = 1.0, make_square_crop: bool = True,
11
- offset_xy: Optional[Tuple[float, float]] = None, align_corners: bool = True,
12
- offset_box_coords: bool = False) -> torch.Tensor:
13
- """
14
- Args:
15
- box: b x 4(x1, y1, x2, y2)
16
- align_corners (bool): Set this to `True` only if the box you give has coordinates
17
- ranging from `0` to `h-1` or `w-1`.
18
-
19
- offset_box_coords (bool): Set this to `True` if the box you give has coordinates
20
- ranging from `0` to `h` or `w`.
21
-
22
- Set this to `False` if the box coordinates range from `-0.5` to `h-0.5` or `w-0.5`.
23
-
24
- If the box coordinates range from `0` to `h-1` or `w-1`, set `align_corners=True`.
25
-
26
- Returns:
27
- torch.Tensor: b x 3 x 3.
28
- """
29
- if offset_xy is None:
30
- offset_xy = (0.0, 0.0)
31
-
32
- x1, y1, x2, y2 = box.split(1, dim=1) # b x 1
33
- cx = (x1 + x2) / 2 + offset_xy[0]
34
- cy = (y1 + y2) / 2 + offset_xy[1]
35
- rx = (x2 - x1) / 2 / target_face_scale
36
- ry = (y2 - y1) / 2 / target_face_scale
37
- if make_square_crop:
38
- rx = ry = torch.maximum(rx, ry)
39
-
40
- x1, y1, x2, y2 = cx - rx, cy - ry, cx + rx, cy + ry
41
-
42
- h, w, *_ = target_shape
43
-
44
- zeros_pl = torch.zeros_like(x1)
45
- ones_pl = torch.ones_like(x1)
46
-
47
- if align_corners:
48
- # x -> (x - x1) / (x2 - x1) * (w - 1)
49
- # y -> (y - y1) / (y2 - y1) * (h - 1)
50
- ax = 1.0 / (x2 - x1) * (w - 1)
51
- ay = 1.0 / (y2 - y1) * (h - 1)
52
- matrix = torch.cat([
53
- ax, zeros_pl, -x1 * ax,
54
- zeros_pl, ay, -y1 * ay,
55
- zeros_pl, zeros_pl, ones_pl
56
- ], dim=1).reshape(-1, 3, 3) # b x 3 x 3
57
- else:
58
- if offset_box_coords:
59
- # x1, x2 \in [0, w], y1, y2 \in [0, h]
60
- # first we should offset x1, x2, y1, y2 to be ranging in
61
- # [-0.5, w-0.5] and [-0.5, h-0.5]
62
- # so to convert these pixel coordinates into boundary coordinates.
63
- x1, x2, y1, y2 = x1-0.5, x2-0.5, y1-0.5, y2-0.5
64
-
65
- # x -> (x - x1) / (x2 - x1) * w - 0.5
66
- # y -> (y - y1) / (y2 - y1) * h - 0.5
67
- ax = 1.0 / (x2 - x1) * w
68
- ay = 1.0 / (y2 - y1) * h
69
- matrix = torch.cat([
70
- ax, zeros_pl, -x1 * ax - 0.5*ones_pl,
71
- zeros_pl, ay, -y1 * ay - 0.5*ones_pl,
72
- zeros_pl, zeros_pl, ones_pl
73
- ], dim=1).reshape(-1, 3, 3) # b x 3 x 3
74
- return matrix
75
-
76
-
77
- def get_similarity_transform_matrix(
78
- from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor:
79
- """
80
- Args:
81
- from_pts, to_pts: b x n x 2
82
-
83
- Returns:
84
- torch.Tensor: b x 3 x 3
85
- """
86
- mfrom = from_pts.mean(dim=1, keepdim=True) # b x 1 x 2
87
- mto = to_pts.mean(dim=1, keepdim=True) # b x 1 x 2
88
-
89
- a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) # b
90
- c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) # b
91
-
92
- to_delta = to_pts - mto
93
- from_delta = from_pts - mfrom
94
- c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:,
95
- :, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) # b
96
-
97
- a = c1 / a1
98
- b = c2 / a1
99
- dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] # b
100
- dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] # b
101
-
102
- ones_pl = torch.ones_like(a1)
103
- zeros_pl = torch.zeros_like(a1)
104
-
105
- return torch.stack([
106
- a, b, dx,
107
- -b, a, dy,
108
- zeros_pl, zeros_pl, ones_pl,
109
- ], dim=-1).reshape(-1, 3, 3)
110
-
111
-
112
- @functools.lru_cache()
113
- def _standard_face_pts():
114
- pts = torch.tensor([
115
- 196.0, 226.0,
116
- 316.0, 226.0,
117
- 256.0, 286.0,
118
- 220.0, 360.4,
119
- 292.0, 360.4], dtype=torch.float32) / 256.0 - 1.0
120
- return torch.reshape(pts, (5, 2))
121
-
122
-
123
- def get_face_align_matrix(
124
- face_pts: torch.Tensor, target_shape: Tuple[int, int],
125
- target_face_scale: float = 1.0, offset_xy: Optional[Tuple[float, float]] = None,
126
- target_pts: Optional[torch.Tensor] = None):
127
-
128
- if target_pts is None:
129
- with torch.no_grad():
130
- std_pts = _standard_face_pts().to(face_pts) # [-1 1]
131
- h, w, *_ = target_shape
132
- target_pts = (std_pts * target_face_scale + 1) * \
133
- torch.tensor([w-1, h-1]).to(face_pts) / 2.0
134
- if offset_xy is not None:
135
- target_pts[:, 0] += offset_xy[0]
136
- target_pts[:, 1] += offset_xy[1]
137
- else:
138
- target_pts = target_pts.to(face_pts)
139
-
140
- if target_pts.dim() == 2:
141
- target_pts = target_pts.unsqueeze(0)
142
- if target_pts.size(0) == 1:
143
- target_pts = target_pts.broadcast_to(face_pts.shape)
144
-
145
- assert target_pts.shape == face_pts.shape
146
-
147
- return get_similarity_transform_matrix(face_pts, target_pts)
148
-
149
-
150
- def rot90(v):
151
- return np.array([-v[1], v[0]])
152
-
153
-
154
- def get_quad(lm: torch.Tensor):
155
- # N,2
156
- lm = lm.detach().cpu().numpy()
157
- # Choose oriented crop rectangle.
158
- eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5
159
- mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5
160
- eye_to_eye = lm[1] - lm[0]
161
- eye_to_mouth = mouth_avg - eye_avg
162
- x = eye_to_eye - rot90(eye_to_mouth)
163
- x /= np.hypot(*x)
164
- x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
165
- y = rot90(x)
166
- c = eye_avg + eye_to_mouth * 0.1
167
- quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
168
- quad_for_coeffs = quad[[0,3, 2,1]] # 顺序改一下
169
- return torch.from_numpy(quad_for_coeffs).float()
170
-
171
-
172
- def get_face_align_matrix_celebm(
173
- face_pts: torch.Tensor, target_shape: Tuple[int, int]):
174
-
175
- face_pts = torch.stack([get_quad(pts) for pts in face_pts], dim=0).to(face_pts)
176
-
177
- assert target_shape[0] == target_shape[1]
178
- target_size = target_shape[0]
179
- target_pts = torch.as_tensor([[0, 0], [target_size,0], [target_size, target_size], [0, target_size]]).to(face_pts)
180
-
181
- if target_pts.dim() == 2:
182
- target_pts = target_pts.unsqueeze(0)
183
- if target_pts.size(0) == 1:
184
- target_pts = target_pts.broadcast_to(face_pts.shape)
185
-
186
- assert target_pts.shape == face_pts.shape
187
-
188
- return get_similarity_transform_matrix(face_pts, target_pts)
189
-
190
- @functools.lru_cache(maxsize=128)
191
- def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]:
192
- yy, xx = torch.meshgrid(torch.arange(h).float(),
193
- torch.arange(w).float(),
194
- indexing='ij')
195
- return yy + 0.5, xx + 0.5
196
-
197
-
198
- def _forge_grid(batch_size: int, device: torch.device,
199
- output_shape: Tuple[int, int],
200
- fn: Callable[[torch.Tensor], torch.Tensor]
201
- ) -> Tuple[torch.Tensor, torch.Tensor]:
202
- """ Forge transform maps with a given function `fn`.
203
-
204
- Args:
205
- output_shape (tuple): (b, h, w, ...).
206
- fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts
207
- a bxnx2 array and outputs the transformed bxnx2 array. Both input
208
- and output store (x, y) coordinates.
209
-
210
- Note:
211
- both input and output arrays of `fn` should store (y, x) coordinates.
212
-
213
- Returns:
214
- Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each
215
- pixel (y, x) or coordinate (x, y),
216
- `(X[y, x], Y[y, x]) = fn([x, y])`
217
- """
218
- h, w, *_ = output_shape
219
- yy, xx = _meshgrid(h, w) # h x w
220
- yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
221
- xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
222
-
223
- in_xxyy = torch.stack(
224
- [xx, yy], dim=-1).reshape([batch_size, h*w, 2]) # (h x w) x 2
225
- out_xxyy: torch.Tensor = fn(in_xxyy) # (h x w) x 2
226
- return out_xxyy.reshape(batch_size, h, w, 2)
227
-
228
-
229
- def _safe_arctanh(x: torch.Tensor, eps: float = 0.001) -> torch.Tensor:
230
- return torch.clamp(x, -1+eps, 1-eps).arctanh()
231
-
232
-
233
- def inverted_tanh_warp_transform(coords: torch.Tensor, matrix: torch.Tensor,
234
- warp_factor: float, warped_shape: Tuple[int, int]):
235
- """ Inverted tanh-warp function.
236
-
237
- Args:
238
- coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates.
239
- matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
240
- from the original image to the aligned yet not-warped image.
241
- warp_factor (float): The warp factor.
242
- 0 means linear transform, 1 means full tanh warp.
243
- warped_shape (tuple): [height, width].
244
-
245
- Returns:
246
- torch.Tensor: b x n x 2 (x, y). The original coordinates.
247
- """
248
- h, w, *_ = warped_shape
249
- # h -= 1
250
- # w -= 1
251
-
252
- w_h = torch.tensor([[w, h]]).to(coords)
253
-
254
- if warp_factor > 0:
255
- # normalize coordinates to [-1, +1]
256
- coords = coords / w_h * 2 - 1
257
-
258
- nl_part1 = coords > 1.0 - warp_factor
259
- nl_part2 = coords < -1.0 + warp_factor
260
-
261
- ret_nl_part1 = _safe_arctanh(
262
- (coords - 1.0 + warp_factor) /
263
- warp_factor) * warp_factor + \
264
- 1.0 - warp_factor
265
- ret_nl_part2 = _safe_arctanh(
266
- (coords + 1.0 - warp_factor) /
267
- warp_factor) * warp_factor - \
268
- 1.0 + warp_factor
269
-
270
- coords = torch.where(nl_part1, ret_nl_part1,
271
- torch.where(nl_part2, ret_nl_part2, coords))
272
-
273
- # denormalize
274
- coords = (coords + 1) / 2 * w_h
275
-
276
- coords_homo = torch.cat(
277
- [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
278
-
279
- inv_matrix = torch.linalg.inv(matrix) # b x 3 x 3
280
- # inv_matrix = np.linalg.inv(matrix)
281
- coords_homo = torch.bmm(
282
- coords_homo, inv_matrix.permute(0, 2, 1)) # b x n x 3
283
- return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]
284
-
285
-
286
- def tanh_warp_transform(
287
- coords: torch.Tensor, matrix: torch.Tensor,
288
- warp_factor: float, warped_shape: Tuple[int, int]):
289
- """ Tanh-warp function.
290
-
291
- Args:
292
- coords (torch.Tensor): b x n x 2 (x, y). The original coordinates.
293
- matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
294
- from the original image to the aligned yet not-warped image.
295
- warp_factor (float): The warp factor.
296
- 0 means linear transform, 1 means full tanh warp.
297
- warped_shape (tuple): [height, width].
298
-
299
- Returns:
300
- torch.Tensor: b x n x 2 (x, y). The transformed coordinates.
301
- """
302
- h, w, *_ = warped_shape
303
- # h -= 1
304
- # w -= 1
305
- w_h = torch.tensor([[w, h]]).to(coords)
306
-
307
- coords_homo = torch.cat(
308
- [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
309
-
310
- coords_homo = torch.bmm(coords_homo, matrix.transpose(2, 1)) # b x n x 3
311
- coords = (coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]) # b x n x 2
312
-
313
- if warp_factor > 0:
314
- # normalize coordinates to [-1, +1]
315
- coords = coords / w_h * 2 - 1
316
-
317
- nl_part1 = coords > 1.0 - warp_factor
318
- nl_part2 = coords < -1.0 + warp_factor
319
-
320
- ret_nl_part1 = torch.tanh(
321
- (coords - 1.0 + warp_factor) /
322
- warp_factor) * warp_factor + \
323
- 1.0 - warp_factor
324
- ret_nl_part2 = torch.tanh(
325
- (coords + 1.0 - warp_factor) /
326
- warp_factor) * warp_factor - \
327
- 1.0 + warp_factor
328
-
329
- coords = torch.where(nl_part1, ret_nl_part1,
330
- torch.where(nl_part2, ret_nl_part2, coords))
331
-
332
- # denormalize
333
- coords = (coords + 1) / 2 * w_h
334
-
335
- return coords
336
-
337
-
338
- def make_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
339
- warped_shape: Tuple[int, int],
340
- orig_shape: Tuple[int, int]):
341
- """
342
- Args:
343
- matrix: bx3x3 matrix.
344
- warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
345
- `warp_factor=0.0` represents a cropping.
346
- warped_shape: The target image shape to transform to.
347
-
348
- Returns:
349
- torch.Tensor: b x h x w x 2 (x, y).
350
- """
351
- orig_h, orig_w, *_ = orig_shape
352
- w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2)
353
- return _forge_grid(
354
- matrix.size(0), matrix.device,
355
- warped_shape,
356
- functools.partial(inverted_tanh_warp_transform,
357
- matrix=matrix,
358
- warp_factor=warp_factor,
359
- warped_shape=warped_shape)) / w_h*2-1
360
-
361
-
362
- def make_inverted_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
363
- warped_shape: Tuple[int, int],
364
- orig_shape: Tuple[int, int]):
365
- """
366
- Args:
367
- matrix: bx3x3 matrix.
368
- warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
369
- `warp_factor=0.0` represents a cropping.
370
- warped_shape: The target image shape to transform to.
371
- orig_shape: The original image shape that is transformed from.
372
-
373
- Returns:
374
- torch.Tensor: b x h x w x 2 (x, y).
375
- """
376
- h, w, *_ = warped_shape
377
- w_h = torch.tensor([w, h]).to(matrix).reshape(1, 1, 1, 2)
378
- return _forge_grid(
379
- matrix.size(0), matrix.device,
380
- orig_shape,
381
- functools.partial(tanh_warp_transform,
382
- matrix=matrix,
383
- warp_factor=warp_factor,
384
- warped_shape=warped_shape)) / w_h * 2-1