File size: 4,352 Bytes
a1074ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import cv2
from PIL import Image
import numpy as np
from huggingface_hub import hf_hub_url, cached_download
from .rrdbnet_arch import RRDBNet
from .utils import (
    pad_reflect,
    split_image_into_overlapping_patches,
    stich_together,
    unpad_image,
)

HF_MODELS = {
    2: dict(
        repo_id="sberbank-ai/Real-ESRGAN",
        filename="RealESRGAN_x2.pth",
    ),
    4: dict(
        repo_id="sberbank-ai/Real-ESRGAN",
        filename="RealESRGAN_x4.pth",
    ),
    6: dict(
        repo_id="alicangonullu/ESRGAN-Ulti",
        filename="RealESRGAN_x4plus.pth",
    ),
    8: dict(
        repo_id="sberbank-ai/Real-ESRGAN",
        filename="RealESRGAN_x8.pth",
    ),
}

class RealESRGAN:
    def __init__(self, device, anime=False, scale=4):
        self.device = device
        self.scale = scale
        if anime:
            self.model = RRDBNet(
                num_in_ch=3,
                num_out_ch=3,
                num_feat=64,
                num_block=6,
                num_grow_ch=32,
                scale=scale,
            )
        else:
            self.model = RRDBNet(
                num_in_ch=3,
                num_out_ch=3,
                num_feat=64,
                num_block=23,
                num_grow_ch=32,
                scale=scale,
            )

    def load_weights(self, model_path, download=True):
        if not os.path.exists(model_path) and download:
            assert self.scale in [2, 4, 8], "You can download models only with scales: 2, 4, 8"
            config = HF_MODELS[self.scale]
            cache_dir = os.path.dirname(model_path)
            local_filename = os.path.basename(model_path)
            config_file_url = hf_hub_url(repo_id=config["repo_id"], filename=config["filename"])
            cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
            print("Weights downloaded to:", os.path.join(cache_dir, local_filename))

        loadnet = torch.load(model_path)
        if "params" in loadnet:
            self.model.load_state_dict(loadnet["params"], strict=True)
        elif "params_ema" in loadnet:
            self.model.load_state_dict(loadnet["params_ema"], strict=True)
        else:
            self.model.load_state_dict(loadnet, strict=True)
        self.model.eval()
        self.model.to(self.device)

    @torch.cuda.amp.autocast()
    def predict(self, lr_image, batch_size=4, patches_size=192, padding=24, pad_size=15):
        scale = self.scale
        device = self.device
        lr_image = np.array(lr_image)
        lr_image = pad_reflect(lr_image, pad_size)

        patches, p_shape = split_image_into_overlapping_patches(
            lr_image, patch_size=patches_size, padding_size=padding
        )
        img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach()

        with torch.no_grad():
            res = self.model(img[0:batch_size])
            for i in range(batch_size, img.shape[0], batch_size):
                res = torch.cat((res, self.model(img[i : i + batch_size])), 0)

        sr_image = res.permute((0, 2, 3, 1)).clamp_(0, 1).cpu()
        np_sr_image = sr_image.numpy()

        padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
        scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
        np_sr_image = stich_together(
            np_sr_image,
            padded_image_shape=padded_size_scaled,
            target_shape=scaled_image_shape,
            padding_size=padding * scale,
        )
        sr_img = (np_sr_image * 255).astype(np.uint8)
        sr_img = unpad_image(sr_img, pad_size * scale)
        sr_img = Image.fromarray(sr_img)
        return sr_img

    def face_enhance(self, img, scale=4):
        from gfpgan import GFPGANer
        face_enhancer = GFPGANer(
            model_path=r"C:\Users\Admin\Downloads\Term 3\Big Data Capstone Project\Real-ESRGAN-GFP\Img-Upscale-AI\model\GFPGANv1.3.pth",
            upscale=scale,
            arch="clean",
            channel_multiplier=2,
        )
        _, _, output = face_enhancer.enhance(
            img, has_aligned=False, only_center_face=False, paste_back=True
        )
        return output