File size: 6,196 Bytes
ed697ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from argparse import Namespace
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import torch
from torch.utils.data import DataLoader

import sys
sys.path.append(".")
sys.path.append("..")

from configs import data_configs
from datasets.inference_dataset import InferenceDataset
from datasets.augmentations import AgeTransformer
from utils.common import log_image
from options.test_options import TestOptions
from models.psp import pSp


def run():
    test_opts = TestOptions().parse()

    out_path_results = os.path.join(test_opts.exp_dir, 'reference_guided_inference')
    os.makedirs(out_path_results, exist_ok=True)

    # update test options with options used during training
    ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
    opts = ckpt['opts']
    opts.update(vars(test_opts))
    opts = Namespace(**opts)

    net = pSp(opts)
    net.eval()
    net.cuda()

    age_transformers = [AgeTransformer(target_age=age) for age in opts.target_age.split(',')]

    print(f'Loading dataset for {opts.dataset_type}')
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()

    source_dataset = InferenceDataset(root=opts.data_path,
                                      transform=transforms_dict['transform_inference'],
                                      opts=opts)
    source_dataloader = DataLoader(source_dataset,
                                   batch_size=opts.test_batch_size,
                                   shuffle=False,
                                   num_workers=int(opts.test_workers),
                                   drop_last=False)

    ref_dataset = InferenceDataset(paths_list=opts.ref_images_paths_file,
                                   transform=transforms_dict['transform_inference'],
                                   opts=opts)
    ref_dataloader = DataLoader(ref_dataset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=1,
                                drop_last=False)

    if opts.n_images is None:
        opts.n_images = len(source_dataset)

    for age_transformer in age_transformers:
        target_age = age_transformer.target_age
        print(f"Running on target age: {target_age}")
        age_save_path = os.path.join(out_path_results, str(target_age))
        os.makedirs(age_save_path, exist_ok=True)
        global_i = 0
        for i, source_batch in enumerate(tqdm(source_dataloader)):
            if global_i >= opts.n_images:
                break
            results_per_source = {idx: [] for idx in range(len(source_batch))}
            with torch.no_grad():
                for ref_batch in ref_dataloader:
                    source_batch = source_batch.cuda().float()
                    ref_batch = ref_batch.cuda().float()
                    source_input_age_batch = [age_transformer(img.cpu()).to('cuda') for img in source_batch]
                    source_input_age_batch = torch.stack(source_input_age_batch)

                    # compute w+ of ref images to be injected for style-mixing
                    ref_latents = net.pretrained_encoder(ref_batch) + net.latent_avg

                    # run age transformation on source images with style-mixing
                    res_batch_mixed = run_on_batch(source_input_age_batch, net, opts, latent_to_inject=ref_latents)

                    # store results
                    for idx in range(len(source_batch)):
                        results_per_source[idx].append([ref_batch[0], res_batch_mixed[idx]])

                # save results
                resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024)
                for image_idx, image_results in results_per_source.items():
                    input_im_path = source_dataset.paths[global_i]
                    image = source_batch[image_idx]
                    input_image = log_image(image, opts)
                    # initialize results image
                    ref_inputs = np.zeros_like(input_image.resize(resize_amount))
                    mixing_results = np.array(input_image.resize(resize_amount))
                    for ref_idx in range(len(image_results)):
                        ref_input, mixing_result = image_results[ref_idx]
                        ref_input = log_image(ref_input, opts)
                        mixing_result = log_image(mixing_result, opts)
                        # append current results
                        ref_inputs = np.concatenate([ref_inputs,
                                                     np.array(ref_input.resize(resize_amount))],
                                                    axis=1)
                        mixing_results = np.concatenate([mixing_results,
                                                         np.array(mixing_result.resize(resize_amount))],
                                                        axis=1)
                    res = np.concatenate([ref_inputs, mixing_results], axis=0)
                    save_path = os.path.join(age_save_path, os.path.basename(input_im_path))
                    Image.fromarray(res).save(save_path)
                    global_i += 1


def run_on_batch(inputs, net, opts, latent_to_inject=None):
    if opts.latent_mask is None:
        result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs)
    else:
        latent_mask = [int(l) for l in opts.latent_mask.split(",")]
        result_batch = []
        for image_idx, input_image in enumerate(inputs):
            # get output image with injected style vector
            res, res_latent = net(input_image.unsqueeze(0).to("cuda").float(),
                                  latent_mask=latent_mask,
                                  inject_latent=latent_to_inject,
                                  alpha=opts.mix_alpha,
                                  resize=opts.resize_outputs,
                                  return_latents=True)
            result_batch.append(res)
        result_batch = torch.cat(result_batch, dim=0)
    return result_batch


if __name__ == '__main__':
    run()