File size: 2,435 Bytes
0f2d9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import cv2
import sys
import torch
import os.path as osp
from gfpgan import GFPGANer
from basicsr.utils.download_util import load_file_from_url

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
sys.path.append(root_path)
from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo


class GFPGAN:

    def __init__(
        self,
        upscale=2,
        bg_upsampler_name="realesrgan",
        prefered_net_in_upsampler="RRDBNet",
    ):

        upscale = int(upscale)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # ------------------------ set up background upsampler ------------------------
        upsampler_zoo = RealEsrUpsamplerZoo(
            upscale=upscale,
            bg_upsampler_name=bg_upsampler_name,
            prefered_net_in_upsampler=prefered_net_in_upsampler,
        )
        bg_upsampler = upsampler_zoo.bg_upsampler

        # ------------------------ load model ------------------------
        gfpgan_weights_path = os.path.join(
            ROOT_DIR, "SR_Inference", "gfpgan", "weights"
        )
        gfpgan_model_path = os.path.join(gfpgan_weights_path, "GFPGANv1.3.pth")

        if not os.path.isfile(gfpgan_model_path):
            url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
            gfpgan_model_path = load_file_from_url(
                url=url,
                model_dir=gfpgan_weights_path,
                progress=True,
                file_name="GFPGANv1.3.pth",
            )

        self.sr_model = GFPGANer(
            upscale=upscale,
            bg_upsampler=bg_upsampler,
            model_path=gfpgan_model_path,
            device=device,
        )

    def __call__(self, img):
        # ------------------------ restore/enhance image using GFPGAN model ------------------------
        cropped_faces, sr_faces, sr_img = self.sr_model.enhance(img)

        return sr_img


if __name__ == "__main__":

    gfpgan = GFPGAN(
        upscale=2, bg_upsampler_name="realesrgan", prefered_net_in_upsampler="RRDBNet"
    )

    img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
    sr_img = gfpgan(img=img)

    saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
    os.makedirs(saving_dir, exist_ok=True)
    cv2.imwrite(f"{saving_dir}/sr_img_gfpgan.png", sr_img)