Spaces:
Build error
Build error
File size: 2,603 Bytes
169e11c 23918e1 169e11c e0452e0 169e11c cf40191 169e11c f6bb7f6 169e11c e0452e0 169e11c e0452e0 169e11c e0452e0 3e01e59 169e11c e0452e0 169e11c e0452e0 369cd4c 169e11c f6bb7f6 169e11c e0452e0 0888435 169e11c e0452e0 169e11c f6bb7f6 169e11c e0452e0 0888435 169e11c e0452e0 369cd4c f6bb7f6 169e11c e0452e0 0888435 169e11c dadda42 b4616bc 169e11c 369cd4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import os.path as osp
import os
import gradio as gr
import torch
import logging
import torchvision
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from src.detection.graph_utils import add_bbox
from src.detection.vision import presets
import torchvision.transforms as T
logging.getLogger('PIL').setLevel(logging.CRITICAL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(baseline: bool = False):
if baseline:
model = fasterrcnn_resnet50_fpn(
weights="DEFAULT")
else:
model = fasterrcnn_resnet50_fpn()
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
checkpoint = torch.load(
osp.join(os.getcwd(), "model_split3_FT_MOT17.pth"), map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
return model
def frcnn_motsynth(image):
model = load_model()
transformEval = presets.DetectionPresetEval()
image_tensor = transformEval(image, None)[0]
image_tensor = image_tensor.to(device)
prediction = model([image_tensor])[0]
image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
torchvision.io.write_png(image_w_bbox, "custom_out.png")
return "custom_out.png"
def frcnn_coco(image):
model = load_model(baseline=True)
transformEval = presets.DetectionPresetEval()
image_tensor = transformEval(image, None)[0]
image_tensor = image_tensor.to(device)
prediction = model([image_tensor])[0]
image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
torchvision.io.write_png(image_w_bbox, "baseline_out.png")
return "baseline_out.png"
title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
description = '<p style="text-align:center">School in AI: Deep Learning, Vision and Language for Industry - second edition final project work by Matteo Sirri.</p> '
examples = ["001.jpg", "002.jpg", "003.jpg",
"004.jpg", "005.jpg", "006.jpg", "007.jpg", ]
io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
type="file", label="Baseline Model trained on COCO + FT on MOT17"))
io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
type="file", label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
gr.Parallel(io_baseline, io_custom, title=title,
description=description, examples=examples).launch(enable_queue=True)
|