File size: 3,997 Bytes
6e92463
 
 
 
 
f4b2099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e92463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a625565
 
 
 
 
6e92463
 
 
 
 
 
 
a625565
6e92463
 
 
 
 
 
9eab909
 
6e92463
 
 
 
 
 
 
a625565
6e92463
 
 
a625565
6e92463
 
 
 
a625565
6e92463
3b8537d
6e92463
 
3f6e15f
219dbcc
3f6e15f
6e92463
3f6e15f
9eab909
ac13efd
21db336
9eab909
 
b607e0d
2530a47
b607e0d
866db1b
ac13efd
 
 
 
 
 
 
 
3b8537d
b607e0d
f4b2099
 
2faccac
ebffb66
fe40b13
 
 
723195d
fe40b13
6a3450c
f4b2099
6e92463
3f6e15f
219dbcc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import cv2
import onnxruntime
import gradio as gr

article_text = """
<div style="text-align: center;">
    <p>Enjoying the tool? Buy me a coffee and get exclusive prompt guides!</p>
    <p><i>Instantly unlock helpful tips for creating better prompts!</i></p>
    <div style="display: flex; justify-content: center;">
        <a href="https://piczify.lemonsqueezy.com/buy/0f5206fa-68e8-42f6-9ca8-4f80c587c83e">
            <img src="https://www.buymeacoffee.com/assets/img/custom_images/yellow_img.png" 
                 alt="Buy Me a Coffee" 
                 style="height: 40px; width: auto; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); border-radius: 10px;">
        </a>
    </div>
</div>
"""


def pre_process(img: np.array) -> np.array:
    # H, W, C -> C, H, W
    img = np.transpose(img[:, :, 0:3], (2, 0, 1))
    # C, H, W -> 1, C, H, W
    img = np.expand_dims(img, axis=0).astype(np.float32)
    return img


def post_process(img: np.array) -> np.array:
    # 1, C, H, W -> C, H, W
    img = np.squeeze(img)
    # C, H, W -> H, W, C
    img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
    return img


def inference(model_path: str, img_array: np.array) -> np.array:
    options = onnxruntime.SessionOptions()
    options.intra_op_num_threads = 1
    options.inter_op_num_threads = 1
    ort_session = onnxruntime.InferenceSession(model_path, options)
    ort_inputs = {ort_session.get_inputs()[0].name: img_array}
    ort_outs = ort_session.run(None, ort_inputs)

    return ort_outs[0]


def convert_pil_to_cv2(image):
    # pil_image = image.convert("RGB")
    open_cv_image = np.array(image)
    # RGB to BGR
    open_cv_image = open_cv_image[:, :, ::-1].copy()
    return open_cv_image


def upscale(image, model):
    model_path = f"models/{model}.ort"
    img = convert_pil_to_cv2(image)
    if img.ndim == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    if img.shape[2] == 4:
        alpha = img[:, :, 3]  # GRAY
        alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR)  # BGR
        alpha_output = post_process(inference(model_path, pre_process(alpha)))  # BGR
        alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY)  # GRAY

        img = img[:, :, 0:3]  # BGR
        image_output = post_process(inference(model_path, pre_process(img)))  # BGR
        image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA)  # BGRA
        image_output[:, :, 3] = alpha_output

    elif img.shape[2] == 3:
        image_output = post_process(inference(model_path, pre_process(img)))  # BGR

    return image_output


css = ".output-image, .input-image, .image-preview {height: 480px !important} "
model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"]

gr.Interface(
    fn=upscale,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Radio(
            model_choices,
            type="value",
            value="modelx4",
            label="Choose Upscaler",
        )
    ],
    # additional_inputs=[
    #     gr.Radio(
    #         model_choices,
    #         type="value",
    #         value="modelx4",
    #         label="Choose Upscaler",
    #     )
    # ],
    outputs="image",
    # title="Image Upscaler PRO ⚡",
    # description="Model: [Anchor-based Plain Net for Mobile Image Super-Resolution](https://arxiv.org/abs/2105.09750). Repository: [SR Mobile PyTorch](https://github.com/w11wo/sr_mobile_pytorch)",
    description = """
                <div style="text-align: center;">
                    <h1>Image Upscaler PRO ⚡</h1>
                    <a href="https://arxiv.org/abs/2105.09750">
                        <img src="https://img.shields.io/badge/arXiv-2105.09750-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
                    </a>
                    <p>Anchor-based Plain Net for Mobile Image Super-Resolution</p>
                </div>
            """,
    article =article_text,
    allow_flagging="never",
    css=css,
).launch()