alexnasa commited on
Commit
be22660
·
verified ·
1 Parent(s): 901df39

Upload network_inference.py

Browse files
Files changed (1) hide show
  1. src/pixel3dmm/network_inference.py +191 -0
src/pixel3dmm/network_inference.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+
3
+ from tqdm import tqdm
4
+ import os
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from omegaconf import OmegaConf
9
+ from time import time
10
+
11
+ from pixel3dmm.utils.uv import uv_pred_to_mesh
12
+ from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
13
+ #from pixel3dmm.lightning.system_flame_params_legacy import system as system_flame_params_legacy
14
+ from pixel3dmm import env_paths
15
+
16
+
17
+ def pad_to_3_channels(img):
18
+ if img.shape[-1] == 3:
19
+ return img
20
+ elif img.shape[-1] == 1:
21
+ return np.concatenate([img, np.zeros_like(img[..., :1]), np.zeros_like(img[..., :1])], axis=-1)
22
+ elif img.shape[-1] == 2:
23
+ return np.concatenate([img, np.zeros_like(img[..., :1])], axis=-1)
24
+ else:
25
+ raise ValueError('too many dimensions in prediction type!')
26
+
27
+ def gaussian_fn(M, std):
28
+ n = torch.arange(0, M) - (M - 1.0) / 2.0
29
+ sig2 = 2 * std * std
30
+ w = torch.exp(-n ** 2 / sig2)
31
+ return w
32
+
33
+
34
+ def gkern(kernlen=256, std=128):
35
+ """Returns a 2D Gaussian kernel array."""
36
+ gkern1d_x = gaussian_fn(kernlen, std=std * 5)
37
+ gkern1d_y = gaussian_fn(kernlen, std=std)
38
+ gkern2d = torch.outer(gkern1d_y, gkern1d_x)
39
+ return gkern2d
40
+
41
+
42
+ valid_verts = np.load(f'{env_paths.VALID_VERTICES_WIDE_REGION}')
43
+
44
+
45
+ def normals_n_uvs(cfg, model):
46
+ if cfg.model.prediction_type == 'flame_params':
47
+ cfg.data.mirror_aug = False
48
+
49
+ # data loader
50
+ if cfg.model.feature_map_type == 'DINO':
51
+ feature_map_size = 32
52
+ elif cfg.model.feature_map_type == 'sapiens':
53
+ feature_map_size = 64
54
+
55
+ batch_size = 1 # cfg.inference_batch_size
56
+
57
+ prediction_types = cfg.model.prediction_type.split(',')
58
+
59
+ conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=11, bias=False, padding='same')
60
+ g_weights = gkern(11, 2)
61
+ g_weights /= torch.sum(g_weights)
62
+ conv.weight = torch.nn.Parameter(g_weights.unsqueeze(0).unsqueeze(0))
63
+
64
+ OUT_NAMES = str(cfg.video_name).split(',')
65
+
66
+ print(f"""
67
+ <<<<<<<< STARTING PIXEL3DMM INFERENCE for {cfg.video_name} in {prediction_types} MODE >>>>>>>>
68
+ """)
69
+
70
+ for OUT_NAME in OUT_NAMES:
71
+ folder = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/'
72
+ IMAGE_FOLDER = f'{folder}/cropped'
73
+ SEGMENTATION_FOLDER = f'{folder}/seg_og/'
74
+
75
+ out_folders = {}
76
+ out_folders_wGT = {}
77
+ out_folders_viz = {}
78
+
79
+ for prediction_type in prediction_types:
80
+ out_folders[prediction_type] = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/p3dmm/{prediction_type}/'
81
+ out_folders_wGT[prediction_type] = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/p3dmm_wGT/{prediction_type}/'
82
+ os.makedirs(out_folders[prediction_type], exist_ok=True)
83
+ os.makedirs(out_folders_wGT[prediction_type], exist_ok=True)
84
+ out_folders_viz[prediction_type] = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/p3dmm_extraViz/{prediction_type}/'
85
+ os.makedirs(out_folders_viz[prediction_type], exist_ok=True)
86
+
87
+
88
+ image_names = os.listdir(f'{IMAGE_FOLDER}')
89
+ image_names.sort()
90
+
91
+ if os.path.exists(out_folders[prediction_type]):
92
+ if len(os.listdir(out_folders[prediction_type])) == len(image_names):
93
+ return
94
+
95
+ for i in tqdm(range(len(image_names))):
96
+ try:
97
+ img = np.array(Image.open(f'{IMAGE_FOLDER}/{image_names[i]}').resize((512, 512))) / 255 # need 512,512 images as input; normalize to [0, 1] range
98
+ img = torch.from_numpy(img)[None, None].float().cuda() # 1,1,512,512,3
99
+ img_seg = np.array(Image.open(f'{SEGMENTATION_FOLDER}/{image_names[i][:-4]}.png').resize((512, 512), Image.NEAREST))
100
+ if len(img_seg.shape) == 3:
101
+ img_seg = img_seg[..., 0]
102
+ #img_seg = np.array(Image.open(f'{SEGEMNTATION_FOLDER}/{int(image_names[i][:-4])*3:05d}.png').resize((512, 512), Image.NEAREST))
103
+ mask = ((img_seg == 2) | ((img_seg > 3) & (img_seg < 14)) ) & ~(img_seg==11)
104
+ mask = torch.from_numpy(mask).long().cuda()[None, None] # 1, 1, 512, 512
105
+ #mask = torch.ones_like(img[..., 0]).cuda().bool()
106
+ batch = {
107
+ 'tar_msk': mask,
108
+ 'tar_rgb': img,
109
+ }
110
+ batch_mirrored = {
111
+ 'tar_rgb': torch.flip(batch['tar_rgb'], dims=[3]).cuda(),
112
+ 'tar_msk': torch.flip(batch['tar_msk'], dims=[3]).cuda(),
113
+ }
114
+
115
+ with torch.no_grad():
116
+ output, conf = model.net(batch)
117
+ output_mirrored, conf = model.net(batch_mirrored)
118
+
119
+ if 'uv_map' in output:
120
+ fliped_uv_pred = torch.flip(output_mirrored['uv_map'], dims=[4])
121
+ fliped_uv_pred[:, :, 0, :, :] *= -1
122
+ fliped_uv_pred[:, :, 0, :, :] += 2*0.0075
123
+ output['uv_map'] = (output['uv_map'] + fliped_uv_pred)/2
124
+ if 'normals' in output:
125
+ fliped_uv_pred = torch.flip(output_mirrored['normals'], dims=[4])
126
+ fliped_uv_pred[:, :, 0, :, :] *= -1
127
+ output['normals'] = (output['normals'] + fliped_uv_pred)/2
128
+ if 'disps' in output:
129
+ fliped_uv_pred = torch.flip(output_mirrored['disps'], dims=[4])
130
+ fliped_uv_pred[:, :, 0, :, :] *= -1
131
+ output['disps'] = (output['disps'] + fliped_uv_pred)/2
132
+
133
+
134
+
135
+ for prediction_type in prediction_types:
136
+ for i_batch in range(batch_size):
137
+
138
+ i_view = 0
139
+ gt_rgb = batch['tar_rgb']
140
+
141
+ # normalize to [0,1] range
142
+ if prediction_type == 'uv_map':
143
+ tmp_output = torch.clamp((output[prediction_type][i_batch, i_view] + 1) / 2, 0, 1)
144
+ elif prediction_type == 'disps':
145
+ tmp_output = torch.clamp((output[prediction_type][i_batch, i_view] + 50) / 100, 0, 1)
146
+ elif prediction_type in ['normals', 'normals_can']:
147
+ tmp_output = output[prediction_type][i_batch, i_view]
148
+ tmp_output = tmp_output / torch.norm(tmp_output, dim=0).unsqueeze(0)
149
+ tmp_output = torch.clamp((tmp_output + 1) / 2, 0, 1)
150
+ # undo "weird" convention of normals that I used for preprocessing
151
+ tmp_output = torch.stack(
152
+ [tmp_output[0, ...], 1 - tmp_output[2, ...], 1 - tmp_output[1, ...]],
153
+ dim=0)
154
+
155
+
156
+ content = [
157
+ gt_rgb[i_batch, i_view].detach().cpu().numpy(),
158
+ pad_to_3_channels(tmp_output.permute(1, 2, 0).detach().cpu().float().numpy()),
159
+ ]
160
+
161
+ catted = (np.concatenate(content, axis=1) * 255).astype(np.uint8)
162
+ Image.fromarray(catted).save(f'{out_folders_wGT[prediction_type]}/{image_names[i]}')
163
+
164
+
165
+ Image.fromarray(
166
+ pad_to_3_channels(
167
+ tmp_output.permute(1, 2, 0).detach().cpu().float().numpy() * 255).astype(
168
+ np.uint8)).save(
169
+ f'{out_folders[prediction_type]}/{image_names[i][:-4]}.png')
170
+
171
+
172
+ # this visulization is quite slow, therefore disable it per default
173
+ if prediction_type == 'uv_map' and cfg.viz_uv_mesh:
174
+ to_show_non_mirr = uv_pred_to_mesh(
175
+ output[prediction_type][i_batch:i_batch + 1, ...],
176
+ batch['tar_msk'][i_batch:i_batch + 1, ...],
177
+ batch['tar_rgb'][i_batch:i_batch + 1, ...],
178
+ right_ear = [537, 1334, 857, 554, 941],
179
+ left_ear = [541, 476, 237, 502, 286],
180
+ )
181
+
182
+ Image.fromarray(to_show_non_mirr).save(f'{out_folders_viz[prediction_type]}/{image_names[i]}')
183
+
184
+
185
+ except Exception:
186
+ traceback.print_exc()
187
+
188
+ print(f"""
189
+ <<<<<<<< FINISHED PIXEL3DMM INFERENCE for {cfg.video_name} in {prediction_types} MODE >>>>>>>>
190
+ """)
191
+