File size: 3,943 Bytes
0f2d9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import cv2
import torch
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from basicsr.utils.download_util import load_file_from_url

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


class RealEsrUpsamplerZoo:

    def __init__(
        self,
        upscale=2,
        bg_upsampler_name="realesrgan",
        prefered_net_in_upsampler="RRDBNet",
    ):

        self.upscale = int(upscale)

        # ------------------------ set up background upsampler ------------------------
        weights_path = os.path.join(
            ROOT_DIR, "SR_Inference", f"{bg_upsampler_name}", "weights"
        )

        if bg_upsampler_name == "realesrgan":
            model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
            if self.upscale == 2:
                model_path = os.path.join(weights_path, "RealESRGAN_x2plus.pth")
                url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
            elif self.upscale == 4:
                model_path = os.path.join(weights_path, "RealESRGAN_x4plus.pth")
                url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
            else:
                raise Exception(
                    f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
                )
        elif bg_upsampler_name == "realesrnet":
            model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
            if self.upscale == 4:
                model_path = os.path.join(weights_path, "RealESRNet_x4plus.pth")
                url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
            else:
                raise Exception(
                    f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
                )
        elif bg_upsampler_name == "anime":
            model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
            if self.upscale == 4:
                model_path = os.path.join(
                    weights_path, "RealESRGAN_x4plus_anime_6B.pth"
                )
                url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
            else:
                raise Exception(
                    f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
                )
        else:
            raise Exception(f"No model implemented for: {bg_upsampler_name}")

        # ------------------------ load background upsampler model ------------------------
        if not os.path.isfile(model_path):
            model_path = load_file_from_url(
                url=url, model_dir=weights_path, progress=True, file_name=None
            )

        self.bg_upsampler = RealESRGANer(
            scale=int(upscale),
            model_path=model_path,
            model=model,
            tile=0,
            tile_pad=0,
            pre_pad=0,
            half=False,
        )

    @staticmethod
    def get_prefered_net(prefered_net_in_upsampler, upscale=2):
        if prefered_net_in_upsampler == "RRDBNet":
            model = RRDBNet(
                num_in_ch=3,
                num_out_ch=3,
                num_feat=64,
                num_block=23,
                num_grow_ch=32,
                scale=int(upscale),
            )
        elif prefered_net_in_upsampler == "SRVGGNetCompact":
            model = SRVGGNetCompact(
                num_in_ch=3,
                num_out_ch=3,
                num_feat=64,
                num_conv=16,
                upscale=int(upscale),
                act_type="prelu",
            )
        else:
            raise Exception(f"No net named: {prefered_net_in_upsampler} implemented!")
        return model