Spaces:
Running
Running
File size: 2,435 Bytes
0f2d9f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import os
import cv2
import sys
import torch
import os.path as osp
from gfpgan import GFPGANer
from basicsr.utils.download_util import load_file_from_url
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
sys.path.append(root_path)
from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo
class GFPGAN:
def __init__(
self,
upscale=2,
bg_upsampler_name="realesrgan",
prefered_net_in_upsampler="RRDBNet",
):
upscale = int(upscale)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------ set up background upsampler ------------------------
upsampler_zoo = RealEsrUpsamplerZoo(
upscale=upscale,
bg_upsampler_name=bg_upsampler_name,
prefered_net_in_upsampler=prefered_net_in_upsampler,
)
bg_upsampler = upsampler_zoo.bg_upsampler
# ------------------------ load model ------------------------
gfpgan_weights_path = os.path.join(
ROOT_DIR, "SR_Inference", "gfpgan", "weights"
)
gfpgan_model_path = os.path.join(gfpgan_weights_path, "GFPGANv1.3.pth")
if not os.path.isfile(gfpgan_model_path):
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
gfpgan_model_path = load_file_from_url(
url=url,
model_dir=gfpgan_weights_path,
progress=True,
file_name="GFPGANv1.3.pth",
)
self.sr_model = GFPGANer(
upscale=upscale,
bg_upsampler=bg_upsampler,
model_path=gfpgan_model_path,
device=device,
)
def __call__(self, img):
# ------------------------ restore/enhance image using GFPGAN model ------------------------
cropped_faces, sr_faces, sr_img = self.sr_model.enhance(img)
return sr_img
if __name__ == "__main__":
gfpgan = GFPGAN(
upscale=2, bg_upsampler_name="realesrgan", prefered_net_in_upsampler="RRDBNet"
)
img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
sr_img = gfpgan(img=img)
saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
os.makedirs(saving_dir, exist_ok=True)
cv2.imwrite(f"{saving_dir}/sr_img_gfpgan.png", sr_img)
|