Spaces:
Runtime error
Runtime error
import os | |
from tqdm import tqdm | |
import torch | |
from torch.utils.data import DataLoader | |
from logger import Logger, Visualizer | |
import numpy as np | |
import imageio | |
def reconstruction(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset): | |
png_dir = os.path.join(log_dir, 'reconstruction/png') | |
log_dir = os.path.join(log_dir, 'reconstruction') | |
if checkpoint is not None: | |
Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector, | |
bg_predictor=bg_predictor, dense_motion_network=dense_motion_network) | |
else: | |
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.") | |
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) | |
if not os.path.exists(log_dir): | |
os.makedirs(log_dir) | |
if not os.path.exists(png_dir): | |
os.makedirs(png_dir) | |
loss_list = [] | |
inpainting_network.eval() | |
kp_detector.eval() | |
dense_motion_network.eval() | |
if bg_predictor: | |
bg_predictor.eval() | |
for it, x in tqdm(enumerate(dataloader)): | |
with torch.no_grad(): | |
predictions = [] | |
visualizations = [] | |
if torch.cuda.is_available(): | |
x['video'] = x['video'].cuda() | |
kp_source = kp_detector(x['video'][:, :, 0]) | |
for frame_idx in range(x['video'].shape[2]): | |
source = x['video'][:, :, 0] | |
driving = x['video'][:, :, frame_idx] | |
kp_driving = kp_detector(driving) | |
bg_params = None | |
if bg_predictor: | |
bg_params = bg_predictor(source, driving) | |
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving, | |
kp_source=kp_source, bg_param = bg_params, | |
dropout_flag = False) | |
out = inpainting_network(source, dense_motion) | |
out['kp_source'] = kp_source | |
out['kp_driving'] = kp_driving | |
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) | |
visualization = Visualizer(**config['visualizer_params']).visualize(source=source, | |
driving=driving, out=out) | |
visualizations.append(visualization) | |
loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy() | |
loss_list.append(loss) | |
# print(np.mean(loss_list)) | |
predictions = np.concatenate(predictions, axis=1) | |
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8)) | |
print("Reconstruction loss: %s" % np.mean(loss_list)) | |