|
import argparse
|
|
import cv2
|
|
import os
|
|
|
|
from imutils import paths
|
|
from tqdm import tqdm
|
|
from config import *
|
|
from utils import get_face_enhancer, get_upsampler
|
|
|
|
|
|
def process(image_path, upsampler_name, face_enhancer_name=None, scale=2, device="cpu"):
|
|
if scale > 4:
|
|
scale = 4
|
|
try:
|
|
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
|
|
|
h, w = img.shape[0:2]
|
|
if h > 3500 or w > 3500:
|
|
output = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
return output
|
|
|
|
if (h < 300 and w < 300) and upsampler_name != "srcnn":
|
|
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
|
|
return img
|
|
|
|
upsampler = get_upsampler(upsampler_name, device=device)
|
|
|
|
if face_enhancer_name:
|
|
face_enhancer = get_face_enhancer(
|
|
face_enhancer_name, scale, upsampler, device=device
|
|
)
|
|
else:
|
|
face_enhancer = None
|
|
|
|
try:
|
|
if face_enhancer is not None:
|
|
_, _, output = face_enhancer.enhance(
|
|
img, has_aligned=False, only_center_face=False, paste_back=True
|
|
)
|
|
else:
|
|
output, _ = upsampler.enhance(img, outscale=scale)
|
|
except RuntimeError as error:
|
|
print(f"Runtime error: {error}")
|
|
|
|
return output
|
|
except Exception as error:
|
|
print(f"global exception: {error}")
|
|
|
|
|
|
def main(args: argparse.Namespace) -> None:
|
|
device = args.device
|
|
scale = args.scale
|
|
|
|
upsampler_name = args.upsampler
|
|
face_enhancer_name = args.face_enhancer
|
|
|
|
if face_enhancer_name and ("srcnn" in upsampler_name or "anime" in upsampler_name):
|
|
print(
|
|
"Warnings: SRCNN and Anime model aren't compatible with face enhance. We will turn it off for you"
|
|
)
|
|
face_enhancer_name = None
|
|
|
|
os.makedirs(args.output, exist_ok=True)
|
|
if not os.path.exists(args.input):
|
|
raise ValueError("The input directory doesn't exist!")
|
|
elif not os.path.isdir(args.input):
|
|
image_paths = [args.input]
|
|
else:
|
|
image_paths = paths.list_images(args.input)
|
|
|
|
with tqdm(image_paths) as pbar:
|
|
for image_path in pbar:
|
|
filename = os.path.basename(image_path)
|
|
pbar.set_postfix_str(f"Processing {image_path}")
|
|
upsampled_image = process(
|
|
image_path=image_path,
|
|
upsampler_name=upsampler_name,
|
|
face_enhancer_name=face_enhancer_name,
|
|
scale=scale,
|
|
device=device,
|
|
)
|
|
if upsampled_image is not None:
|
|
save_path = os.path.join(args.output, filename)
|
|
cv2.imwrite(save_path, upsampled_image)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description=(
|
|
"Runs automatic detection and mask generation on an input image or directory of images"
|
|
)
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--input",
|
|
"-i",
|
|
type=str,
|
|
required=True,
|
|
help="Path to either a single input image or folder of images.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output",
|
|
"-o",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the output directory.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--upsampler",
|
|
type=str,
|
|
default="realesr-general-x4v3",
|
|
choices=[
|
|
"srcnn",
|
|
"RealESRGAN_x2plus",
|
|
"RealESRGAN_x4plus",
|
|
"RealESRNet_x4plus",
|
|
"realesr-general-x4v3",
|
|
"RealESRGAN_x4plus_anime_6B",
|
|
"realesr-animevideov3",
|
|
],
|
|
help="The type of upsampler model to load",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--face-enhancer",
|
|
type=str,
|
|
choices=["GFPGANv1.3", "GFPGANv1.4", "RestoreFormer"],
|
|
help="The type of face enhancer model to load",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--scale",
|
|
type=float,
|
|
default=2,
|
|
choices=[1.5, 2, 2.5, 3, 3.5, 4],
|
|
help="scaling factor",
|
|
)
|
|
parser.add_argument(
|
|
"--device", type=str, default="cuda", help="The device to run upsampling on."
|
|
)
|
|
args = parser.parse_args()
|
|
main(args)
|
|
|