File size: 2,462 Bytes
6709fc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os

from PIL import Image
import torch
from torch.utils.data import DataLoader

from datasets.inference_dataset import InferenceDataset
from datasets.process_image import ImageProcessor
from models.styleres import StyleRes
from options.inference_options import InferenceOptions
from options import Settings
from utils import parse_config
from tqdm import tqdm

def initialize_styleres(checkpoint_path, device):
    Settings.device = device
    model = StyleRes()
    model.load_ckpt(checkpoint_path)
    model.send_to_device()
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model

def run():
    args = InferenceOptions().parse()
    edit_configs = parse_config(args.edit_configs)
    if torch.cuda.is_available():
        device = torch.device('cuda'  if torch.cuda.is_available() else 'cpu')
    
    dataset = InferenceDataset(args.datadir, aligner_path=args.aligner_path)
    print(f"Dataset is created. Number of images is {len(dataset)}")
    dataloader = DataLoader(dataset, batch_size = args.test_batch_size,
                            shuffle=False,
                            num_workers=int(args.test_workers),
                            drop_last=False)

    if args.n_images == None:
        args.n_images = len(dataset)

    # Create output directories
    output_dir = args.outdir
    os.makedirs(output_dir, exist_ok=True)
    for edit_config in edit_configs:
        cfg_vals = edit_config.values()
        edit_config.outdir = '_'.join( str(i) for i in cfg_vals)
        os.makedirs( os.path.join(output_dir, edit_config.outdir), exist_ok=True)

    resize_amount = (1024, 1024)
    if args.resize_outputs:
        resize_amount = (256,256)
    
    # Setup model
    model = initialize_styleres(args.checkpoint_path, device)
        
    n_images = 0
    for data in tqdm(dataloader):
        if n_images >= args.n_images:
            break
        n_images = n_images + data['image'].shape[0]
        for edit_config in edit_configs:
            images = model.edit_images( data['image'], edit_config)
            images = ImageProcessor.postprocess_image(images.detach().cpu().numpy())
            for j in range( images.shape[0]):
                save_name = data['name'][j]
                pil_img = Image.fromarray(images[j]).resize(resize_amount)
                pil_img.save(os.path.join(output_dir,  edit_config.outdir, save_name))


if __name__ == '__main__':
    run()