qubvel-hf HF staff commited on
Commit
e7c7d09
1 Parent(s): 2ed3d59

Init model on GPU

Browse files
Files changed (2) hide show
  1. app.py +11 -3
  2. inference_gradio.py +21 -24
app.py CHANGED
@@ -1,10 +1,10 @@
 
1
  import gradio as gr
2
  import spaces
3
  from inference_gradio import inference_one_image, model_init
4
 
5
  MODEL_PATH = "./checkpoints/docres.pkl"
6
 
7
- model = model_init(MODEL_PATH)
8
  possible_tasks = [
9
  "dewarping",
10
  "deshadowing",
@@ -13,14 +13,22 @@ possible_tasks = [
13
  "binarization",
14
  ]
15
 
16
- @spaces.GPU
17
  def run_tasks(image, tasks):
 
 
 
 
 
 
 
18
  bgr_image = image[..., ::-1].copy()
19
- bgr_restored_image = inference_one_image(model, bgr_image, tasks)
20
  if bgr_restored_image.ndim == 3:
21
  rgb_image = bgr_restored_image[..., ::-1]
22
  else:
23
  rgb_image = bgr_restored_image
 
24
  return rgb_image
25
 
26
 
 
1
+ import torch
2
  import gradio as gr
3
  import spaces
4
  from inference_gradio import inference_one_image, model_init
5
 
6
  MODEL_PATH = "./checkpoints/docres.pkl"
7
 
 
8
  possible_tasks = [
9
  "dewarping",
10
  "deshadowing",
 
13
  "binarization",
14
  ]
15
 
16
+ @spaces.GPU(duration=90)
17
  def run_tasks(image, tasks):
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # load model
22
+ model = model_init(MODEL_PATH, device)
23
+
24
+ # run inference
25
  bgr_image = image[..., ::-1].copy()
26
+ bgr_restored_image = inference_one_image(model, bgr_image, tasks, device)
27
  if bgr_restored_image.ndim == 3:
28
  rgb_image = bgr_restored_image[..., ::-1]
29
  else:
30
  rgb_image = bgr_restored_image
31
+
32
  return rgb_image
33
 
34
 
inference_gradio.py CHANGED
@@ -14,9 +14,6 @@ sys.path.append("./data/MBD/")
14
  from data.MBD.infer import net1_net2_infer_single_im
15
 
16
 
17
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
-
19
-
20
  def dewarp_prompt(img):
21
  mask = net1_net2_infer_single_im(img, "data/MBD/checkpoint/mbd.pkl")
22
  base_coord = utils.getBasecoord(256, 256) / 256
@@ -122,7 +119,7 @@ def binarization_promptv2(img):
122
  )
123
 
124
 
125
- def dewarping(model, im_org):
126
  INPUT_SIZE = 256
127
  im_masked, prompt_org = dewarp_prompt(im_org.copy())
128
 
@@ -131,10 +128,10 @@ def dewarping(model, im_org):
131
  im_masked = cv2.resize(im_masked, (INPUT_SIZE, INPUT_SIZE))
132
  im_masked = im_masked / 255.0
133
  im_masked = torch.from_numpy(im_masked.transpose(2, 0, 1)).unsqueeze(0)
134
- im_masked = im_masked.float().to(DEVICE)
135
 
136
  prompt = torch.from_numpy(prompt_org.transpose(2, 0, 1)).unsqueeze(0)
137
- prompt = prompt.float().to(DEVICE)
138
 
139
  in_im = torch.cat((im_masked, prompt), dim=1)
140
 
@@ -158,7 +155,7 @@ def dewarping(model, im_org):
158
  return prompt_org[:, :, 0], prompt_org[:, :, 1], prompt_org[:, :, 2], out_im
159
 
160
 
161
- def appearance(model, im_org):
162
  MAX_SIZE = 1600
163
  # obtain im and prompt
164
  h, w = im_org.shape[:2]
@@ -176,7 +173,7 @@ def appearance(model, im_org):
176
  in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
177
 
178
  # inference
179
- in_im = in_im.half().to(DEVICE)
180
  model = model.half()
181
  with torch.no_grad():
182
  pred = model(in_im)
@@ -198,7 +195,7 @@ def appearance(model, im_org):
198
  return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
199
 
200
 
201
- def deshadowing(model, im_org):
202
  MAX_SIZE = 1600
203
  # obtain im and prompt
204
  h, w = im_org.shape[:2]
@@ -216,7 +213,7 @@ def deshadowing(model, im_org):
216
  in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
217
 
218
  # inference
219
- in_im = in_im.half().to(DEVICE)
220
  model = model.half()
221
  with torch.no_grad():
222
  pred = model(in_im)
@@ -238,16 +235,16 @@ def deshadowing(model, im_org):
238
  return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
239
 
240
 
241
- def deblurring(model, im_org):
242
  # setup image
243
  in_im, padding_h, padding_w = stride_integral(im_org, 8)
244
  prompt = deblur_prompt(in_im)
245
  in_im = np.concatenate((in_im, prompt), -1)
246
  in_im = in_im / 255.0
247
  in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
248
- in_im = in_im.half().to(DEVICE)
249
  # inference
250
- model.to(DEVICE)
251
  model.eval()
252
  model = model.half()
253
  with torch.no_grad():
@@ -260,7 +257,7 @@ def deblurring(model, im_org):
260
  return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
261
 
262
 
263
- def binarization(model, im_org):
264
  im, padding_h, padding_w = stride_integral(im_org, 8)
265
  prompt = binarization_promptv2(im)
266
  h, w = im.shape[:2]
@@ -268,7 +265,7 @@ def binarization(model, im_org):
268
 
269
  in_im = in_im / 255.0
270
  in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
271
- in_im = in_im.to(DEVICE)
272
  model = model.half()
273
  in_im = in_im.half()
274
  with torch.no_grad():
@@ -283,7 +280,7 @@ def binarization(model, im_org):
283
  return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
284
 
285
 
286
- def model_init(model_path):
287
  # prepare model
288
  model = restormer_arch.Restormer(
289
  inp_channels=6,
@@ -298,7 +295,7 @@ def model_init(model_path):
298
  dual_pixel_task=True,
299
  )
300
 
301
- if DEVICE == "cpu":
302
  state = convert_state_dict(
303
  torch.load(model_path, map_location="cpu")["model_state"]
304
  )
@@ -309,7 +306,7 @@ def model_init(model_path):
309
  model.load_state_dict(state)
310
 
311
  model.eval()
312
- model = model.to(DEVICE)
313
  return model
314
 
315
 
@@ -328,11 +325,11 @@ def resize(image, max_size):
328
  return image
329
 
330
 
331
- def inference_one_image(model, image, tasks):
332
  # image should be in BGR format
333
 
334
  if "dewarping" in tasks:
335
- *_, image = dewarping(model, image)
336
 
337
  # if only dewarping return here
338
  if len(tasks) == 1 and "dewarping" in tasks:
@@ -341,12 +338,12 @@ def inference_one_image(model, image, tasks):
341
  image = resize(image, 1536)
342
 
343
  if "deshadowing" in tasks:
344
- *_, image = deshadowing(model, image)
345
  if "appearance" in tasks:
346
- *_, image = appearance(model, image)
347
  if "deblurring" in tasks:
348
- *_, image = deblurring(model, image)
349
  if "binarization" in tasks:
350
- *_, image = binarization(model, image)
351
 
352
  return image
 
14
  from data.MBD.infer import net1_net2_infer_single_im
15
 
16
 
 
 
 
17
  def dewarp_prompt(img):
18
  mask = net1_net2_infer_single_im(img, "data/MBD/checkpoint/mbd.pkl")
19
  base_coord = utils.getBasecoord(256, 256) / 256
 
119
  )
120
 
121
 
122
+ def dewarping(model, im_org, device):
123
  INPUT_SIZE = 256
124
  im_masked, prompt_org = dewarp_prompt(im_org.copy())
125
 
 
128
  im_masked = cv2.resize(im_masked, (INPUT_SIZE, INPUT_SIZE))
129
  im_masked = im_masked / 255.0
130
  im_masked = torch.from_numpy(im_masked.transpose(2, 0, 1)).unsqueeze(0)
131
+ im_masked = im_masked.float().to(device)
132
 
133
  prompt = torch.from_numpy(prompt_org.transpose(2, 0, 1)).unsqueeze(0)
134
+ prompt = prompt.float().to(device)
135
 
136
  in_im = torch.cat((im_masked, prompt), dim=1)
137
 
 
155
  return prompt_org[:, :, 0], prompt_org[:, :, 1], prompt_org[:, :, 2], out_im
156
 
157
 
158
+ def appearance(model, im_org, device):
159
  MAX_SIZE = 1600
160
  # obtain im and prompt
161
  h, w = im_org.shape[:2]
 
173
  in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
174
 
175
  # inference
176
+ in_im = in_im.half().to(device)
177
  model = model.half()
178
  with torch.no_grad():
179
  pred = model(in_im)
 
195
  return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
196
 
197
 
198
+ def deshadowing(model, im_org, device):
199
  MAX_SIZE = 1600
200
  # obtain im and prompt
201
  h, w = im_org.shape[:2]
 
213
  in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
214
 
215
  # inference
216
+ in_im = in_im.half().to(device)
217
  model = model.half()
218
  with torch.no_grad():
219
  pred = model(in_im)
 
235
  return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
236
 
237
 
238
+ def deblurring(model, im_org, device):
239
  # setup image
240
  in_im, padding_h, padding_w = stride_integral(im_org, 8)
241
  prompt = deblur_prompt(in_im)
242
  in_im = np.concatenate((in_im, prompt), -1)
243
  in_im = in_im / 255.0
244
  in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
245
+ in_im = in_im.half().to(device)
246
  # inference
247
+ model.to(device)
248
  model.eval()
249
  model = model.half()
250
  with torch.no_grad():
 
257
  return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
258
 
259
 
260
+ def binarization(model, im_org, device):
261
  im, padding_h, padding_w = stride_integral(im_org, 8)
262
  prompt = binarization_promptv2(im)
263
  h, w = im.shape[:2]
 
265
 
266
  in_im = in_im / 255.0
267
  in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
268
+ in_im = in_im.to(device)
269
  model = model.half()
270
  in_im = in_im.half()
271
  with torch.no_grad():
 
280
  return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im
281
 
282
 
283
+ def model_init(model_path, device):
284
  # prepare model
285
  model = restormer_arch.Restormer(
286
  inp_channels=6,
 
295
  dual_pixel_task=True,
296
  )
297
 
298
+ if device == "cpu":
299
  state = convert_state_dict(
300
  torch.load(model_path, map_location="cpu")["model_state"]
301
  )
 
306
  model.load_state_dict(state)
307
 
308
  model.eval()
309
+ model = model.to(device)
310
  return model
311
 
312
 
 
325
  return image
326
 
327
 
328
+ def inference_one_image(model, image, tasks, device):
329
  # image should be in BGR format
330
 
331
  if "dewarping" in tasks:
332
+ *_, image = dewarping(model, image, device)
333
 
334
  # if only dewarping return here
335
  if len(tasks) == 1 and "dewarping" in tasks:
 
338
  image = resize(image, 1536)
339
 
340
  if "deshadowing" in tasks:
341
+ *_, image = deshadowing(model, image, device)
342
  if "appearance" in tasks:
343
+ *_, image = appearance(model, image, device)
344
  if "deblurring" in tasks:
345
+ *_, image = deblurring(model, image, device)
346
  if "binarization" in tasks:
347
+ *_, image = binarization(model, image, device)
348
 
349
  return image