Spaces:
Running
Running
File size: 3,997 Bytes
6e92463 f4b2099 6e92463 a625565 6e92463 a625565 6e92463 9eab909 6e92463 a625565 6e92463 a625565 6e92463 a625565 6e92463 3b8537d 6e92463 3f6e15f 219dbcc 3f6e15f 6e92463 3f6e15f 9eab909 ac13efd 21db336 9eab909 b607e0d 2530a47 b607e0d 866db1b ac13efd 3b8537d b607e0d f4b2099 2faccac ebffb66 fe40b13 723195d fe40b13 6a3450c f4b2099 6e92463 3f6e15f 219dbcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import numpy as np
import cv2
import onnxruntime
import gradio as gr
article_text = """
<div style="text-align: center;">
<p>Enjoying the tool? Buy me a coffee and get exclusive prompt guides!</p>
<p><i>Instantly unlock helpful tips for creating better prompts!</i></p>
<div style="display: flex; justify-content: center;">
<a href="https://piczify.lemonsqueezy.com/buy/0f5206fa-68e8-42f6-9ca8-4f80c587c83e">
<img src="https://www.buymeacoffee.com/assets/img/custom_images/yellow_img.png"
alt="Buy Me a Coffee"
style="height: 40px; width: auto; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); border-radius: 10px;">
</a>
</div>
</div>
"""
def pre_process(img: np.array) -> np.array:
# H, W, C -> C, H, W
img = np.transpose(img[:, :, 0:3], (2, 0, 1))
# C, H, W -> 1, C, H, W
img = np.expand_dims(img, axis=0).astype(np.float32)
return img
def post_process(img: np.array) -> np.array:
# 1, C, H, W -> C, H, W
img = np.squeeze(img)
# C, H, W -> H, W, C
img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
return img
def inference(model_path: str, img_array: np.array) -> np.array:
options = onnxruntime.SessionOptions()
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1
ort_session = onnxruntime.InferenceSession(model_path, options)
ort_inputs = {ort_session.get_inputs()[0].name: img_array}
ort_outs = ort_session.run(None, ort_inputs)
return ort_outs[0]
def convert_pil_to_cv2(image):
# pil_image = image.convert("RGB")
open_cv_image = np.array(image)
# RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
return open_cv_image
def upscale(image, model):
model_path = f"models/{model}.ort"
img = convert_pil_to_cv2(image)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.shape[2] == 4:
alpha = img[:, :, 3] # GRAY
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR
alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
img = img[:, :, 0:3] # BGR
image_output = post_process(inference(model_path, pre_process(img))) # BGR
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
image_output[:, :, 3] = alpha_output
elif img.shape[2] == 3:
image_output = post_process(inference(model_path, pre_process(img))) # BGR
return image_output
css = ".output-image, .input-image, .image-preview {height: 480px !important} "
model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"]
gr.Interface(
fn=upscale,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Radio(
model_choices,
type="value",
value="modelx4",
label="Choose Upscaler",
)
],
# additional_inputs=[
# gr.Radio(
# model_choices,
# type="value",
# value="modelx4",
# label="Choose Upscaler",
# )
# ],
outputs="image",
# title="Image Upscaler PRO ⚡",
# description="Model: [Anchor-based Plain Net for Mobile Image Super-Resolution](https://arxiv.org/abs/2105.09750). Repository: [SR Mobile PyTorch](https://github.com/w11wo/sr_mobile_pytorch)",
description = """
<div style="text-align: center;">
<h1>Image Upscaler PRO ⚡</h1>
<a href="https://arxiv.org/abs/2105.09750">
<img src="https://img.shields.io/badge/arXiv-2105.09750-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
</a>
<p>Anchor-based Plain Net for Mobile Image Super-Resolution</p>
</div>
""",
article =article_text,
allow_flagging="never",
css=css,
).launch()
|