File size: 5,775 Bytes
e30d598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import glob
import time
from xml.sax import parse
import numpy as np
import os
import cv2
import torch
import torchvision.transforms as transforms
from skimage import io
from basicsr.utils import imwrite, tensor2img
from basicsr.utils.face_util import FaceRestorationHelper
import torch.nn.functional as F

from basicsr.utils.registry import ARCH_REGISTRY

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    parser = argparse.ArgumentParser()

    parser.add_argument('--w', type=float, default=0.5, help='Balance the quality and fidelity')
    parser.add_argument('--upscale_factor', type=int, default=2)
    parser.add_argument('--test_path', type=str, default='./inputs/cropped_faces')
    parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
    parser.add_argument('--upsample_num_times', type=int, default=1, help='Upsample the image before face detection')
    parser.add_argument('--save_inverse_affine', action='store_true')
    parser.add_argument('--only_keep_largest', action='store_true')
    parser.add_argument('--draw_box', action='store_true')

    # The following are the paths for dlib models
    parser.add_argument(
        '--detection_path', type=str,
        default='weights/dlib/mmod_human_face_detector-4cb19393.dat'
    )
    parser.add_argument(
        '--landmark5_path', type=str,
        default='weights/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat'
    )
    parser.add_argument(
        '--landmark68_path', type=str,
        default='weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat'
    )

    args = parser.parse_args()
    if args.test_path.endswith('/'):  # solve when path ends with /
        args.test_path = args.test_path[:-1]

    w = args.w
    result_root = f'results/{os.path.basename(args.test_path)}_{w}'

    # set up the Network
    net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, 
                                            connect_list=['32', '64', '128', '256']).to(device)

    ckpt_path = 'weights/CodeFormer/codeformer.pth'
    checkpoint = torch.load(ckpt_path)['params_ema']
    net.load_state_dict(checkpoint)
    net.eval()


    save_crop_root = os.path.join(result_root, 'cropped_faces')
    save_restore_root = os.path.join(result_root, 'restored_faces')
    save_final_root = os.path.join(result_root, 'final_results')
    save_input_root = os.path.join(result_root, 'inputs')

    face_helper = FaceRestorationHelper(args.upscale_factor, face_size=512)
    face_helper.init_dlib(args.detection_path, args.landmark5_path, args.landmark68_path)

    # scan all the jpg and png images
    for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
        img_name = os.path.basename(img_path)
        print(f'Processing: {img_name}')
        if args.has_aligned: 
            # the input faces are already cropped and aligned
            img = cv2.imread(img_path, cv2.IMREAD_COLOR)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
            face_helper.cropped_faces = [img]
            cropped_faces = face_helper.cropped_faces
        else:
            # detect faces
            num_det_faces = face_helper.detect_faces(
                img_path, upsample_num_times=args.upsample_num_times, only_keep_largest=args.only_keep_largest)
            # get 5 face landmarks for each face
            num_landmarks = face_helper.get_face_landmarks_5()
            print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.')
            # warp and crop each face
            save_crop_path = os.path.join(save_crop_root, img_name)
            face_helper.warp_crop_faces(save_crop_path, save_inverse_affine_path=None)
            cropped_faces = face_helper.cropped_faces

            # get 68 landmarks for each cropped face
            # num_landmarks = face_helper.get_face_landmarks_68()
            # print(f'\tDetect {num_landmarks} faces for 68 landmarks.')
            # assert len(cropped_faces) == len(face_helper.all_landmarks_68)

        # TODO
        # face_helper.free_dlib_gpu_memory()

        # face restoration for each cropped face
        for idx, cropped_face in enumerate(cropped_faces):
            # prepare data
            cropped_face = transforms.ToTensor()(cropped_face)
            cropped_face = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(cropped_face)
            cropped_face = cropped_face.unsqueeze(0).to(device)

            try:
                with torch.no_grad():
                    output = net(cropped_face, w=w, adain=True)[0]
                    restored_face = tensor2img(output, min_max=(-1, 1))
                del output
                torch.cuda.empty_cache()
            except Exception as error:
                print(f'\tFailed inference for CodeFormer: {error}')
                restored_face = tensor2img(cropped_face, min_max=(-1, 1))

            path = os.path.splitext(os.path.join(save_restore_root, img_name))[0]
            if not args.has_aligned:
                save_path = f'{path}_{idx:02d}.png'
                face_helper.add_restored_face(restored_face)
            else:
                save_path = f'{path}.png'
            imwrite(restored_face, save_path)

        if not args.has_aligned:
            # paste each restored face to the input image
            face_helper.paste_faces_to_input_image(os.path.join(save_final_root, img_name), draw_box=args.draw_box)

        # clean all the intermediate results to process the next image
        face_helper.clean_all()

    print(f'\nAll results are saved in {result_root}')