FasterRCNNDemo / app.py
lyhue1991's picture
Update app.py
9f387f9
raw
history blame
3.87 kB
import numpy as np
from PIL import Image,ImageColor,ImageDraw,ImageFont
import torch
from torch import nn
import torchvision
from torchvision import datasets, models, transforms
import streamlit as st
# 可视化函数
def plot_detection(image,prediction,idx2names,min_score = 0.8):
image_result = image.copy()
boxes,labels,scores = prediction['boxes'],prediction['labels'],prediction['scores']
draw = ImageDraw.Draw(image_result)
for idx in range(boxes.shape[0]):
if scores[idx] >= min_score:
x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
name = idx2names.get(str(labels[idx].item()))
score = scores[idx]
draw.rectangle((x1,y1,x2,y2), fill=None, outline ='lawngreen',width = 2)
draw.text((x1,y1),name+":\n"+str(round(score.item(),2)),fill="red")
return image_result
# 加载模型
@st.cache()
def load_model():
num_classes = 91
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True,num_classes = num_classes)
if torch.cuda.is_available():
model.to("cuda:0")
model.eval()
model.idx2names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car',
'4': 'motorcycle', '5': 'airplane', '6': 'bus', '7': 'train', '8': 'truck', '9': 'boat',
'10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign',
'14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat',
'18': 'dog', '19': 'horse', '20': 'sheep', '21': 'cow', '22': 'elephant',
'23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack',
'28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase',
'34': 'frisbee', '35': 'skis', '36': 'snowboard', '37': 'sports ball',
'38': 'kite','39': 'baseball bat', '40': 'baseball glove', '41': 'skateboard',
'42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass',
'47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl',
'52': 'banana', '53': 'apple', '54': 'sandwich', '55': 'orange',
'56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza',
'60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch',
'64': 'potted plant', '65': 'bed', '67': 'dining table',
'70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse',
'75': 'remote', '76': 'keyboard', '77': 'cell phone',
'78': 'microwave', '79': 'oven', '80': 'toaster',
'81': 'sink', '82': 'refrigerator', '84': 'book',
'85': 'clock', '86': 'vase', '87': 'scissors',
'88': 'teddybear', '89': 'hair drier', '90': 'toothbrush'}
return model
def predict_detection(model,image_path,min_score=0.8):
# 准备数据
inputs = []
img = Image.open(image_path).convert("RGB")
img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float()
if torch.cuda.is_available():
img_tensor = img_tensor.cuda()
inputs.append(img_tensor)
# 预测结果
with torch.no_grad():
predictions = model(inputs)
# 结果可视化
img_result = plot_detection(img,predictions[0],
model.idx2names,min_score = min_score)
return img_result
st.title("FasterRCNN功能演示")
st.header("FasterRCNN Input:")
image_file = st.file_uploader("upload a image file(jpg/png) to predict:")
if image_file is not None:
try:
st.image(image_file)
except Exception as err:
st.write(err)
else:
image_file = "horseman.png"
st.image(image_file)
min_score = st.slider(label="choose the min_score parameter:",min_value=0.1,max_value=0.98,value=0.8)
st.header("FasterRCNN Prediction:")
with st.spinner('waitting for prediction...'):
model = load_model()
img_result = predict_detection(model,image_file,min_score=min_score)
st.image(img_result)