File size: 4,320 Bytes
0f2d9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import cv2
import sys
import torch
import os.path as osp
from basicsr.utils import img2tensor, tensor2img
from torchvision.transforms.functional import normalize
from facexlib.utils.face_restoration_helper import FaceRestoreHelper

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
sys.path.append(root_path)
from SR_Inference.codeformer.codeformer_arch import CodeFormerArch
from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo


class CodeFormer:

    def __init__(
        self,
        upscale=2,
        bg_upsampler_name="realesrgan",
        prefered_net_in_upsampler="RRDBNet",
        fidelity_weight=0.8,
    ):

        self.upscale = int(upscale)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.fidelity_weight = fidelity_weight

        # ------------------------ set up background upsampler ------------------------
        upsampler_zoo = RealEsrUpsamplerZoo(
            upscale=self.upscale,
            bg_upsampler_name=bg_upsampler_name,
            prefered_net_in_upsampler=prefered_net_in_upsampler,
        )
        self.bg_upsampler = upsampler_zoo.bg_upsampler

        # ------------------ set up FaceRestoreHelper -------------------
        gfpgan_weights_path = os.path.join(
            ROOT_DIR, "SR_Inference", "gfpgan", "weights"
        )
        self.face_restorer_helper = FaceRestoreHelper(
            upscale_factor=self.upscale,
            face_size=512,
            crop_ratio=(1, 1),
            det_model="retinaface_resnet50",
            save_ext="png",
            use_parse=True,
            device=self.device,
            # model_rootpath="gfpgan/weights",
            model_rootpath=gfpgan_weights_path,
        )

        # ------------------ load model -------------------
        self.sr_model = CodeFormerArch().to(self.device)
        ckpt_path = os.path.join(
            ROOT_DIR, "SR_Inference", "codeformer", "weights", "codeformer_v0.1.0.pth"
        )
        loadnet = torch.load(ckpt_path, map_location=self.device)
        if "params_ema" in loadnet:
            keyname = "params_ema"
        else:
            keyname = "params"

        self.sr_model.load_state_dict(loadnet[keyname])
        self.sr_model.eval()

    @torch.no_grad()
    def __call__(self, img):

        bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]

        self.face_restorer_helper.clean_all()
        self.face_restorer_helper.read_image(img)
        self.face_restorer_helper.get_face_landmarks_5(
            only_keep_largest=True, only_center_face=False, eye_dist_threshold=5
        )
        self.face_restorer_helper.align_warp_face()

        if len(self.face_restorer_helper.cropped_faces) > 0:

            cropped_face = self.face_restorer_helper.cropped_faces[0]

            cropped_face_t = img2tensor(
                imgs=cropped_face / 255.0, bgr2rgb=True, float32=True
            )
            normalize(
                tensor=cropped_face_t,
                mean=(0.5, 0.5, 0.5),
                std=(0.5, 0.5, 0.5),
                inplace=True,
            )
            cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)

            # ------------------- restore/enhance image using CodeFormerArch model -------------------
            output = self.sr_model(cropped_face_t, w=self.fidelity_weight, adain=True)[
                0
            ]

            restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
            restored_face = restored_face.astype("uint8")

            self.face_restorer_helper.add_restored_face(restored_face)
            self.face_restorer_helper.get_inverse_affine(None)

            sr_img = self.face_restorer_helper.paste_faces_to_input_image(
                upsample_img=bg_img
            )
        else:
            sr_img = bg_img

        return sr_img


if __name__ == "__main__":

    codeformer = CodeFormer(upscale=2, fidelity_weight=1.0)

    img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
    sr_img = codeformer(img=img)

    saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
    os.makedirs(saving_dir, exist_ok=True)
    cv2.imwrite(f"{saving_dir}/sr_img.png", sr_img)