pupilsense / SR_Inference /inference_gfpgan.py
vijul.shah
End-to-End Pipeline Configured
0f2d9f6
raw
history blame
2.44 kB
import os
import cv2
import sys
import torch
import os.path as osp
from gfpgan import GFPGANer
from basicsr.utils.download_util import load_file_from_url
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
sys.path.append(root_path)
from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo
class GFPGAN:
def __init__(
self,
upscale=2,
bg_upsampler_name="realesrgan",
prefered_net_in_upsampler="RRDBNet",
):
upscale = int(upscale)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------ set up background upsampler ------------------------
upsampler_zoo = RealEsrUpsamplerZoo(
upscale=upscale,
bg_upsampler_name=bg_upsampler_name,
prefered_net_in_upsampler=prefered_net_in_upsampler,
)
bg_upsampler = upsampler_zoo.bg_upsampler
# ------------------------ load model ------------------------
gfpgan_weights_path = os.path.join(
ROOT_DIR, "SR_Inference", "gfpgan", "weights"
)
gfpgan_model_path = os.path.join(gfpgan_weights_path, "GFPGANv1.3.pth")
if not os.path.isfile(gfpgan_model_path):
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
gfpgan_model_path = load_file_from_url(
url=url,
model_dir=gfpgan_weights_path,
progress=True,
file_name="GFPGANv1.3.pth",
)
self.sr_model = GFPGANer(
upscale=upscale,
bg_upsampler=bg_upsampler,
model_path=gfpgan_model_path,
device=device,
)
def __call__(self, img):
# ------------------------ restore/enhance image using GFPGAN model ------------------------
cropped_faces, sr_faces, sr_img = self.sr_model.enhance(img)
return sr_img
if __name__ == "__main__":
gfpgan = GFPGAN(
upscale=2, bg_upsampler_name="realesrgan", prefered_net_in_upsampler="RRDBNet"
)
img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
sr_img = gfpgan(img=img)
saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
os.makedirs(saving_dir, exist_ok=True)
cv2.imwrite(f"{saving_dir}/sr_img_gfpgan.png", sr_img)