Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,16 @@ import streamlit as st
|
|
| 10 |
warnings.filterwarnings("ignore", category=UserWarning)
|
| 11 |
from tempfile import NamedTemporaryFile
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
MODEL_PATH = "SD_model_weights.pth"
|
|
@@ -43,8 +52,62 @@ def detect_object(IMAGE_PATH):
|
|
| 43 |
num_list = filtered_indices[0].tolist()
|
| 44 |
filtered_labels = [labels[i] for i in num_list]
|
| 45 |
show_labeled_image(image, filtered_boxes, filtered_labels)
|
| 46 |
-
|
|
|
|
| 47 |
#img_array = img_to_array(img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
file = st.file_uploader('Upload an Image',type=(["jpeg","jpg","png"]))
|
| 50 |
|
|
|
|
| 10 |
warnings.filterwarnings("ignore", category=UserWarning)
|
| 11 |
from tempfile import NamedTemporaryFile
|
| 12 |
|
| 13 |
+
import cv2
|
| 14 |
+
import matplotlib.patches as patches
|
| 15 |
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
import matplotlib.image as mpimg
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
from detecto.utils import reverse_normalize, normalize_transform, _is_iterable
|
| 22 |
+
from torchvision import transforms
|
| 23 |
|
| 24 |
|
| 25 |
MODEL_PATH = "SD_model_weights.pth"
|
|
|
|
| 52 |
num_list = filtered_indices[0].tolist()
|
| 53 |
filtered_labels = [labels[i] for i in num_list]
|
| 54 |
show_labeled_image(image, filtered_boxes, filtered_labels)
|
| 55 |
+
|
| 56 |
+
show_image(image,filtered_boxes,filtered_labels)
|
| 57 |
#img_array = img_to_array(img)
|
| 58 |
+
def show_image(image, boxes, labels=None):
|
| 59 |
+
"""Show the image along with the specified boxes around detected objects.
|
| 60 |
+
Also displays each box's label if a list of labels is provided.
|
| 61 |
+
:param image: The image to plot. If the image is a normalized
|
| 62 |
+
torch.Tensor object, it will automatically be reverse-normalized
|
| 63 |
+
and converted to a PIL image for plotting.
|
| 64 |
+
:type image: numpy.ndarray or torch.Tensor
|
| 65 |
+
:param boxes: A torch tensor of size (N, 4) where N is the number
|
| 66 |
+
of boxes to plot, or simply size 4 if N is 1.
|
| 67 |
+
:type boxes: torch.Tensor
|
| 68 |
+
:param labels: (Optional) A list of size N giving the labels of
|
| 69 |
+
each box (labels[i] corresponds to boxes[i]). Defaults to None.
|
| 70 |
+
:type labels: torch.Tensor or None
|
| 71 |
+
**Example**::
|
| 72 |
+
>>> from detecto.core import Model
|
| 73 |
+
>>> from detecto.utils import read_image
|
| 74 |
+
>>> from detecto.visualize import show_labeled_image
|
| 75 |
+
>>> model = Model.load('model_weights.pth', ['tick', 'gate'])
|
| 76 |
+
>>> image = read_image('image.jpg')
|
| 77 |
+
>>> labels, boxes, scores = model.predict(image)
|
| 78 |
+
>>> show_labeled_image(image, boxes, labels)
|
| 79 |
+
"""
|
| 80 |
+
fig, ax = plt.subplots(1)
|
| 81 |
+
# If the image is already a tensor, convert it back to a PILImage
|
| 82 |
+
# and reverse normalize it
|
| 83 |
+
if isinstance(image, torch.Tensor):
|
| 84 |
+
image = reverse_normalize(image)
|
| 85 |
+
image = transforms.ToPILImage()(image)
|
| 86 |
+
ax.imshow(image)
|
| 87 |
+
|
| 88 |
+
# Show a single box or multiple if provided
|
| 89 |
+
if boxes.ndim == 1:
|
| 90 |
+
boxes = boxes.view(1, 4)
|
| 91 |
+
|
| 92 |
+
if labels is not None and not _is_iterable(labels):
|
| 93 |
+
labels = [labels]
|
| 94 |
+
|
| 95 |
+
# Plot each box
|
| 96 |
+
for i in range(2):
|
| 97 |
+
box = boxes[i]
|
| 98 |
+
width, height = (box[2] - box[0]).item(), (box[3] - box[1]).item()
|
| 99 |
+
initial_pos = (box[0].item(), box[1].item())
|
| 100 |
+
rect = patches.Rectangle(initial_pos, width, height, linewidth=1,
|
| 101 |
+
edgecolor='r', facecolor='none')
|
| 102 |
+
if labels:
|
| 103 |
+
ax.text(box[0] + 5, box[1] - 5, '{}'.format(labels[i]), color='red')
|
| 104 |
+
|
| 105 |
+
ax.add_patch(rect)
|
| 106 |
+
|
| 107 |
+
cp = os.path.abspath(os.getcwd()) + '/foo.png'
|
| 108 |
+
plt.savefig(cp)
|
| 109 |
+
plt.close(fig)
|
| 110 |
+
#print(type(plt
|
| 111 |
|
| 112 |
file = st.file_uploader('Upload an Image',type=(["jpeg","jpg","png"]))
|
| 113 |
|