import numpy as np from PIL import Image,ImageColor,ImageDraw,ImageFont import torch from torch import nn import torchvision from torchvision import datasets, models, transforms import streamlit as st # 可视化函数 def plot_detection(image,prediction,idx2names,min_score = 0.8): image_result = image.copy() boxes,labels,scores = prediction['boxes'],prediction['labels'],prediction['scores'] draw = ImageDraw.Draw(image_result) for idx in range(boxes.shape[0]): if scores[idx] >= min_score: x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3] name = idx2names.get(str(labels[idx].item())) score = scores[idx] draw.rectangle((x1,y1,x2,y2), fill=None, outline ='lawngreen',width = 2) draw.text((x1,y1),name+":\n"+str(round(score.item(),2)),fill="red") return image_result # 加载模型 @st.cache() def load_model(): num_classes = 91 model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True,num_classes = num_classes) if torch.cuda.is_available(): model.to("cuda:0") model.eval() model.idx2names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car', '4': 'motorcycle', '5': 'airplane', '6': 'bus', '7': 'train', '8': 'truck', '9': 'boat', '10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign', '14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat', '18': 'dog', '19': 'horse', '20': 'sheep', '21': 'cow', '22': 'elephant', '23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack', '28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase', '34': 'frisbee', '35': 'skis', '36': 'snowboard', '37': 'sports ball', '38': 'kite','39': 'baseball bat', '40': 'baseball glove', '41': 'skateboard', '42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass', '47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl', '52': 'banana', '53': 'apple', '54': 'sandwich', '55': 'orange', '56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza', '60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch', '64': 'potted plant', '65': 'bed', '67': 'dining table', '70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse', '75': 'remote', '76': 'keyboard', '77': 'cell phone', '78': 'microwave', '79': 'oven', '80': 'toaster', '81': 'sink', '82': 'refrigerator', '84': 'book', '85': 'clock', '86': 'vase', '87': 'scissors', '88': 'teddybear', '89': 'hair drier', '90': 'toothbrush'} return model def predict_detection(model,image_path,min_score=0.8): # 准备数据 inputs = [] img = Image.open(image_path).convert("RGB") img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float() if torch.cuda.is_available(): img_tensor = img_tensor.cuda() inputs.append(img_tensor) # 预测结果 with torch.no_grad(): predictions = model(inputs) # 结果可视化 img_result = plot_detection(img,predictions[0], model.idx2names,min_score = min_score) return img_result st.title("FasterRCNN功能演示") st.header("FasterRCNN Input:") image_file = st.file_uploader("upload a image file(jpg/png) to predict:") if image_file is not None: try: st.image(image_file) except Exception as err: st.write(err) else: image_file = "https://tva1.sinaimg.cn/large/e6c9d24egy1h566tcs188j20al0dwabm.jpg" st.image(image_file) min_score = st.slider(label="choose the min_score parameter:",min_value=0.1,max_value=0.98,value=0.8) st.header("FasterRCNN Prediction:") with st.spinner('waitting for prediction...'): model = load_model() img_result = predict_detection(model,image_file,min_score=min_score) st.image(img_result)