Matteo Sirri commited on
Commit
e0452e0
·
1 Parent(s): feca2a9

fix: add model

Browse files
Files changed (4) hide show
  1. app.py +17 -16
  2. configs/__init__.py +0 -0
  3. configs/path_cfg.py +0 -19
  4. model_split3_FT_MOT17.pth +3 -0
app.py CHANGED
@@ -3,9 +3,8 @@ import gradio as gr
3
  import torch
4
  import logging
5
  import torchvision
6
- from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2
7
  from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
8
- from configs.path_cfg import OUTPUT_DIR
9
  from src.detection.graph_utils import add_bbox
10
  from src.detection.vision import presets
11
  logging.getLogger('PIL').setLevel(logging.CRITICAL)
@@ -13,48 +12,50 @@ logging.getLogger('PIL').setLevel(logging.CRITICAL)
13
 
14
  def load_model(baseline: bool = False):
15
  if baseline:
16
- model = fasterrcnn_resnet50_fpn_v2(
17
  weights="DEFAULT")
18
  else:
19
- model = fasterrcnn_resnet50_fpn_v2()
20
  in_features = model.roi_heads.box_predictor.cls_score.in_features
21
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
22
- checkpoint = torch.load(osp.join(OUTPUT_DIR, "detection_logs",
23
- "fasterrcnn_training", "checkpoint.pth"), map_location="cpu")
24
  model.load_state_dict(checkpoint["model"])
 
 
25
  model.eval()
26
  return model
27
 
28
 
29
- def detect_with_resnet50Model_finetuning_motsynth(image):
30
  model = load_model(baseline=True)
31
  transformEval = presets.DetectionPresetEval()
32
  image_tensor = transformEval(image, None)[0]
33
  prediction = model([image_tensor])[0]
34
- image_w_bbox = add_bbox(image_tensor, prediction, 0.85)
35
  torchvision.io.write_png(image_w_bbox, "custom_out.png")
36
  return "custom_out.png"
37
 
38
 
39
- def detect_with_resnet50Model_baseline(image):
40
  model = load_model(baseline=True)
41
  transformEval = presets.DetectionPresetEval()
42
  image_tensor = transformEval(image, None)[0]
43
  prediction = model([image_tensor])[0]
44
- image_w_bbox = add_bbox(image_tensor, prediction, 0.85)
45
  torchvision.io.write_png(image_w_bbox, "baseline_out.png")
46
  return "baseline_out.png"
47
 
48
 
49
- title = "Performance comparision of Faster R-CNN for people detection with syntetic data"
50
- description = "<p style='text-align: center'>Performance comparision of Faster RCNN models for people detection using MOTSynth and MOT17"
51
  examples = "/input_examples"
52
 
53
- io_baseline = gr.Interface(detect_with_resnet50Model_baseline, gr.Image(type="pil"), gr.Image(
54
- type="file", shape=(1920, 1080), label="Baseline Faster RCNN Model pretrained on COCO dataset"))
55
 
56
- io_custom = gr.Interface(detect_with_resnet50Model_finetuning_motsynth, gr.Image(type="pil"), gr.Image(
57
- type="file", shape=(1920, 1080), label="Faster RCNN Model pretrained on COCO dataset + FT on MOTSynth"))
58
 
59
  gr.Parallel(io_baseline, io_custom, title=title,
60
  description=description, examples=examples).launch(enable_queue=True)
 
3
  import torch
4
  import logging
5
  import torchvision
6
+ from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
7
  from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
 
8
  from src.detection.graph_utils import add_bbox
9
  from src.detection.vision import presets
10
  logging.getLogger('PIL').setLevel(logging.CRITICAL)
 
12
 
13
  def load_model(baseline: bool = False):
14
  if baseline:
15
+ model = fasterrcnn_resnet50_fpn(
16
  weights="DEFAULT")
17
  else:
18
+ model = fasterrcnn_resnet50_fpn()
19
  in_features = model.roi_heads.box_predictor.cls_score.in_features
20
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
21
+ checkpoint = torch.load(
22
+ "model_split_3_FT_MOT17.pth", map_location="cpu")
23
  model.load_state_dict(checkpoint["model"])
24
+ device = torch.device('cuda:0')
25
+ model.to(device)
26
  model.eval()
27
  return model
28
 
29
 
30
+ def frcnn_motsynth(image):
31
  model = load_model(baseline=True)
32
  transformEval = presets.DetectionPresetEval()
33
  image_tensor = transformEval(image, None)[0]
34
  prediction = model([image_tensor])[0]
35
+ image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
36
  torchvision.io.write_png(image_w_bbox, "custom_out.png")
37
  return "custom_out.png"
38
 
39
 
40
+ def frcnn_coco(image):
41
  model = load_model(baseline=True)
42
  transformEval = presets.DetectionPresetEval()
43
  image_tensor = transformEval(image, None)[0]
44
  prediction = model([image_tensor])[0]
45
+ image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
46
  torchvision.io.write_png(image_w_bbox, "baseline_out.png")
47
  return "baseline_out.png"
48
 
49
 
50
+ title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
51
+ description = "![alt text](http://www.aiacademy.unimore.it/media/news/ai-logo-white_2ND_EDITION.png)"
52
  examples = "/input_examples"
53
 
54
+ io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
55
+ type="file", shape=(1920, 1080), label="Baseline Model trained on COCO + FT on MOT17"))
56
 
57
+ io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
58
+ type="file", shape=(1920, 1080), label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
59
 
60
  gr.Parallel(io_baseline, io_custom, title=title,
61
  description=description, examples=examples).launch(enable_queue=True)
configs/__init__.py DELETED
File without changes
configs/path_cfg.py DELETED
@@ -1,19 +0,0 @@
1
- import os
2
- import sys
3
- import os
4
-
5
- IN_COLAB = False
6
- if 'COLAB_GPU' in os.environ:
7
- IN_COLAB=True
8
-
9
- cwd = os.getcwd()
10
-
11
- if(IN_COLAB):
12
- MOTSYNTH_ROOT = '/content/gdrive/MyDrive/CVCS/storage/MOTSynth'
13
- MOTCHA_ROOT = '/content/gdrive/MyDrive/CVCS/storage/MOTChallenge'
14
- OUTPUT_DIR = '/content/gdrive/MyDrive/CVCS/storage/motsynth_output'
15
- else:
16
- # windows config
17
- MOTSYNTH_ROOT = cwd + '\storage\MOTSynth'
18
- MOTCHA_ROOT = cwd + '\storage\MOTChallenge'
19
- OUTPUT_DIR = cwd + '\storage\motsynth_output'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_split3_FT_MOT17.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53116b936ee59ca7cd9f29ef99bc8bf1dc591b6e8955f6c380b083454535923d
3
+ size 330056867