File size: 4,405 Bytes
0fe2a53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import argparse
import cv2
import os
from imutils import paths
from tqdm import tqdm
from config import *
from utils import get_face_enhancer, get_upsampler
def process(image_path, upsampler_name, face_enhancer_name=None, scale=2, device="cpu"):
if scale > 4:
scale = 4 # avoid too large scale value
try:
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
h, w = img.shape[0:2]
if h > 3500 or w > 3500:
output = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return output
if (h < 300 and w < 300) and upsampler_name != "srcnn":
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
return img
upsampler = get_upsampler(upsampler_name, device=device)
if face_enhancer_name:
face_enhancer = get_face_enhancer(
face_enhancer_name, scale, upsampler, device=device
)
else:
face_enhancer = None
try:
if face_enhancer is not None:
_, _, output = face_enhancer.enhance(
img, has_aligned=False, only_center_face=False, paste_back=True
)
else:
output, _ = upsampler.enhance(img, outscale=scale)
except RuntimeError as error:
print(f"Runtime error: {error}")
return output
except Exception as error:
print(f"global exception: {error}")
def main(args: argparse.Namespace) -> None:
device = args.device
scale = args.scale
upsampler_name = args.upsampler
face_enhancer_name = args.face_enhancer
if face_enhancer_name and ("srcnn" in upsampler_name or "anime" in upsampler_name):
print(
"Warnings: SRCNN and Anime model aren't compatible with face enhance. We will turn it off for you"
)
face_enhancer_name = None
os.makedirs(args.output, exist_ok=True)
if not os.path.exists(args.input):
raise ValueError("The input directory doesn't exist!")
elif not os.path.isdir(args.input):
image_paths = [args.input]
else:
image_paths = paths.list_images(args.input)
with tqdm(image_paths) as pbar:
for image_path in pbar:
filename = os.path.basename(image_path)
pbar.set_postfix_str(f"Processing {image_path}")
upsampled_image = process(
image_path=image_path,
upsampler_name=upsampler_name,
face_enhancer_name=face_enhancer_name,
scale=scale,
device=device,
)
if upsampled_image is not None:
save_path = os.path.join(args.output, filename)
cv2.imwrite(save_path, upsampled_image)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"Runs automatic detection and mask generation on an input image or directory of images"
)
)
parser.add_argument(
"--input",
"-i",
type=str,
required=True,
help="Path to either a single input image or folder of images.",
)
parser.add_argument(
"--output",
"-o",
type=str,
required=True,
help="Path to the output directory.",
)
parser.add_argument(
"--upsampler",
type=str,
default="realesr-general-x4v3",
choices=[
"srcnn",
"RealESRGAN_x2plus",
"RealESRGAN_x4plus",
"RealESRNet_x4plus",
"realesr-general-x4v3",
"RealESRGAN_x4plus_anime_6B",
"realesr-animevideov3",
],
help="The type of upsampler model to load",
)
parser.add_argument(
"--face-enhancer",
type=str,
choices=["GFPGANv1.3", "GFPGANv1.4", "RestoreFormer"],
help="The type of face enhancer model to load",
)
parser.add_argument(
"--scale",
type=float,
default=2,
choices=[1.5, 2, 2.5, 3, 3.5, 4],
help="scaling factor",
)
parser.add_argument(
"--device", type=str, default="cuda", help="The device to run upsampling on."
)
args = parser.parse_args()
main(args)
|