Priyanka-Kumavat-At-TE's picture
Update app.py
bb2f729
raw
history blame
6.13 kB
import streamlit as st
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import ToTensor
from PIL import Image, ImageDraw
import cv2
import numpy as np
import pandas as pd
import os
import tempfile
from tempfile import NamedTemporaryFile
# Create an FRCNN model instance with the same structure as the saved model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(num_classes=91)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the saved parameters into the model
model.load_state_dict(torch.load("frcnn_model.pth"))
# Define the classes for object detection
classes = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A',
'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A',
'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
# Set the threshold for object detection. It is IoU (Intersection over Union)
threshold = 0.5
st.title(""" Image Object Detections """)
# st.subheader("Prediction of Object Detection")
st.write(""" The Faster R-CNN (Region-based Convolutional Neural Network) is a cutting-edge object detection model that combines deep
learning with region proposal networks to achieve highly accurate object detection in images.
It is trained on a large dataset of images and can detect a wide range of objects with high precision and recall.
The model is based on the ResNet-50 architecture, which allows it to capture complex visual features from the input image.
It uses a two-stage approach, first proposing regions of interest (RoIs) in the image and then classifying and refining the
object boundaries within these RoIs. This approach makes it extremely efficient and accurate in detecting multiple objects
in a single image.
""")
images = ["test2.jpg","img7.jpg","img20.jpg","img23.jpg","test1.jpg","img18.jpg","img3.jpg","img15.jpg","img17.jpg"]
with st.sidebar:
st.write("Choose an Image")
st.image(images)
# define the function to perform object detection on an image
def detect_objects(image_path):
# load the image
image = Image.open(image_path).convert('RGB')
# convert the image to a tensor
image_tensor = ToTensor()(image).to(device)
# run the image through the model to get the predictions
model.eval()
with torch.no_grad():
predictions = model([image_tensor])
# filter out the predictions below the threshold
scores = predictions[0]['scores'].cpu().numpy()
boxes = predictions[0]['boxes'].cpu().numpy()
labels = predictions[0]['labels'].cpu().numpy()
mask = scores > threshold
scores = scores[mask]
boxes = boxes[mask]
labels = labels[mask]
# create a new image with the predicted objects outlined in rectangles
draw = ImageDraw.Draw(image)
for box, label in zip(boxes, labels):
# draw the rectangle around the object
draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline='red')
# write the object class above the rectangle
class_name = classes[label]
draw.text((box[0], box[1]), class_name, fill='yellow')
# show the image
st.write("Obects detected in the image are: ")
st.image(image, use_column_width=True)
# st.image.show()
file = st.file_uploader('Upload an Image', type=(["jpeg", "jpg", "png"]))
if file is None:
st.write("Please upload an image file")
else:
image = Image.open(file)
st.write("Input Image")
st.image(image, use_column_width=True)
with NamedTemporaryFile(dir='.', suffix='.' + file.name.split('.')[-1]) as f:
f.write(file.getbuffer())
# your_function_which_takes_a_path(f.name)
detect_objects(f.name)
# if file is None:
# st.write("Please upload an image file")
# else:
# image = Image.open(file)
# st.write("Input Image")
# st.image(image, use_column_width=True)
# with NamedTemporaryFile(dir='.', suffix='.jpeg') as f: # this line gives error and only accepts .jpeg and so used above snippet
# f.write(file.getbuffer()) # which will accepts all formats of images.
# # your_function_which_takes_a_path(f.name)
# detect_objects(f.name)
st.write(""" This Streamlit app provides a user-friendly interface for uploading an image and visualizing the output of the Faster R-CNN
model. It displays the uploaded image along with the predicted objects highlighted with bounding box overlays. The app allows
users to explore the detected objects in the image, providing valuable insights and understanding of the model's predictions.
It can be used for a wide range of applications, such as object recognition, image analysis, and visual storytelling.
Whether it's identifying objects in real-world images or understanding the capabilities of state-of-the-art object detection
models, this Streamlit app powered by Faster R-CNN is a powerful tool for computer vision tasks.
""")