lyhue1991 commited on
Commit
538c868
·
1 Parent(s): bc744c2

Create new file

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image,ImageColor,ImageDraw,ImageFont
3
+ import torch
4
+ from torch import nn
5
+
6
+ import torchvision
7
+ from torchvision import datasets, models, transforms
8
+
9
+ import streamlit as st
10
+
11
+ # 可视化函数
12
+ def plot_detection(image,prediction,idx2names,min_score = 0.8):
13
+ image_result = image.copy()
14
+ boxes,labels,scores = prediction['boxes'],prediction['labels'],prediction['scores']
15
+ draw = ImageDraw.Draw(image_result)
16
+ for idx in range(boxes.shape[0]):
17
+ if scores[idx] >= min_score:
18
+ x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
19
+ name = idx2names.get(str(labels[idx].item()))
20
+ score = scores[idx]
21
+ draw.rectangle((x1,y1,x2,y2), fill=None, outline ='lawngreen',width = 2)
22
+ draw.text((x1,y1),name+":\n"+str(round(score.item(),2)),fill="red")
23
+ return image_result
24
+
25
+
26
+ # 加载模型
27
+ @st.cache()
28
+ def load_model():
29
+ num_classes = 91
30
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True,num_classes = num_classes)
31
+ if torch.cuda.is_available():
32
+ model.to("cuda:0")
33
+ model.eval()
34
+ model.idx2names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car',
35
+ '4': 'motorcycle', '5': 'airplane', '6': 'bus', '7': 'train', '8': 'truck', '9': 'boat',
36
+ '10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign',
37
+ '14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat',
38
+ '18': 'dog', '19': 'horse', '20': 'sheep', '21': 'cow', '22': 'elephant',
39
+ '23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack',
40
+ '28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase',
41
+ '34': 'frisbee', '35': 'skis', '36': 'snowboard', '37': 'sports ball',
42
+ '38': 'kite','39': 'baseball bat', '40': 'baseball glove', '41': 'skateboard',
43
+ '42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass',
44
+ '47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl',
45
+ '52': 'banana', '53': 'apple', '54': 'sandwich', '55': 'orange',
46
+ '56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza',
47
+ '60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch',
48
+ '64': 'potted plant', '65': 'bed', '67': 'dining table',
49
+ '70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse',
50
+ '75': 'remote', '76': 'keyboard', '77': 'cell phone',
51
+ '78': 'microwave', '79': 'oven', '80': 'toaster',
52
+ '81': 'sink', '82': 'refrigerator', '84': 'book',
53
+ '85': 'clock', '86': 'vase', '87': 'scissors',
54
+ '88': 'teddybear', '89': 'hair drier', '90': 'toothbrush'}
55
+ return model
56
+
57
+ def predict_detection(model,image_path,min_score=0.8):
58
+ # 准备数据
59
+ inputs = []
60
+ img = Image.open(image_path).convert("RGB")
61
+ img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float()
62
+ if torch.cuda.is_available():
63
+ img_tensor = img_tensor.cuda()
64
+ inputs.append(img_tensor)
65
+
66
+ # 预测结果
67
+ with torch.no_grad():
68
+ predictions = model(inputs)
69
+
70
+ # 结果可视化
71
+ img_result = plot_detection(img,predictions[0],
72
+ model.idx2names,min_score = min_score)
73
+ return img_result
74
+
75
+ st.title("FasterRCNN功能演示")
76
+
77
+ st.header("FasterRCNN Input:")
78
+ image_file = st.file_uploader("upload a image file(jpg/png) to predict:")
79
+ if image_file is not None:
80
+ try:
81
+ st.image(image_file)
82
+ except Exception as err:
83
+ st.write(err)
84
+ else:
85
+ image_file = "https://tva1.sinaimg.cn/large/e6c9d24egy1h566tcs188j20al0dwabm.jpg"
86
+ st.image(image_file)
87
+
88
+ min_score = st.slider(label="choose the min_score parameter:",min_value=0.1,max_value=0.98,value=0.8)
89
+
90
+ st.header("FasterRCNN Prediction:")
91
+ with st.spinner('waitting for prediction...'):
92
+ model = load_model()
93
+ img_result = predict_detection(model,image_file,min_score=min_score)
94
+ st.image(img_result)