Mathdesenvnonimate commited on
Commit
85456b1
·
verified ·
1 Parent(s): b8575d8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py CHANGED
@@ -80,6 +80,86 @@ def check_input_image(input_image):
80
 
81
 
82
  def preprocess(input_image, do_remove_background, foreground_ratio):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def fill_background(image):
84
  image = np.array(image).astype(np.float32) / 255.0
85
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
 
80
 
81
 
82
  def preprocess(input_image, do_remove_background, foreground_ratio):
83
+ def pre_process(img: np.array) -> np.array:
84
+ # H, W, C -> C, H, W
85
+ img = np.transpose(img[:, :, 0:3], (2, 0, 1))
86
+ # C, H, W -> 1, C, H, W
87
+ img = np.expand_dims(img, axis=0).astype(np.float32)
88
+ return img
89
+
90
+
91
+ def post_process(img: np.array) -> np.array:
92
+ # 1, C, H, W -> C, H, W
93
+ img = np.squeeze(img)
94
+ # C, H, W -> H, W, C
95
+ img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
96
+ return img
97
+
98
+
99
+ def inference(model_path: str, img_array: np.array) -> np.array:
100
+ options = onnxruntime.SessionOptions()
101
+ options.intra_op_num_threads = 1
102
+ options.inter_op_num_threads = 1
103
+ ort_session = onnxruntime.InferenceSession(model_path, options)
104
+ ort_inputs = {ort_session.get_inputs()[0].name: img_array}
105
+ ort_outs = ort_session.run(None, ort_inputs)
106
+
107
+ return ort_outs[0]
108
+
109
+
110
+ def convert_pil_to_cv2(input_image):
111
+ # pil_image = image.convert("RGB")
112
+ open_cv_image = np.array(input_image)
113
+ # RGB to BGR
114
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
115
+ return open_cv_image
116
+
117
+
118
+ def upscale(image, model):
119
+ model_path = f"models/{model}.ort"
120
+ img = convert_pil_to_cv2(image)
121
+ if img.ndim == 2:
122
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
123
+
124
+ if img.shape[2] == 4:
125
+ alpha = img[:, :, 3] # GRAY
126
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
127
+ alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR
128
+ alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
129
+
130
+ img = img[:, :, 0:3] # BGR
131
+ image_output = post_process(inference(model_path, pre_process(img))) # BGR
132
+ image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
133
+ image_output[:, :, 3] = alpha_output
134
+
135
+ elif img.shape[2] == 3:
136
+ image_output = post_process(inference(model_path, pre_process(img))) # BGR
137
+
138
+ return image_output
139
+
140
+
141
+
142
+ def fill_background(image):
143
+ image = np.array(image).astype(np.float32) / 255.0
144
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
145
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
146
+ return image
147
+
148
+
149
+
150
+
151
+ if do_remove_background:
152
+ image = image_output.convert("RGB")
153
+ image = remove_background(image, rembg_session)
154
+ image = resize_foreground(image, foreground_ratio)
155
+ image = fill_background(image)
156
+ else:
157
+ image = image_output
158
+ if image.mode == "RGBA":
159
+ image = fill_background(image)
160
+ return image
161
+
162
+
163
  def fill_background(image):
164
  image = np.array(image).astype(np.float32) / 255.0
165
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5