Matteo Sirri commited on
Commit
cf40191
·
1 Parent(s): 3e01e59

fix: fix typo

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -8,6 +8,8 @@ from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
8
  from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
9
  from src.detection.graph_utils import add_bbox
10
  from src.detection.vision import presets
 
 
11
  logging.getLogger('PIL').setLevel(logging.CRITICAL)
12
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -36,8 +38,9 @@ def frcnn_motsynth(image):
36
  image_tensor = image_tensor.to(device)
37
  prediction = model([image_tensor])[0]
38
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
39
- torchvision.io.write_png(image_w_bbox, "custom_out.png")
40
- return "custom_out.png"
 
41
 
42
 
43
  def frcnn_coco(image):
@@ -47,8 +50,9 @@ def frcnn_coco(image):
47
  image_tensor = image_tensor.to(device)
48
  prediction = model([image_tensor])[0]
49
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
50
- torchvision.io.write_png(image_w_bbox, "baseline_out.png")
51
- return "baseline_out.png"
 
52
 
53
 
54
  title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
@@ -57,10 +61,10 @@ examples = ["001.jpg", "002.jpg", "003.jpg",
57
  "004.jpg", "005.jpg", "006.jpg", "007.jpg", ]
58
 
59
  io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
60
- type="file", shape=(1920, 1080), label="Baseline Model trained on COCO + FT on MOT17"))
61
 
62
- io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
63
- type="file", shape=(1920, 1080), label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
64
 
65
  gr.Parallel(io_baseline, io_custom, title=title,
66
  description=description, examples=examples).launch(enable_queue=True)
 
8
  from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
9
  from src.detection.graph_utils import add_bbox
10
  from src.detection.vision import presets
11
+ import torchvision.transforms as T
12
+
13
  logging.getLogger('PIL').setLevel(logging.CRITICAL)
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
38
  image_tensor = image_tensor.to(device)
39
  prediction = model([image_tensor])[0]
40
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
41
+ transform = T.ToPILImage()
42
+ output = transform(image_w_bbox)
43
+ return output
44
 
45
 
46
  def frcnn_coco(image):
 
50
  image_tensor = image_tensor.to(device)
51
  prediction = model([image_tensor])[0]
52
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
53
+ transform = T.ToPILImage()
54
+ output = transform(image_w_bbox)
55
+ return output
56
 
57
 
58
  title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
 
61
  "004.jpg", "005.jpg", "006.jpg", "007.jpg", ]
62
 
63
  io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
64
+ type="pil", label="Baseline Model trained on COCO + FT on MOT17"))
65
 
66
+ io_custom = gr.Interface(frcnn_motsynth, inputs=examples, outputs=gr.Image(
67
+ type="pil", label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
68
 
69
  gr.Parallel(io_baseline, io_custom, title=title,
70
  description=description, examples=examples).launch(enable_queue=True)