alexnasa commited on
Commit
901df39
·
verified ·
1 Parent(s): e02fa45

Delete src/network_inference.py

Browse files
Files changed (1) hide show
  1. src/network_inference.py +0 -191
src/network_inference.py DELETED
@@ -1,191 +0,0 @@
1
- import traceback
2
-
3
- from tqdm import tqdm
4
- import os
5
- import torch
6
- import numpy as np
7
- from PIL import Image
8
- from omegaconf import OmegaConf
9
- from time import time
10
-
11
- from pixel3dmm.utils.uv import uv_pred_to_mesh
12
- from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
13
- #from pixel3dmm.lightning.system_flame_params_legacy import system as system_flame_params_legacy
14
- from pixel3dmm import env_paths
15
-
16
-
17
- def pad_to_3_channels(img):
18
- if img.shape[-1] == 3:
19
- return img
20
- elif img.shape[-1] == 1:
21
- return np.concatenate([img, np.zeros_like(img[..., :1]), np.zeros_like(img[..., :1])], axis=-1)
22
- elif img.shape[-1] == 2:
23
- return np.concatenate([img, np.zeros_like(img[..., :1])], axis=-1)
24
- else:
25
- raise ValueError('too many dimensions in prediction type!')
26
-
27
- def gaussian_fn(M, std):
28
- n = torch.arange(0, M) - (M - 1.0) / 2.0
29
- sig2 = 2 * std * std
30
- w = torch.exp(-n ** 2 / sig2)
31
- return w
32
-
33
-
34
- def gkern(kernlen=256, std=128):
35
- """Returns a 2D Gaussian kernel array."""
36
- gkern1d_x = gaussian_fn(kernlen, std=std * 5)
37
- gkern1d_y = gaussian_fn(kernlen, std=std)
38
- gkern2d = torch.outer(gkern1d_y, gkern1d_x)
39
- return gkern2d
40
-
41
-
42
- valid_verts = np.load(f'{env_paths.VALID_VERTICES_WIDE_REGION}')
43
-
44
-
45
- def normals_n_uvs(cfg, model):
46
- if cfg.model.prediction_type == 'flame_params':
47
- cfg.data.mirror_aug = False
48
-
49
- # data loader
50
- if cfg.model.feature_map_type == 'DINO':
51
- feature_map_size = 32
52
- elif cfg.model.feature_map_type == 'sapiens':
53
- feature_map_size = 64
54
-
55
- batch_size = 1 # cfg.inference_batch_size
56
-
57
- prediction_types = cfg.model.prediction_type.split(',')
58
-
59
- conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=11, bias=False, padding='same')
60
- g_weights = gkern(11, 2)
61
- g_weights /= torch.sum(g_weights)
62
- conv.weight = torch.nn.Parameter(g_weights.unsqueeze(0).unsqueeze(0))
63
-
64
- OUT_NAMES = str(cfg.video_name).split(',')
65
-
66
- print(f"""
67
- <<<<<<<< STARTING PIXEL3DMM INFERENCE for {cfg.video_name} in {prediction_types} MODE >>>>>>>>
68
- """)
69
-
70
- for OUT_NAME in OUT_NAMES:
71
- folder = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/'
72
- IMAGE_FOLDER = f'{folder}/cropped'
73
- SEGMENTATION_FOLDER = f'{folder}/seg_og/'
74
-
75
- out_folders = {}
76
- out_folders_wGT = {}
77
- out_folders_viz = {}
78
-
79
- for prediction_type in prediction_types:
80
- out_folders[prediction_type] = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/p3dmm/{prediction_type}/'
81
- out_folders_wGT[prediction_type] = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/p3dmm_wGT/{prediction_type}/'
82
- os.makedirs(out_folders[prediction_type], exist_ok=True)
83
- os.makedirs(out_folders_wGT[prediction_type], exist_ok=True)
84
- out_folders_viz[prediction_type] = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/p3dmm_extraViz/{prediction_type}/'
85
- os.makedirs(out_folders_viz[prediction_type], exist_ok=True)
86
-
87
-
88
- image_names = os.listdir(f'{IMAGE_FOLDER}')
89
- image_names.sort()
90
-
91
- if os.path.exists(out_folders[prediction_type]):
92
- if len(os.listdir(out_folders[prediction_type])) == len(image_names):
93
- return
94
-
95
- for i in tqdm(range(len(image_names))):
96
- try:
97
- img = np.array(Image.open(f'{IMAGE_FOLDER}/{image_names[i]}').resize((512, 512))) / 255 # need 512,512 images as input; normalize to [0, 1] range
98
- img = torch.from_numpy(img)[None, None].float().cuda() # 1,1,512,512,3
99
- img_seg = np.array(Image.open(f'{SEGMENTATION_FOLDER}/{image_names[i][:-4]}.png').resize((512, 512), Image.NEAREST))
100
- if len(img_seg.shape) == 3:
101
- img_seg = img_seg[..., 0]
102
- #img_seg = np.array(Image.open(f'{SEGEMNTATION_FOLDER}/{int(image_names[i][:-4])*3:05d}.png').resize((512, 512), Image.NEAREST))
103
- mask = ((img_seg == 2) | ((img_seg > 3) & (img_seg < 14)) ) & ~(img_seg==11)
104
- mask = torch.from_numpy(mask).long().cuda()[None, None] # 1, 1, 512, 512
105
- #mask = torch.ones_like(img[..., 0]).cuda().bool()
106
- batch = {
107
- 'tar_msk': mask,
108
- 'tar_rgb': img,
109
- }
110
- batch_mirrored = {
111
- 'tar_rgb': torch.flip(batch['tar_rgb'], dims=[3]).cuda(),
112
- 'tar_msk': torch.flip(batch['tar_msk'], dims=[3]).cuda(),
113
- }
114
-
115
- with torch.no_grad():
116
- output, conf = model.net(batch)
117
- output_mirrored, conf = model.net(batch_mirrored)
118
-
119
- if 'uv_map' in output:
120
- fliped_uv_pred = torch.flip(output_mirrored['uv_map'], dims=[4])
121
- fliped_uv_pred[:, :, 0, :, :] *= -1
122
- fliped_uv_pred[:, :, 0, :, :] += 2*0.0075
123
- output['uv_map'] = (output['uv_map'] + fliped_uv_pred)/2
124
- if 'normals' in output:
125
- fliped_uv_pred = torch.flip(output_mirrored['normals'], dims=[4])
126
- fliped_uv_pred[:, :, 0, :, :] *= -1
127
- output['normals'] = (output['normals'] + fliped_uv_pred)/2
128
- if 'disps' in output:
129
- fliped_uv_pred = torch.flip(output_mirrored['disps'], dims=[4])
130
- fliped_uv_pred[:, :, 0, :, :] *= -1
131
- output['disps'] = (output['disps'] + fliped_uv_pred)/2
132
-
133
-
134
-
135
- for prediction_type in prediction_types:
136
- for i_batch in range(batch_size):
137
-
138
- i_view = 0
139
- gt_rgb = batch['tar_rgb']
140
-
141
- # normalize to [0,1] range
142
- if prediction_type == 'uv_map':
143
- tmp_output = torch.clamp((output[prediction_type][i_batch, i_view] + 1) / 2, 0, 1)
144
- elif prediction_type == 'disps':
145
- tmp_output = torch.clamp((output[prediction_type][i_batch, i_view] + 50) / 100, 0, 1)
146
- elif prediction_type in ['normals', 'normals_can']:
147
- tmp_output = output[prediction_type][i_batch, i_view]
148
- tmp_output = tmp_output / torch.norm(tmp_output, dim=0).unsqueeze(0)
149
- tmp_output = torch.clamp((tmp_output + 1) / 2, 0, 1)
150
- # undo "weird" convention of normals that I used for preprocessing
151
- tmp_output = torch.stack(
152
- [tmp_output[0, ...], 1 - tmp_output[2, ...], 1 - tmp_output[1, ...]],
153
- dim=0)
154
-
155
-
156
- content = [
157
- gt_rgb[i_batch, i_view].detach().cpu().numpy(),
158
- pad_to_3_channels(tmp_output.permute(1, 2, 0).detach().cpu().float().numpy()),
159
- ]
160
-
161
- catted = (np.concatenate(content, axis=1) * 255).astype(np.uint8)
162
- Image.fromarray(catted).save(f'{out_folders_wGT[prediction_type]}/{image_names[i]}')
163
-
164
-
165
- Image.fromarray(
166
- pad_to_3_channels(
167
- tmp_output.permute(1, 2, 0).detach().cpu().float().numpy() * 255).astype(
168
- np.uint8)).save(
169
- f'{out_folders[prediction_type]}/{image_names[i][:-4]}.png')
170
-
171
-
172
- # this visulization is quite slow, therefore disable it per default
173
- if prediction_type == 'uv_map' and cfg.viz_uv_mesh:
174
- to_show_non_mirr = uv_pred_to_mesh(
175
- output[prediction_type][i_batch:i_batch + 1, ...],
176
- batch['tar_msk'][i_batch:i_batch + 1, ...],
177
- batch['tar_rgb'][i_batch:i_batch + 1, ...],
178
- right_ear = [537, 1334, 857, 554, 941],
179
- left_ear = [541, 476, 237, 502, 286],
180
- )
181
-
182
- Image.fromarray(to_show_non_mirr).save(f'{out_folders_viz[prediction_type]}/{image_names[i]}')
183
-
184
-
185
- except Exception:
186
- traceback.print_exc()
187
-
188
- print(f"""
189
- <<<<<<<< FINISHED PIXEL3DMM INFERENCE for {cfg.video_name} in {prediction_types} MODE >>>>>>>>
190
- """)
191
-