Matteo Sirri commited on
Commit
d41f3b7
·
1 Parent(s): 50aa67b

feat: add cuda support

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -12,7 +12,7 @@ import torchvision.transforms as T
12
 
13
  logging.getLogger('PIL').setLevel(logging.CRITICAL)
14
 
15
- device = torch.device("cpu")
16
 
17
 
18
  def load_model(baseline: bool = False):
@@ -35,7 +35,9 @@ def frcnn_motsynth(image):
35
  model = load_model()
36
  transformEval = presets.DetectionPresetEval()
37
  image_tensor = transformEval(image, None)[0]
 
38
  prediction = model([image_tensor])[0]
 
39
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
40
  torchvision.io.write_png(image_w_bbox, "custom_out.png")
41
  return "custom_out.png"
@@ -45,7 +47,9 @@ def frcnn_coco(image):
45
  model = load_model(baseline=True)
46
  transformEval = presets.DetectionPresetEval()
47
  image_tensor = transformEval(image, None)[0]
 
48
  prediction = model([image_tensor])[0]
 
49
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
50
  torchvision.io.write_png(image_w_bbox, "baseline_out.png")
51
  return "baseline_out.png"
 
12
 
13
  logging.getLogger('PIL').setLevel(logging.CRITICAL)
14
 
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
 
18
  def load_model(baseline: bool = False):
 
35
  model = load_model()
36
  transformEval = presets.DetectionPresetEval()
37
  image_tensor = transformEval(image, None)[0]
38
+ image_tensor = image_tensor.to(device)
39
  prediction = model([image_tensor])[0]
40
+ prediction = prediction.to(device)
41
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
42
  torchvision.io.write_png(image_w_bbox, "custom_out.png")
43
  return "custom_out.png"
 
47
  model = load_model(baseline=True)
48
  transformEval = presets.DetectionPresetEval()
49
  image_tensor = transformEval(image, None)[0]
50
+ image_tensor = image_tensor.to(device)
51
  prediction = model([image_tensor])[0]
52
+ prediction = prediction.to(device)
53
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
54
  torchvision.io.write_png(image_w_bbox, "baseline_out.png")
55
  return "baseline_out.png"