Matteo Sirri commited on
Commit
0888435
·
1 Parent(s): dadda42

fix: fix type

Browse files
app.py CHANGED
@@ -38,9 +38,8 @@ def frcnn_motsynth(image):
38
  image_tensor = image_tensor.to(device)
39
  prediction = model([image_tensor])[0]
40
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
41
- transform = T.ToPILImage()
42
- output = transform(image_w_bbox)
43
- return output
44
 
45
 
46
  def frcnn_coco(image):
@@ -50,9 +49,8 @@ def frcnn_coco(image):
50
  image_tensor = image_tensor.to(device)
51
  prediction = model([image_tensor])[0]
52
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
53
- transform = T.ToPILImage()
54
- output = transform(image_w_bbox)
55
- return output
56
 
57
 
58
  title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
@@ -61,7 +59,7 @@ examples = ["001.jpg", "002.jpg", "003.jpg",
61
  "004.jpg", "005.jpg", "006.jpg", "007.jpg", ]
62
 
63
  io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
64
- type="pil", label="Baseline Model trained on COCO + FT on MOT17"))
65
 
66
  io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
67
  type="pil", label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
 
38
  image_tensor = image_tensor.to(device)
39
  prediction = model([image_tensor])[0]
40
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
41
+ torchvision.io.write_png(image_w_bbox, "custom_out.png")
42
+ return "custom_out.png"
 
43
 
44
 
45
  def frcnn_coco(image):
 
49
  image_tensor = image_tensor.to(device)
50
  prediction = model([image_tensor])[0]
51
  image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
52
+ torchvision.io.write_png(image_w_bbox, "baseline_out.png")
53
+ return "baseline_out.png"
 
54
 
55
 
56
  title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
 
59
  "004.jpg", "005.jpg", "006.jpg", "007.jpg", ]
60
 
61
  io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
62
+ type="file", label="Baseline Model trained on COCO + FT on MOT17"))
63
 
64
  io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
65
  type="pil", label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
app.py.7d07da4a5b8438bc3eb3c4039d0839bb.tmp ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import os
3
+ import gradio as gr
4
+ import torch
5
+ import logging
6
+ import torchvision
7
+ from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
8
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
9
+ from src.detection.graph_utils import add_bbox
10
+ from src.detection.vision import presets
11
+ import torchvision.transforms as T
12
+
13
+ logging.getLogger('PIL').setLevel(logging.CRITICAL)
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+
18
+ def load_model(baseline: bool = False):
19
+ if baseline:
20
+ model = fasterrcnn_resnet50_fpn(
21
+ weights="DEFAULT")
22
+ else:
23
+ model = fasterrcnn_resnet50_fpn()
24
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
25
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
26
+ checkpoint = torch.load(
27
+ osp.join(os.getcwd(), "model_split3_FT_MOT17.pth"), map_location="cpu")
28
+ model.load_state_dict(checkpoint["model"])
29
+ model.to(device)
30
+ model.eval()
31
+ return model
32
+
33
+
34
+ def frcnn_motsynth(image):
35
+ model = load_model()
36
+ transformEval = presets.DetectionPresetEval()
37
+ image_tensor = transformEval(image, None)[0]
38
+ image_tensor = image_tensor.to(device)
39
+ prediction = model([image_tensor])[0]
40
+ image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
41
+ torchvision.io.write_png(image_w_bbox, "custom_out.png")
42
+ return "custom_out.png"
43
+
44
+
45
+ def frcnn_coco(image):
46
+ model = load_model(baseline=True)
47
+ transformEval = presets.DetectionPresetEval()
48
+ image_tensor = transformEval(image, None)[0]
49
+ image_tensor = image_tensor.to(device)
50
+ prediction = model([image_tensor])[0]
51
+ image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
52
+ torchvision.io.write_png(image_w_bbox, "baseline_out.png")
53
+ return "baseline_out.png"
54
+
55
+
56
+ title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
57
+ description = '<p style="text-align:center">School in AI: Deep Learning, Vision and Language for Industry - second edition final project work by Matteo Sirri.</p> '
58
+ examples = ["001.jpg", "002.jpg", "003.jpg",
59
+ "004.jpg", "005.jpg", "006.jpg", "007.jpg", ]
60
+
61
+ io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
62
+ type="file", label="Baseline Model trained on COCO + FT on MOT17"))
63
+
64
+ io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
65
+ type="file", label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
66
+
67
+ gr.Parallel(io_baseline, io_custom, title=title,
68
+ description=description, examples=examples).launch(enable_queue=True)