Spaces:
Running
on
T4
Running
on
T4
Update test_code/inference.py
#4
by
Arrcttacsrks
- opened
- test_code/inference.py +83 -132
test_code/inference.py
CHANGED
@@ -1,153 +1,104 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
import
|
5 |
-
import os, sys, cv2, shutil, warnings
|
6 |
-
import torch
|
7 |
import gradio as gr
|
8 |
-
from
|
9 |
-
from torchvision.utils import save_image
|
10 |
-
warnings.simplefilter("default")
|
11 |
-
os.environ["PYTHONWARNINGS"] = "default"
|
12 |
-
|
13 |
-
|
14 |
-
# Import files from the local folder
|
15 |
-
root_path = os.path.abspath('.')
|
16 |
-
sys.path.append(root_path)
|
17 |
-
from test_code.test_utils import load_grl, load_rrdb, load_cunet
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
@torch.no_grad # You must add these time, else it will have Out of Memory
|
22 |
-
def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torch.float32, downsample_threshold=720, crop_for_4x=True):
|
23 |
-
''' Super Resolve a low resolution image
|
24 |
-
Args:
|
25 |
-
generator (torch): the generator class that is already loaded
|
26 |
-
input_path (str): the path to the input lr images
|
27 |
-
output_path (str): the directory to store the generated images
|
28 |
-
weight_dtype (bool): the weight type (float32/float16)
|
29 |
-
downsample_threshold (int): the threshold of height/width (short side) to downsample the input
|
30 |
-
crop_for_4x (bool): whether we crop the lr images to match 4x scale (needed for some situation)
|
31 |
-
'''
|
32 |
-
print("Processing image {}".format(input_path))
|
33 |
-
|
34 |
-
# Read the image and do preprocess
|
35 |
-
img_lr = cv2.imread(input_path)
|
36 |
-
h, w, c = img_lr.shape
|
37 |
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
img_lr = cv2.resize(img_lr, (int(w/resize_ratio), int(h/resize_ratio)), interpolation = cv2.INTER_LINEAR)
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
#
|
|
|
|
|
|
|
|
|
|
|
47 |
if crop_for_4x:
|
48 |
-
h, w, _ =
|
49 |
if h % 4 != 0:
|
50 |
-
|
51 |
if w % 4 != 0:
|
52 |
-
|
53 |
-
|
54 |
-
# Check if the size is out of the boundary
|
55 |
-
h, w, c = img_lr.shape
|
56 |
-
if h*w > 720*1280:
|
57 |
-
raise gr.Error("The input image size is too large. The largest area we support is 720x1280=921600 pixel!")
|
58 |
-
|
59 |
-
|
60 |
-
# Transform to tensor
|
61 |
-
img_lr = cv2.cvtColor(img_lr, cv2.COLOR_BGR2RGB)
|
62 |
-
img_lr = ToTensor()(img_lr).unsqueeze(0).cuda() # Use tensor format
|
63 |
-
img_lr = img_lr.to(dtype=weight_dtype)
|
64 |
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
print("lr shape is ", img_lr.shape)
|
68 |
-
super_resolved_img = generator(img_lr)
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
return super_resolved_img
|
79 |
|
|
|
|
|
|
|
|
|
80 |
|
|
|
|
|
|
|
81 |
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
# Fundamental setting
|
86 |
-
parser = argparse.ArgumentParser()
|
87 |
-
parser.add_argument('--input_dir', type = str, default = '__assets__/lr_inputs', help="Can be either single image input or a folder input")
|
88 |
-
parser.add_argument('--model', type = str, default = 'GRL', help=" 'GRL' || 'RRDB' (for ESRNET & ESRGAN) || 'CUNET' (for Real-ESRGAN) ")
|
89 |
-
parser.add_argument('--scale', type = int, default = 4, help="Up scaler factor")
|
90 |
-
parser.add_argument('--weight_path', type = str, default = 'pretrained/4x_APISR_GRL_GAN_generator.pth', help="Weight path directory, usually under saved_models folder")
|
91 |
-
parser.add_argument('--store_dir', type = str, default = 'sample_outputs', help="The folder to store the super-resolved images")
|
92 |
-
parser.add_argument('--float16_inference', type = bool, default = False, help="Float16 inference, only useful in RRDB now") # Currently, this is only supported in RRDB, there is some bug with GRL model
|
93 |
-
args = parser.parse_args()
|
94 |
-
|
95 |
-
# Sample Command
|
96 |
-
# 4x GRL (Default): python test_code/inference.py --model GRL --scale 4 --weight_path pretrained/4x_APISR_GRL_GAN_generator.pth
|
97 |
-
# 2x RRDB: python test_code/inference.py --model RRDB --scale 2 --weight_path pretrained/2x_APISR_RRDB_GAN_generator.pth
|
98 |
-
|
99 |
-
|
100 |
-
# Read argument and prepare the folder needed
|
101 |
-
input_dir = args.input_dir
|
102 |
-
model = args.model
|
103 |
-
weight_path = args.weight_path
|
104 |
-
store_dir = args.store_dir
|
105 |
-
scale = args.scale
|
106 |
-
float16_inference = args.float16_inference
|
107 |
-
|
108 |
-
|
109 |
-
# Check the path of the weight
|
110 |
-
if not os.path.exists(weight_path):
|
111 |
-
print("we cannot locate weight path ", weight_path)
|
112 |
-
# TODO: I am not sure if I should automatically download weight from github release based on the upscale factor and model name.
|
113 |
-
os._exit(0)
|
114 |
-
|
115 |
-
|
116 |
-
# Prepare the store folder
|
117 |
-
if os.path.exists(store_dir):
|
118 |
-
shutil.rmtree(store_dir)
|
119 |
-
os.makedirs(store_dir)
|
120 |
|
|
|
121 |
|
|
|
|
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
torch.backends.cudnn.benchmark = True
|
126 |
-
weight_dtype = torch.float16
|
127 |
-
else:
|
128 |
-
weight_dtype = torch.float32
|
129 |
-
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
input_path = os.path.join(input_dir, filename)
|
143 |
-
output_path = os.path.join(store_dir, filename)
|
144 |
-
# In default, we will automatically use crop to match 4x size
|
145 |
-
super_resolve_img(generator, input_path, output_path, weight_dtype, crop_for_4x=True)
|
146 |
-
|
147 |
-
else: # If the input is a single image, we will process it directly and write on the same folder
|
148 |
-
filename = os.path.split(input_dir)[-1].split('.')[0]
|
149 |
-
output_path = os.path.join(store_dir, filename+"_"+str(scale)+"x.png")
|
150 |
-
# In default, we will automatically use crop to match 4x size
|
151 |
-
super_resolve_img(generator, input_dir, output_path, weight_dtype, crop_for_4x=True)
|
152 |
-
|
153 |
-
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime as ort
|
|
|
|
|
5 |
import gradio as gr
|
6 |
+
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
# Path to the model in Hugging Face Space
|
9 |
+
MODEL_PATH = "pretrained/4xGRL.onnx" # Adjust this if the model is stored in a different location
|
10 |
|
11 |
+
# Preprocessing function for images (similar to original script)
|
12 |
+
def preprocess_image(img, target_height=180, target_width=320, crop_for_4x=True, downsample_threshold=720):
|
13 |
+
''' Preprocess the image to match model input expectations '''
|
14 |
+
img = np.array(img)
|
|
|
15 |
|
16 |
+
# Convert to RGB (OpenCV uses BGR by default)
|
17 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
18 |
+
|
19 |
+
# Resize if necessary (downsample based on the downsample threshold)
|
20 |
+
h, w, _ = img_rgb.shape
|
21 |
+
short_side = min(h, w)
|
22 |
|
23 |
+
# Downsample if the short side exceeds the threshold
|
24 |
+
if short_side > downsample_threshold:
|
25 |
+
resize_ratio = short_side / downsample_threshold
|
26 |
+
img_rgb = cv2.resize(img_rgb, (int(w / resize_ratio), int(h / resize_ratio)), interpolation=cv2.INTER_LINEAR)
|
27 |
+
|
28 |
+
# Crop to match 4x scaling if needed
|
29 |
if crop_for_4x:
|
30 |
+
h, w, _ = img_rgb.shape
|
31 |
if h % 4 != 0:
|
32 |
+
img_rgb = img_rgb[:4 * (h // 4), :, :]
|
33 |
if w % 4 != 0:
|
34 |
+
img_rgb = img_rgb[:, :4 * (w // 4), :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
# Resize the image to match the model's expected input size (e.g., 180x320)
|
37 |
+
img_resized = cv2.resize(img_rgb, (target_width, target_height)) # Resize to 180x320
|
38 |
|
39 |
+
return img_resized
|
|
|
|
|
40 |
|
41 |
+
# Inference function to process image using ONNX model
|
42 |
+
def inference(img, model_name="4xGRL"):
|
43 |
+
try:
|
44 |
+
# Ensure correct dtype for ONNX
|
45 |
+
weight_dtype = np.float32 # ONNX uses numpy arrays, so use np.float32
|
46 |
+
|
47 |
+
if model_name == "4xGRL":
|
48 |
+
# Load the ONNX model
|
49 |
+
ort_session = ort.InferenceSession(MODEL_PATH)
|
50 |
|
51 |
+
# Preprocess the image (resize, crop, etc.)
|
52 |
+
img_resized = preprocess_image(img)
|
|
|
|
|
53 |
|
54 |
+
# Prepare the input in the format expected by the model (e.g., (N, C, H, W))
|
55 |
+
input_image = np.transpose(img_resized, (2, 0, 1)) # Convert to (C, H, W)
|
56 |
+
input_image = np.expand_dims(input_image, axis=0) # Add batch dimension
|
57 |
+
input_image = input_image.astype(weight_dtype) # Convert to float32
|
58 |
|
59 |
+
# Run the model
|
60 |
+
ort_inputs = {ort_session.get_inputs()[0].name: input_image}
|
61 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
62 |
|
63 |
+
# Post-process the output
|
64 |
+
output_image = ort_outs[0] # Assuming the model output is in the first position
|
65 |
+
output_image = np.transpose(output_image.squeeze(), (1, 2, 0)) # Convert to (H, W, C)
|
66 |
+
output_image = np.clip(output_image, 0, 255).astype(np.uint8) # Ensure valid image range
|
67 |
|
68 |
+
# Convert output to PIL Image for Gradio
|
69 |
+
output_pil = Image.fromarray(output_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
return output_pil
|
72 |
|
73 |
+
else:
|
74 |
+
raise Exception("Model not supported")
|
75 |
|
76 |
+
except Exception as error:
|
77 |
+
return f"An error occurred: {error}"
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
# Gradio interface
|
80 |
+
def create_interface():
|
81 |
+
with gr.Blocks() as demo:
|
82 |
+
gr.Markdown("# Anime Super-Resolution using ONNX")
|
83 |
+
gr.Markdown("Upload an anime image to enhance it using the 4xGRL model.")
|
84 |
+
|
85 |
+
# File input for image
|
86 |
+
with gr.Row():
|
87 |
+
input_image = gr.Image(type="pil", label="Upload Image", interactive=True)
|
88 |
+
|
89 |
+
# Process button
|
90 |
+
with gr.Row():
|
91 |
+
process_button = gr.Button("Process Image")
|
92 |
+
|
93 |
+
# Output for result image
|
94 |
+
with gr.Row():
|
95 |
+
result_image = gr.Image(type="pil", label="Processed Image")
|
96 |
+
|
97 |
+
# Functionality for processing image
|
98 |
+
process_button.click(inference, inputs=input_image, outputs=result_image)
|
99 |
+
|
100 |
+
return demo
|
101 |
|
102 |
+
# Launch the app
|
103 |
+
demo = create_interface()
|
104 |
+
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|