Spaces:
Running
Running
File size: 3,943 Bytes
0f2d9f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
import cv2
import torch
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from basicsr.utils.download_util import load_file_from_url
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class RealEsrUpsamplerZoo:
def __init__(
self,
upscale=2,
bg_upsampler_name="realesrgan",
prefered_net_in_upsampler="RRDBNet",
):
self.upscale = int(upscale)
# ------------------------ set up background upsampler ------------------------
weights_path = os.path.join(
ROOT_DIR, "SR_Inference", f"{bg_upsampler_name}", "weights"
)
if bg_upsampler_name == "realesrgan":
model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
if self.upscale == 2:
model_path = os.path.join(weights_path, "RealESRGAN_x2plus.pth")
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
elif self.upscale == 4:
model_path = os.path.join(weights_path, "RealESRGAN_x4plus.pth")
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
else:
raise Exception(
f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
)
elif bg_upsampler_name == "realesrnet":
model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
if self.upscale == 4:
model_path = os.path.join(weights_path, "RealESRNet_x4plus.pth")
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
else:
raise Exception(
f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
)
elif bg_upsampler_name == "anime":
model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
if self.upscale == 4:
model_path = os.path.join(
weights_path, "RealESRGAN_x4plus_anime_6B.pth"
)
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
else:
raise Exception(
f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
)
else:
raise Exception(f"No model implemented for: {bg_upsampler_name}")
# ------------------------ load background upsampler model ------------------------
if not os.path.isfile(model_path):
model_path = load_file_from_url(
url=url, model_dir=weights_path, progress=True, file_name=None
)
self.bg_upsampler = RealESRGANer(
scale=int(upscale),
model_path=model_path,
model=model,
tile=0,
tile_pad=0,
pre_pad=0,
half=False,
)
@staticmethod
def get_prefered_net(prefered_net_in_upsampler, upscale=2):
if prefered_net_in_upsampler == "RRDBNet":
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=int(upscale),
)
elif prefered_net_in_upsampler == "SRVGGNetCompact":
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=16,
upscale=int(upscale),
act_type="prelu",
)
else:
raise Exception(f"No net named: {prefered_net_in_upsampler} implemented!")
return model
|