Update test_code/inference.py

#4
Files changed (1) hide show
  1. test_code/inference.py +83 -132
test_code/inference.py CHANGED
@@ -1,153 +1,104 @@
1
- '''
2
- This is file is to execute the inference for a single image or a folder input
3
- '''
4
- import argparse
5
- import os, sys, cv2, shutil, warnings
6
- import torch
7
  import gradio as gr
8
- from torchvision.transforms import ToTensor
9
- from torchvision.utils import save_image
10
- warnings.simplefilter("default")
11
- os.environ["PYTHONWARNINGS"] = "default"
12
-
13
-
14
- # Import files from the local folder
15
- root_path = os.path.abspath('.')
16
- sys.path.append(root_path)
17
- from test_code.test_utils import load_grl, load_rrdb, load_cunet
18
-
19
-
20
-
21
- @torch.no_grad # You must add these time, else it will have Out of Memory
22
- def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torch.float32, downsample_threshold=720, crop_for_4x=True):
23
- ''' Super Resolve a low resolution image
24
- Args:
25
- generator (torch): the generator class that is already loaded
26
- input_path (str): the path to the input lr images
27
- output_path (str): the directory to store the generated images
28
- weight_dtype (bool): the weight type (float32/float16)
29
- downsample_threshold (int): the threshold of height/width (short side) to downsample the input
30
- crop_for_4x (bool): whether we crop the lr images to match 4x scale (needed for some situation)
31
- '''
32
- print("Processing image {}".format(input_path))
33
-
34
- # Read the image and do preprocess
35
- img_lr = cv2.imread(input_path)
36
- h, w, c = img_lr.shape
37
 
 
 
38
 
39
- # Downsample if needed
40
- short_side = min(h, w)
41
- if downsample_threshold != -1 and short_side > downsample_threshold:
42
- resize_ratio = short_side / downsample_threshold
43
- img_lr = cv2.resize(img_lr, (int(w/resize_ratio), int(h/resize_ratio)), interpolation = cv2.INTER_LINEAR)
44
 
 
 
 
 
 
 
45
 
46
- # Crop if needed
 
 
 
 
 
47
  if crop_for_4x:
48
- h, w, _ = img_lr.shape
49
  if h % 4 != 0:
50
- img_lr = img_lr[:4*(h//4),:,:]
51
  if w % 4 != 0:
52
- img_lr = img_lr[:,:4*(w//4),:]
53
-
54
- # Check if the size is out of the boundary
55
- h, w, c = img_lr.shape
56
- if h*w > 720*1280:
57
- raise gr.Error("The input image size is too large. The largest area we support is 720x1280=921600 pixel!")
58
-
59
-
60
- # Transform to tensor
61
- img_lr = cv2.cvtColor(img_lr, cv2.COLOR_BGR2RGB)
62
- img_lr = ToTensor()(img_lr).unsqueeze(0).cuda() # Use tensor format
63
- img_lr = img_lr.to(dtype=weight_dtype)
64
 
 
 
65
 
66
- # Model inference
67
- print("lr shape is ", img_lr.shape)
68
- super_resolved_img = generator(img_lr)
69
 
70
- # Store the generated result
71
- with torch.cuda.amp.autocast():
72
- if output_path is not None:
73
- save_image(super_resolved_img, output_path)
 
 
 
 
 
74
 
75
- # Empty the cache every time you finish processing one image
76
- torch.cuda.empty_cache()
77
-
78
- return super_resolved_img
79
 
 
 
 
 
80
 
 
 
 
81
 
 
 
 
 
82
 
83
- if __name__ == "__main__":
84
-
85
- # Fundamental setting
86
- parser = argparse.ArgumentParser()
87
- parser.add_argument('--input_dir', type = str, default = '__assets__/lr_inputs', help="Can be either single image input or a folder input")
88
- parser.add_argument('--model', type = str, default = 'GRL', help=" 'GRL' || 'RRDB' (for ESRNET & ESRGAN) || 'CUNET' (for Real-ESRGAN) ")
89
- parser.add_argument('--scale', type = int, default = 4, help="Up scaler factor")
90
- parser.add_argument('--weight_path', type = str, default = 'pretrained/4x_APISR_GRL_GAN_generator.pth', help="Weight path directory, usually under saved_models folder")
91
- parser.add_argument('--store_dir', type = str, default = 'sample_outputs', help="The folder to store the super-resolved images")
92
- parser.add_argument('--float16_inference', type = bool, default = False, help="Float16 inference, only useful in RRDB now") # Currently, this is only supported in RRDB, there is some bug with GRL model
93
- args = parser.parse_args()
94
-
95
- # Sample Command
96
- # 4x GRL (Default): python test_code/inference.py --model GRL --scale 4 --weight_path pretrained/4x_APISR_GRL_GAN_generator.pth
97
- # 2x RRDB: python test_code/inference.py --model RRDB --scale 2 --weight_path pretrained/2x_APISR_RRDB_GAN_generator.pth
98
-
99
-
100
- # Read argument and prepare the folder needed
101
- input_dir = args.input_dir
102
- model = args.model
103
- weight_path = args.weight_path
104
- store_dir = args.store_dir
105
- scale = args.scale
106
- float16_inference = args.float16_inference
107
-
108
-
109
- # Check the path of the weight
110
- if not os.path.exists(weight_path):
111
- print("we cannot locate weight path ", weight_path)
112
- # TODO: I am not sure if I should automatically download weight from github release based on the upscale factor and model name.
113
- os._exit(0)
114
-
115
-
116
- # Prepare the store folder
117
- if os.path.exists(store_dir):
118
- shutil.rmtree(store_dir)
119
- os.makedirs(store_dir)
120
 
 
121
 
 
 
122
 
123
- # Define the weight type
124
- if float16_inference:
125
- torch.backends.cudnn.benchmark = True
126
- weight_dtype = torch.float16
127
- else:
128
- weight_dtype = torch.float32
129
-
130
 
131
- # Load the model
132
- if model == "GRL":
133
- generator = load_grl(weight_path, scale=scale) # GRL for Real-World SR only support 4x upscaling
134
- elif model == "RRDB":
135
- generator = load_rrdb(weight_path, scale=scale) # Can be any size
136
- generator = generator.to(dtype=weight_dtype)
137
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- # Take the input path and do inference
140
- if os.path.isdir(store_dir): # If the input is a directory, we will iterate it
141
- for filename in sorted(os.listdir(input_dir)):
142
- input_path = os.path.join(input_dir, filename)
143
- output_path = os.path.join(store_dir, filename)
144
- # In default, we will automatically use crop to match 4x size
145
- super_resolve_img(generator, input_path, output_path, weight_dtype, crop_for_4x=True)
146
-
147
- else: # If the input is a single image, we will process it directly and write on the same folder
148
- filename = os.path.split(input_dir)[-1].split('.')[0]
149
- output_path = os.path.join(store_dir, filename+"_"+str(scale)+"x.png")
150
- # In default, we will automatically use crop to match 4x size
151
- super_resolve_img(generator, input_dir, output_path, weight_dtype, crop_for_4x=True)
152
-
153
-
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime as ort
 
 
5
  import gradio as gr
6
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Path to the model in Hugging Face Space
9
+ MODEL_PATH = "pretrained/4xGRL.onnx" # Adjust this if the model is stored in a different location
10
 
11
+ # Preprocessing function for images (similar to original script)
12
+ def preprocess_image(img, target_height=180, target_width=320, crop_for_4x=True, downsample_threshold=720):
13
+ ''' Preprocess the image to match model input expectations '''
14
+ img = np.array(img)
 
15
 
16
+ # Convert to RGB (OpenCV uses BGR by default)
17
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
18
+
19
+ # Resize if necessary (downsample based on the downsample threshold)
20
+ h, w, _ = img_rgb.shape
21
+ short_side = min(h, w)
22
 
23
+ # Downsample if the short side exceeds the threshold
24
+ if short_side > downsample_threshold:
25
+ resize_ratio = short_side / downsample_threshold
26
+ img_rgb = cv2.resize(img_rgb, (int(w / resize_ratio), int(h / resize_ratio)), interpolation=cv2.INTER_LINEAR)
27
+
28
+ # Crop to match 4x scaling if needed
29
  if crop_for_4x:
30
+ h, w, _ = img_rgb.shape
31
  if h % 4 != 0:
32
+ img_rgb = img_rgb[:4 * (h // 4), :, :]
33
  if w % 4 != 0:
34
+ img_rgb = img_rgb[:, :4 * (w // 4), :]
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Resize the image to match the model's expected input size (e.g., 180x320)
37
+ img_resized = cv2.resize(img_rgb, (target_width, target_height)) # Resize to 180x320
38
 
39
+ return img_resized
 
 
40
 
41
+ # Inference function to process image using ONNX model
42
+ def inference(img, model_name="4xGRL"):
43
+ try:
44
+ # Ensure correct dtype for ONNX
45
+ weight_dtype = np.float32 # ONNX uses numpy arrays, so use np.float32
46
+
47
+ if model_name == "4xGRL":
48
+ # Load the ONNX model
49
+ ort_session = ort.InferenceSession(MODEL_PATH)
50
 
51
+ # Preprocess the image (resize, crop, etc.)
52
+ img_resized = preprocess_image(img)
 
 
53
 
54
+ # Prepare the input in the format expected by the model (e.g., (N, C, H, W))
55
+ input_image = np.transpose(img_resized, (2, 0, 1)) # Convert to (C, H, W)
56
+ input_image = np.expand_dims(input_image, axis=0) # Add batch dimension
57
+ input_image = input_image.astype(weight_dtype) # Convert to float32
58
 
59
+ # Run the model
60
+ ort_inputs = {ort_session.get_inputs()[0].name: input_image}
61
+ ort_outs = ort_session.run(None, ort_inputs)
62
 
63
+ # Post-process the output
64
+ output_image = ort_outs[0] # Assuming the model output is in the first position
65
+ output_image = np.transpose(output_image.squeeze(), (1, 2, 0)) # Convert to (H, W, C)
66
+ output_image = np.clip(output_image, 0, 255).astype(np.uint8) # Ensure valid image range
67
 
68
+ # Convert output to PIL Image for Gradio
69
+ output_pil = Image.fromarray(output_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ return output_pil
72
 
73
+ else:
74
+ raise Exception("Model not supported")
75
 
76
+ except Exception as error:
77
+ return f"An error occurred: {error}"
 
 
 
 
 
78
 
79
+ # Gradio interface
80
+ def create_interface():
81
+ with gr.Blocks() as demo:
82
+ gr.Markdown("# Anime Super-Resolution using ONNX")
83
+ gr.Markdown("Upload an anime image to enhance it using the 4xGRL model.")
84
+
85
+ # File input for image
86
+ with gr.Row():
87
+ input_image = gr.Image(type="pil", label="Upload Image", interactive=True)
88
+
89
+ # Process button
90
+ with gr.Row():
91
+ process_button = gr.Button("Process Image")
92
+
93
+ # Output for result image
94
+ with gr.Row():
95
+ result_image = gr.Image(type="pil", label="Processed Image")
96
+
97
+ # Functionality for processing image
98
+ process_button.click(inference, inputs=input_image, outputs=result_image)
99
+
100
+ return demo
101
 
102
+ # Launch the app
103
+ demo = create_interface()
104
+ demo.launch(share=True)