File size: 3,077 Bytes
0f2d9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import cv2
import sys
import torch
import numpy as np
import os.path as osp
from PIL import Image
from basicsr.utils import img2tensor

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
sys.path.append(root_path)
from SR_Inference.hat.hat_arch import HATArch


class HAT:

    def __init__(
        self,
        upscale=2,
        in_chans=3,
        img_size=(480, 640),
        window_size=16,
        compress_ratio=3,
        squeeze_factor=30,
        conv_scale=0.01,
        overlap_ratio=0.5,
        img_range=1.0,
        depths=[6, 6, 6, 6, 6, 6],
        embed_dim=180,
        num_heads=[6, 6, 6, 6, 6, 6],
        mlp_ratio=2,
        upsampler="pixelshuffle",
        resi_connection="1conv",
    ):
        upscale = int(upscale)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # ------------------ load model for img enhancement -------------------
        self.sr_model = HATArch(
            img_size=img_size,
            upscale=upscale,
            in_chans=in_chans,
            window_size=window_size,
            compress_ratio=compress_ratio,
            squeeze_factor=squeeze_factor,
            conv_scale=conv_scale,
            overlap_ratio=overlap_ratio,
            img_range=img_range,
            depths=depths,
            embed_dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            upsampler=upsampler,
            resi_connection=resi_connection,
        ).to(self.device)

        ckpt_path = os.path.join(
            ROOT_DIR,
            "SR_Inference",
            "hat",
            "weights",
            f"HAT_SRx{str(upscale)}_ImageNet-pretrain.pth",
        )
        loadnet = torch.load(ckpt_path, map_location=self.device)
        if "params_ema" in loadnet:
            keyname = "params_ema"
        else:
            keyname = "params"

        self.sr_model.load_state_dict(loadnet[keyname])
        self.sr_model.eval()

    @torch.no_grad()
    def __call__(self, img):
        img_tensor = (
            img2tensor(imgs=img / 255.0, bgr2rgb=True, float32=True)
            .unsqueeze(0)
            .to(self.device)
        )
        restored_img = self.sr_model(img_tensor)[0]
        restored_img = restored_img.permute(1, 2, 0).cpu().numpy()
        restored_img = (restored_img - restored_img.min()) / (
            restored_img.max() - restored_img.min()
        )
        restored_img = (restored_img * 255).astype(np.uint8)
        restored_img = Image.fromarray(restored_img)
        restored_img = np.array(restored_img)
        sr_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)

        return sr_img


if __name__ == "__main__":

    hat = HAT(upscale=2)

    img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
    sr_img = hat(img=img)

    saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
    os.makedirs(saving_dir, exist_ok=True)
    cv2.imwrite(f"{saving_dir}/sr_img_hat.png", sr_img)