Spaces:
Build error
Build error
File size: 2,330 Bytes
169e11c e0452e0 169e11c e0452e0 169e11c e0452e0 169e11c e0452e0 169e11c e0452e0 169e11c e0452e0 d32b68f 169e11c e0452e0 169e11c e0452e0 169e11c e0452e0 169e11c e0452e0 dea24b0 169e11c e0452e0 169e11c e0452e0 169e11c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os.path as osp
import gradio as gr
import torch
import logging
import torchvision
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from src.detection.graph_utils import add_bbox
from src.detection.vision import presets
logging.getLogger('PIL').setLevel(logging.CRITICAL)
def load_model(baseline: bool = False):
if baseline:
model = fasterrcnn_resnet50_fpn(
weights="DEFAULT")
else:
model = fasterrcnn_resnet50_fpn()
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
checkpoint = torch.load(
"model_split_3_FT_MOT17.pth", map_location="cpu")
model.load_state_dict(checkpoint["model"])
device = torch.device('cuda:0')
model.to(device)
model.eval()
return model
def frcnn_motsynth(image):
model = load_model(baseline=True)
transformEval = presets.DetectionPresetEval()
image_tensor = transformEval(image, None)[0]
prediction = model([image_tensor])[0]
image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
torchvision.io.write_png(image_w_bbox, "custom_out.png")
return "custom_out.png"
def frcnn_coco(image):
model = load_model(baseline=True)
transformEval = presets.DetectionPresetEval()
image_tensor = transformEval(image, None)[0]
prediction = model([image_tensor])[0]
image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
torchvision.io.write_png(image_w_bbox, "baseline_out.png")
return "baseline_out.png"
title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
description = ""
examples = "/input_examples"
io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
type="file", shape=(1920, 1080), label="Baseline Model trained on COCO + FT on MOT17"))
io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
type="file", shape=(1920, 1080), label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
gr.Parallel(io_baseline, io_custom, title=title,
description=description, examples=examples).launch(enable_queue=True)
|