Spaces:
Runtime error
Runtime error
Create new file
Browse files
app.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image,ImageColor,ImageDraw,ImageFont
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
import torchvision
|
7 |
+
from torchvision import datasets, models, transforms
|
8 |
+
|
9 |
+
import streamlit as st
|
10 |
+
|
11 |
+
# 可视化函数
|
12 |
+
def plot_detection(image,prediction,idx2names,min_score = 0.8):
|
13 |
+
image_result = image.copy()
|
14 |
+
boxes,labels,scores = prediction['boxes'],prediction['labels'],prediction['scores']
|
15 |
+
draw = ImageDraw.Draw(image_result)
|
16 |
+
for idx in range(boxes.shape[0]):
|
17 |
+
if scores[idx] >= min_score:
|
18 |
+
x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
|
19 |
+
name = idx2names.get(str(labels[idx].item()))
|
20 |
+
score = scores[idx]
|
21 |
+
draw.rectangle((x1,y1,x2,y2), fill=None, outline ='lawngreen',width = 2)
|
22 |
+
draw.text((x1,y1),name+":\n"+str(round(score.item(),2)),fill="red")
|
23 |
+
return image_result
|
24 |
+
|
25 |
+
|
26 |
+
# 加载模型
|
27 |
+
@st.cache()
|
28 |
+
def load_model():
|
29 |
+
num_classes = 91
|
30 |
+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True,num_classes = num_classes)
|
31 |
+
if torch.cuda.is_available():
|
32 |
+
model.to("cuda:0")
|
33 |
+
model.eval()
|
34 |
+
model.idx2names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car',
|
35 |
+
'4': 'motorcycle', '5': 'airplane', '6': 'bus', '7': 'train', '8': 'truck', '9': 'boat',
|
36 |
+
'10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign',
|
37 |
+
'14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat',
|
38 |
+
'18': 'dog', '19': 'horse', '20': 'sheep', '21': 'cow', '22': 'elephant',
|
39 |
+
'23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack',
|
40 |
+
'28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase',
|
41 |
+
'34': 'frisbee', '35': 'skis', '36': 'snowboard', '37': 'sports ball',
|
42 |
+
'38': 'kite','39': 'baseball bat', '40': 'baseball glove', '41': 'skateboard',
|
43 |
+
'42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass',
|
44 |
+
'47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl',
|
45 |
+
'52': 'banana', '53': 'apple', '54': 'sandwich', '55': 'orange',
|
46 |
+
'56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza',
|
47 |
+
'60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch',
|
48 |
+
'64': 'potted plant', '65': 'bed', '67': 'dining table',
|
49 |
+
'70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse',
|
50 |
+
'75': 'remote', '76': 'keyboard', '77': 'cell phone',
|
51 |
+
'78': 'microwave', '79': 'oven', '80': 'toaster',
|
52 |
+
'81': 'sink', '82': 'refrigerator', '84': 'book',
|
53 |
+
'85': 'clock', '86': 'vase', '87': 'scissors',
|
54 |
+
'88': 'teddybear', '89': 'hair drier', '90': 'toothbrush'}
|
55 |
+
return model
|
56 |
+
|
57 |
+
def predict_detection(model,image_path,min_score=0.8):
|
58 |
+
# 准备数据
|
59 |
+
inputs = []
|
60 |
+
img = Image.open(image_path).convert("RGB")
|
61 |
+
img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float()
|
62 |
+
if torch.cuda.is_available():
|
63 |
+
img_tensor = img_tensor.cuda()
|
64 |
+
inputs.append(img_tensor)
|
65 |
+
|
66 |
+
# 预测结果
|
67 |
+
with torch.no_grad():
|
68 |
+
predictions = model(inputs)
|
69 |
+
|
70 |
+
# 结果可视化
|
71 |
+
img_result = plot_detection(img,predictions[0],
|
72 |
+
model.idx2names,min_score = min_score)
|
73 |
+
return img_result
|
74 |
+
|
75 |
+
st.title("FasterRCNN功能演示")
|
76 |
+
|
77 |
+
st.header("FasterRCNN Input:")
|
78 |
+
image_file = st.file_uploader("upload a image file(jpg/png) to predict:")
|
79 |
+
if image_file is not None:
|
80 |
+
try:
|
81 |
+
st.image(image_file)
|
82 |
+
except Exception as err:
|
83 |
+
st.write(err)
|
84 |
+
else:
|
85 |
+
image_file = "https://tva1.sinaimg.cn/large/e6c9d24egy1h566tcs188j20al0dwabm.jpg"
|
86 |
+
st.image(image_file)
|
87 |
+
|
88 |
+
min_score = st.slider(label="choose the min_score parameter:",min_value=0.1,max_value=0.98,value=0.8)
|
89 |
+
|
90 |
+
st.header("FasterRCNN Prediction:")
|
91 |
+
with st.spinner('waitting for prediction...'):
|
92 |
+
model = load_model()
|
93 |
+
img_result = predict_detection(model,image_file,min_score=min_score)
|
94 |
+
st.image(img_result)
|