alexnasa commited on
Commit
92c1bc3
·
verified ·
1 Parent(s): 5049961

Upload transform.py

Browse files
src/pixel3dmm/preprocessing/facer/facer/transform.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Callable, Tuple, Optional
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import functools
5
+ import numpy as np
6
+
7
+
8
+ def get_crop_and_resize_matrix(
9
+ box: torch.Tensor, target_shape: Tuple[int, int],
10
+ target_face_scale: float = 1.0, make_square_crop: bool = True,
11
+ offset_xy: Optional[Tuple[float, float]] = None, align_corners: bool = True,
12
+ offset_box_coords: bool = False) -> torch.Tensor:
13
+ """
14
+ Args:
15
+ box: b x 4(x1, y1, x2, y2)
16
+ align_corners (bool): Set this to `True` only if the box you give has coordinates
17
+ ranging from `0` to `h-1` or `w-1`.
18
+
19
+ offset_box_coords (bool): Set this to `True` if the box you give has coordinates
20
+ ranging from `0` to `h` or `w`.
21
+
22
+ Set this to `False` if the box coordinates range from `-0.5` to `h-0.5` or `w-0.5`.
23
+
24
+ If the box coordinates range from `0` to `h-1` or `w-1`, set `align_corners=True`.
25
+
26
+ Returns:
27
+ torch.Tensor: b x 3 x 3.
28
+ """
29
+ if offset_xy is None:
30
+ offset_xy = (0.0, 0.0)
31
+
32
+ x1, y1, x2, y2 = box.split(1, dim=1) # b x 1
33
+ cx = (x1 + x2) / 2 + offset_xy[0]
34
+ cy = (y1 + y2) / 2 + offset_xy[1]
35
+ rx = (x2 - x1) / 2 / target_face_scale
36
+ ry = (y2 - y1) / 2 / target_face_scale
37
+ if make_square_crop:
38
+ rx = ry = torch.maximum(rx, ry)
39
+
40
+ x1, y1, x2, y2 = cx - rx, cy - ry, cx + rx, cy + ry
41
+
42
+ h, w, *_ = target_shape
43
+
44
+ zeros_pl = torch.zeros_like(x1)
45
+ ones_pl = torch.ones_like(x1)
46
+
47
+ if align_corners:
48
+ # x -> (x - x1) / (x2 - x1) * (w - 1)
49
+ # y -> (y - y1) / (y2 - y1) * (h - 1)
50
+ ax = 1.0 / (x2 - x1) * (w - 1)
51
+ ay = 1.0 / (y2 - y1) * (h - 1)
52
+ matrix = torch.cat([
53
+ ax, zeros_pl, -x1 * ax,
54
+ zeros_pl, ay, -y1 * ay,
55
+ zeros_pl, zeros_pl, ones_pl
56
+ ], dim=1).reshape(-1, 3, 3) # b x 3 x 3
57
+ else:
58
+ if offset_box_coords:
59
+ # x1, x2 \in [0, w], y1, y2 \in [0, h]
60
+ # first we should offset x1, x2, y1, y2 to be ranging in
61
+ # [-0.5, w-0.5] and [-0.5, h-0.5]
62
+ # so to convert these pixel coordinates into boundary coordinates.
63
+ x1, x2, y1, y2 = x1-0.5, x2-0.5, y1-0.5, y2-0.5
64
+
65
+ # x -> (x - x1) / (x2 - x1) * w - 0.5
66
+ # y -> (y - y1) / (y2 - y1) * h - 0.5
67
+ ax = 1.0 / (x2 - x1) * w
68
+ ay = 1.0 / (y2 - y1) * h
69
+ matrix = torch.cat([
70
+ ax, zeros_pl, -x1 * ax - 0.5*ones_pl,
71
+ zeros_pl, ay, -y1 * ay - 0.5*ones_pl,
72
+ zeros_pl, zeros_pl, ones_pl
73
+ ], dim=1).reshape(-1, 3, 3) # b x 3 x 3
74
+ return matrix
75
+
76
+
77
+ def get_similarity_transform_matrix(
78
+ from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor:
79
+ """
80
+ Args:
81
+ from_pts, to_pts: b x n x 2
82
+
83
+ Returns:
84
+ torch.Tensor: b x 3 x 3
85
+ """
86
+ mfrom = from_pts.mean(dim=1, keepdim=True) # b x 1 x 2
87
+ mto = to_pts.mean(dim=1, keepdim=True) # b x 1 x 2
88
+
89
+ a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) # b
90
+ c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) # b
91
+
92
+ to_delta = to_pts - mto
93
+ from_delta = from_pts - mfrom
94
+ c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:,
95
+ :, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) # b
96
+
97
+ a = c1 / a1
98
+ b = c2 / a1
99
+ dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] # b
100
+ dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] # b
101
+
102
+ ones_pl = torch.ones_like(a1)
103
+ zeros_pl = torch.zeros_like(a1)
104
+
105
+ return torch.stack([
106
+ a, b, dx,
107
+ -b, a, dy,
108
+ zeros_pl, zeros_pl, ones_pl,
109
+ ], dim=-1).reshape(-1, 3, 3)
110
+
111
+
112
+ @functools.lru_cache()
113
+ def _standard_face_pts():
114
+ pts = torch.tensor([
115
+ 196.0, 226.0,
116
+ 316.0, 226.0,
117
+ 256.0, 286.0,
118
+ 220.0, 360.4,
119
+ 292.0, 360.4], dtype=torch.float32) / 256.0 - 1.0
120
+ return torch.reshape(pts, (5, 2))
121
+
122
+
123
+ def get_face_align_matrix(
124
+ face_pts: torch.Tensor, target_shape: Tuple[int, int],
125
+ target_face_scale: float = 1.0, offset_xy: Optional[Tuple[float, float]] = None,
126
+ target_pts: Optional[torch.Tensor] = None):
127
+
128
+ if target_pts is None:
129
+ with torch.no_grad():
130
+ std_pts = _standard_face_pts().to(face_pts) # [-1 1]
131
+ h, w, *_ = target_shape
132
+ target_pts = (std_pts * target_face_scale + 1) * \
133
+ torch.tensor([w-1, h-1]).to(face_pts) / 2.0
134
+ if offset_xy is not None:
135
+ target_pts[:, 0] += offset_xy[0]
136
+ target_pts[:, 1] += offset_xy[1]
137
+ else:
138
+ target_pts = target_pts.to(face_pts)
139
+
140
+ if target_pts.dim() == 2:
141
+ target_pts = target_pts.unsqueeze(0)
142
+ if target_pts.size(0) == 1:
143
+ target_pts = target_pts.broadcast_to(face_pts.shape)
144
+
145
+ assert target_pts.shape == face_pts.shape
146
+
147
+ return get_similarity_transform_matrix(face_pts, target_pts)
148
+
149
+
150
+ def rot90(v):
151
+ return np.array([-v[1], v[0]])
152
+
153
+
154
+ def get_quad(lm: torch.Tensor):
155
+ # N,2
156
+ lm = lm.detach().cpu().numpy()
157
+ # Choose oriented crop rectangle.
158
+ eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5
159
+ mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5
160
+ eye_to_eye = lm[1] - lm[0]
161
+ eye_to_mouth = mouth_avg - eye_avg
162
+ x = eye_to_eye - rot90(eye_to_mouth)
163
+ x /= np.hypot(*x)
164
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
165
+ y = rot90(x)
166
+ c = eye_avg + eye_to_mouth * 0.1
167
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
168
+ quad_for_coeffs = quad[[0,3, 2,1]] # 顺序改一下
169
+ return torch.from_numpy(quad_for_coeffs).float()
170
+
171
+
172
+ def get_face_align_matrix_celebm(
173
+ face_pts: torch.Tensor, target_shape: Tuple[int, int], bbox_scale_factor: float = 1.0):
174
+
175
+ face_pts = torch.stack([get_quad(pts) for pts in face_pts], dim=0).to(face_pts)
176
+ face_mean = face_pts.mean(axis=1).unsqueeze(1)
177
+ diff = face_pts - face_mean
178
+ face_pts = face_mean + torch.tensor([[[1.5, 1.5]]], device=diff.device)*diff
179
+ assert target_shape[0] == target_shape[1]
180
+ diagonal = torch.norm(face_pts[:, 0, :] - face_pts[:, 2, :], dim=-1)
181
+ min_bbox_size = 350
182
+ max_bbox_size = 500
183
+ bbox_scale_factor = bbox_scale_factor + torch.clamp((max_bbox_size-diagonal)/(max_bbox_size-min_bbox_size), 0, 1)
184
+ print(bbox_scale_factor)
185
+ target_size = target_shape[0]/bbox_scale_factor
186
+ #target_pts = torch.as_tensor([[0, 0], [target_size,0], [target_size, target_size], [0, target_size]]).to(face_pts)
187
+ target_ptss = []
188
+ for tidx in range(target_size.shape[0]):
189
+ target_pts = torch.as_tensor([[0, 0], [target_size[tidx],0], [target_size[tidx], target_size[tidx]], [0, target_size[tidx]]]).to(face_pts)
190
+ target_pts += int( (target_shape[0]-target_size[tidx])/2 )
191
+ target_ptss.append(target_pts)
192
+ target_pts = torch.stack(target_ptss, dim=0)
193
+
194
+ #if target_pts.dim() == 2:
195
+ # target_pts = target_pts.unsqueeze(0)
196
+ #if target_pts.size(0) == 1:
197
+ # target_pts = target_pts.broadcast_to(face_pts.shape)
198
+
199
+ assert target_pts.shape == face_pts.shape
200
+
201
+ return get_similarity_transform_matrix(face_pts, target_pts)
202
+
203
+ @functools.lru_cache(maxsize=128)
204
+ def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]:
205
+ yy, xx = torch.meshgrid(torch.arange(h).float(),
206
+ torch.arange(w).float(),
207
+ indexing='ij')
208
+ return yy, xx
209
+
210
+
211
+ def _forge_grid(batch_size: int, device: torch.device,
212
+ output_shape: Tuple[int, int],
213
+ fn: Callable[[torch.Tensor], torch.Tensor]
214
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
215
+ """ Forge transform maps with a given function `fn`.
216
+
217
+ Args:
218
+ output_shape (tuple): (b, h, w, ...).
219
+ fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts
220
+ a bxnx2 array and outputs the transformed bxnx2 array. Both input
221
+ and output store (x, y) coordinates.
222
+
223
+ Note:
224
+ both input and output arrays of `fn` should store (y, x) coordinates.
225
+
226
+ Returns:
227
+ Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each
228
+ pixel (y, x) or coordinate (x, y),
229
+ `(X[y, x], Y[y, x]) = fn([x, y])`
230
+ """
231
+ h, w, *_ = output_shape
232
+ yy, xx = _meshgrid(h, w) # h x w
233
+ yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
234
+ xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
235
+
236
+ in_xxyy = torch.stack(
237
+ [xx, yy], dim=-1).reshape([batch_size, h*w, 2]) # (h x w) x 2
238
+ out_xxyy: torch.Tensor = fn(in_xxyy) # (h x w) x 2
239
+ return out_xxyy.reshape(batch_size, h, w, 2)
240
+
241
+
242
+ def _safe_arctanh(x: torch.Tensor, eps: float = 0.001) -> torch.Tensor:
243
+ return torch.clamp(x, -1+eps, 1-eps).arctanh()
244
+
245
+
246
+ def inverted_tanh_warp_transform(coords: torch.Tensor, matrix: torch.Tensor,
247
+ warp_factor: float, warped_shape: Tuple[int, int]):
248
+ """ Inverted tanh-warp function.
249
+
250
+ Args:
251
+ coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates.
252
+ matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
253
+ from the original image to the aligned yet not-warped image.
254
+ warp_factor (float): The warp factor.
255
+ 0 means linear transform, 1 means full tanh warp.
256
+ warped_shape (tuple): [height, width].
257
+
258
+ Returns:
259
+ torch.Tensor: b x n x 2 (x, y). The original coordinates.
260
+ """
261
+ h, w, *_ = warped_shape
262
+ # h -= 1
263
+ # w -= 1
264
+
265
+ w_h = torch.tensor([[w, h]]).to(coords)
266
+
267
+ if warp_factor > 0:
268
+ # normalize coordinates to [-1, +1]
269
+ coords = coords / w_h * 2 - 1
270
+
271
+ nl_part1 = coords > 1.0 - warp_factor
272
+ nl_part2 = coords < -1.0 + warp_factor
273
+
274
+ ret_nl_part1 = _safe_arctanh(
275
+ (coords - 1.0 + warp_factor) /
276
+ warp_factor) * warp_factor + \
277
+ 1.0 - warp_factor
278
+ ret_nl_part2 = _safe_arctanh(
279
+ (coords + 1.0 - warp_factor) /
280
+ warp_factor) * warp_factor - \
281
+ 1.0 + warp_factor
282
+
283
+ coords = torch.where(nl_part1, ret_nl_part1,
284
+ torch.where(nl_part2, ret_nl_part2, coords))
285
+
286
+ # denormalize
287
+ coords = (coords + 1) / 2 * w_h
288
+
289
+ coords_homo = torch.cat(
290
+ [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
291
+
292
+ inv_matrix = torch.linalg.inv(matrix) # b x 3 x 3
293
+ # inv_matrix = np.linalg.inv(matrix)
294
+ coords_homo = torch.bmm(
295
+ coords_homo, inv_matrix.permute(0, 2, 1)) # b x n x 3
296
+ return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]
297
+
298
+
299
+ def tanh_warp_transform(
300
+ coords: torch.Tensor, matrix: torch.Tensor,
301
+ warp_factor: float, warped_shape: Tuple[int, int]):
302
+ """ Tanh-warp function.
303
+
304
+ Args:
305
+ coords (torch.Tensor): b x n x 2 (x, y). The original coordinates.
306
+ matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
307
+ from the original image to the aligned yet not-warped image.
308
+ warp_factor (float): The warp factor.
309
+ 0 means linear transform, 1 means full tanh warp.
310
+ warped_shape (tuple): [height, width].
311
+
312
+ Returns:
313
+ torch.Tensor: b x n x 2 (x, y). The transformed coordinates.
314
+ """
315
+ h, w, *_ = warped_shape
316
+ # h -= 1
317
+ # w -= 1
318
+ w_h = torch.tensor([[w, h]]).to(coords)
319
+
320
+ coords_homo = torch.cat(
321
+ [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
322
+
323
+ coords_homo = torch.bmm(coords_homo, matrix.transpose(2, 1)) # b x n x 3
324
+ coords = (coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]) # b x n x 2
325
+
326
+ if warp_factor > 0:
327
+ # normalize coordinates to [-1, +1]
328
+ coords = coords / w_h * 2 - 1
329
+
330
+ nl_part1 = coords > 1.0 - warp_factor
331
+ nl_part2 = coords < -1.0 + warp_factor
332
+
333
+ ret_nl_part1 = torch.tanh(
334
+ (coords - 1.0 + warp_factor) /
335
+ warp_factor) * warp_factor + \
336
+ 1.0 - warp_factor
337
+ ret_nl_part2 = torch.tanh(
338
+ (coords + 1.0 - warp_factor) /
339
+ warp_factor) * warp_factor - \
340
+ 1.0 + warp_factor
341
+
342
+ coords = torch.where(nl_part1, ret_nl_part1,
343
+ torch.where(nl_part2, ret_nl_part2, coords))
344
+
345
+ # denormalize
346
+ coords = (coords + 1) / 2 * w_h
347
+
348
+ return coords
349
+
350
+
351
+ def make_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
352
+ warped_shape: Tuple[int, int],
353
+ orig_shape: Tuple[int, int]):
354
+ """
355
+ Args:
356
+ matrix: bx3x3 matrix.
357
+ warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
358
+ `warp_factor=0.0` represents a cropping.
359
+ warped_shape: The target image shape to transform to.
360
+
361
+ Returns:
362
+ torch.Tensor: b x h x w x 2 (x, y).
363
+ """
364
+ orig_h, orig_w, *_ = orig_shape
365
+ w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2)
366
+ return _forge_grid(
367
+ matrix.size(0), matrix.device,
368
+ warped_shape,
369
+ functools.partial(inverted_tanh_warp_transform,
370
+ matrix=matrix,
371
+ warp_factor=warp_factor,
372
+ warped_shape=warped_shape)) / w_h*2-1
373
+
374
+
375
+ def make_inverted_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
376
+ warped_shape: Tuple[int, int],
377
+ orig_shape: Tuple[int, int]):
378
+ """
379
+ Args:
380
+ matrix: bx3x3 matrix.
381
+ warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
382
+ `warp_factor=0.0` represents a cropping.
383
+ warped_shape: The target image shape to transform to.
384
+ orig_shape: The original image shape that is transformed from.
385
+
386
+ Returns:
387
+ torch.Tensor: b x h x w x 2 (x, y).
388
+ """
389
+ h, w, *_ = warped_shape
390
+ w_h = torch.tensor([w, h]).to(matrix).reshape(1, 1, 1, 2)
391
+ return _forge_grid(
392
+ matrix.size(0), matrix.device,
393
+ orig_shape,
394
+ functools.partial(tanh_warp_transform,
395
+ matrix=matrix,
396
+ warp_factor=warp_factor,
397
+ warped_shape=warped_shape)) / w_h * 2-1