File size: 3,874 Bytes
538c868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f387f9
538c868
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
from PIL import Image,ImageColor,ImageDraw,ImageFont 
import torch
from torch import nn

import torchvision
from torchvision import datasets, models, transforms

import streamlit as st 

# 可视化函数
def plot_detection(image,prediction,idx2names,min_score = 0.8):
    image_result = image.copy()
    boxes,labels,scores = prediction['boxes'],prediction['labels'],prediction['scores']
    draw = ImageDraw.Draw(image_result) 
    for idx in range(boxes.shape[0]):
        if scores[idx] >= min_score:
            x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
            name = idx2names.get(str(labels[idx].item()))
            score = scores[idx]
            draw.rectangle((x1,y1,x2,y2), fill=None, outline ='lawngreen',width = 2)
            draw.text((x1,y1),name+":\n"+str(round(score.item(),2)),fill="red")
    return image_result 


# 加载模型
@st.cache()
def load_model():
    num_classes = 91
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True,num_classes = num_classes)
    if torch.cuda.is_available():
        model.to("cuda:0")
    model.eval()
    model.idx2names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car', 
       '4': 'motorcycle', '5': 'airplane', '6': 'bus', '7': 'train', '8': 'truck', '9': 'boat', 
       '10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign', 
       '14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat',
       '18': 'dog', '19': 'horse', '20': 'sheep', '21': 'cow', '22': 'elephant', 
       '23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack', 
       '28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase',
       '34': 'frisbee', '35': 'skis', '36': 'snowboard', '37': 'sports ball',
       '38': 'kite','39': 'baseball bat', '40': 'baseball glove', '41': 'skateboard',
       '42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass', 
       '47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl',
       '52': 'banana', '53': 'apple', '54': 'sandwich', '55': 'orange', 
       '56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza',
       '60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch', 
       '64': 'potted plant', '65': 'bed', '67': 'dining table',
       '70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse', 
       '75': 'remote', '76': 'keyboard', '77': 'cell phone', 
       '78': 'microwave', '79': 'oven', '80': 'toaster', 
       '81': 'sink', '82': 'refrigerator', '84': 'book',
       '85': 'clock', '86': 'vase', '87': 'scissors',
       '88': 'teddybear', '89': 'hair drier', '90': 'toothbrush'}
    return model 

def predict_detection(model,image_path,min_score=0.8):
    # 准备数据
    inputs = []
    img = Image.open(image_path).convert("RGB")
    img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float()
    if torch.cuda.is_available():
        img_tensor = img_tensor.cuda()
    inputs.append(img_tensor)    

    # 预测结果
    with torch.no_grad():
        predictions = model(inputs)

    # 结果可视化
    img_result = plot_detection(img,predictions[0],
        model.idx2names,min_score = min_score)
    return img_result
    
st.title("FasterRCNN功能演示")

st.header("FasterRCNN Input:")
image_file = st.file_uploader("upload a image file(jpg/png) to predict:")
if image_file is not None:
    try:
        st.image(image_file)
    except Exception as err:
        st.write(err)
else:
    image_file = "horseman.png"
    st.image(image_file)
    
min_score = st.slider(label="choose the min_score parameter:",min_value=0.1,max_value=0.98,value=0.8)

st.header("FasterRCNN Prediction:")
with st.spinner('waitting for prediction...'):
    model = load_model()
    img_result = predict_detection(model,image_file,min_score=min_score)
    st.image(img_result)