File size: 2,603 Bytes
169e11c
23918e1
169e11c
 
 
 
e0452e0
169e11c
 
 
cf40191
 
169e11c
 
f6bb7f6
 
169e11c
 
 
e0452e0
169e11c
 
e0452e0
169e11c
 
e0452e0
3e01e59
169e11c
e0452e0
169e11c
 
 
 
e0452e0
369cd4c
169e11c
 
f6bb7f6
169e11c
e0452e0
0888435
 
169e11c
 
e0452e0
169e11c
 
 
f6bb7f6
169e11c
e0452e0
0888435
 
169e11c
 
e0452e0
369cd4c
f6bb7f6
 
169e11c
e0452e0
0888435
169e11c
dadda42
b4616bc
169e11c
 
369cd4c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os.path as osp
import os
import gradio as gr
import torch
import logging
import torchvision
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from src.detection.graph_utils import add_bbox
from src.detection.vision import presets
import torchvision.transforms as T

logging.getLogger('PIL').setLevel(logging.CRITICAL)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model(baseline: bool = False):
    if baseline:
        model = fasterrcnn_resnet50_fpn(
            weights="DEFAULT")
    else:
        model = fasterrcnn_resnet50_fpn()
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
        checkpoint = torch.load(
            osp.join(os.getcwd(), "model_split3_FT_MOT17.pth"), map_location="cpu")
        model.load_state_dict(checkpoint["model"])
    model.to(device)
    model.eval()
    return model


def frcnn_motsynth(image):
    model = load_model()
    transformEval = presets.DetectionPresetEval()
    image_tensor = transformEval(image, None)[0]
    image_tensor = image_tensor.to(device)
    prediction = model([image_tensor])[0]
    image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
    torchvision.io.write_png(image_w_bbox, "custom_out.png")
    return "custom_out.png"


def frcnn_coco(image):
    model = load_model(baseline=True)
    transformEval = presets.DetectionPresetEval()
    image_tensor = transformEval(image, None)[0]
    image_tensor = image_tensor.to(device)
    prediction = model([image_tensor])[0]
    image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
    torchvision.io.write_png(image_w_bbox, "baseline_out.png")
    return "baseline_out.png"


title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
description = '<p style="text-align:center">School in AI: Deep Learning, Vision and Language for Industry - second edition final project work by Matteo Sirri.</p> '
examples = ["001.jpg", "002.jpg", "003.jpg",
            "004.jpg", "005.jpg", "006.jpg", "007.jpg", ]

io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
    type="file", label="Baseline Model trained on COCO + FT on MOT17"))

io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
    type="file", label="Faster R-CNN trained on MOTSynth + FT on MOT17"))

gr.Parallel(io_baseline, io_custom, title=title,
            description=description, examples=examples).launch(enable_queue=True)