thin-plate-spline-motion-model / reconstruction.py
AlekseyKorshuk's picture
feat: updates
f844f44
raw
history blame
2.93 kB
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from logger import Logger, Visualizer
import numpy as np
import imageio
def reconstruction(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
png_dir = os.path.join(log_dir, 'reconstruction/png')
log_dir = os.path.join(log_dir, 'reconstruction')
if checkpoint is not None:
Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
bg_predictor=bg_predictor, dense_motion_network=dense_motion_network)
else:
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(png_dir):
os.makedirs(png_dir)
loss_list = []
inpainting_network.eval()
kp_detector.eval()
dense_motion_network.eval()
if bg_predictor:
bg_predictor.eval()
for it, x in tqdm(enumerate(dataloader)):
with torch.no_grad():
predictions = []
visualizations = []
if torch.cuda.is_available():
x['video'] = x['video'].cuda()
kp_source = kp_detector(x['video'][:, :, 0])
for frame_idx in range(x['video'].shape[2]):
source = x['video'][:, :, 0]
driving = x['video'][:, :, frame_idx]
kp_driving = kp_detector(driving)
bg_params = None
if bg_predictor:
bg_params = bg_predictor(source, driving)
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
kp_source=kp_source, bg_param = bg_params,
dropout_flag = False)
out = inpainting_network(source, dense_motion)
out['kp_source'] = kp_source
out['kp_driving'] = kp_driving
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
driving=driving, out=out)
visualizations.append(visualization)
loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy()
loss_list.append(loss)
# print(np.mean(loss_list))
predictions = np.concatenate(predictions, axis=1)
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
print("Reconstruction loss: %s" % np.mean(loss_list))