Spaces:
Running
Running
File size: 3,077 Bytes
0f2d9f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import cv2
import sys
import torch
import numpy as np
import os.path as osp
from PIL import Image
from basicsr.utils import img2tensor
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
sys.path.append(root_path)
from SR_Inference.hat.hat_arch import HATArch
class HAT:
def __init__(
self,
upscale=2,
in_chans=3,
img_size=(480, 640),
window_size=16,
compress_ratio=3,
squeeze_factor=30,
conv_scale=0.01,
overlap_ratio=0.5,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="pixelshuffle",
resi_connection="1conv",
):
upscale = int(upscale)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------ load model for img enhancement -------------------
self.sr_model = HATArch(
img_size=img_size,
upscale=upscale,
in_chans=in_chans,
window_size=window_size,
compress_ratio=compress_ratio,
squeeze_factor=squeeze_factor,
conv_scale=conv_scale,
overlap_ratio=overlap_ratio,
img_range=img_range,
depths=depths,
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
upsampler=upsampler,
resi_connection=resi_connection,
).to(self.device)
ckpt_path = os.path.join(
ROOT_DIR,
"SR_Inference",
"hat",
"weights",
f"HAT_SRx{str(upscale)}_ImageNet-pretrain.pth",
)
loadnet = torch.load(ckpt_path, map_location=self.device)
if "params_ema" in loadnet:
keyname = "params_ema"
else:
keyname = "params"
self.sr_model.load_state_dict(loadnet[keyname])
self.sr_model.eval()
@torch.no_grad()
def __call__(self, img):
img_tensor = (
img2tensor(imgs=img / 255.0, bgr2rgb=True, float32=True)
.unsqueeze(0)
.to(self.device)
)
restored_img = self.sr_model(img_tensor)[0]
restored_img = restored_img.permute(1, 2, 0).cpu().numpy()
restored_img = (restored_img - restored_img.min()) / (
restored_img.max() - restored_img.min()
)
restored_img = (restored_img * 255).astype(np.uint8)
restored_img = Image.fromarray(restored_img)
restored_img = np.array(restored_img)
sr_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
return sr_img
if __name__ == "__main__":
hat = HAT(upscale=2)
img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
sr_img = hat(img=img)
saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
os.makedirs(saving_dir, exist_ok=True)
cv2.imwrite(f"{saving_dir}/sr_img_hat.png", sr_img)
|