Spaces:
Runtime error
Runtime error
File size: 3,365 Bytes
cd4c90e 9fbf078 cd4c90e 9fbf078 cd4c90e 9fbf078 f6654ff 9fbf078 cd4c90e 9fbf078 cd4c90e 9fbf078 ab5b42b cd4c90e b80c100 cd4c90e b80c100 1463eb9 b80c100 cd4c90e b80c100 cd4c90e b80c100 9fbf078 cd4c90e 9fbf078 cd4c90e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
import cv2
import PIL
import torch
from classifier import CustomEfficientNet
from model import get_model, predict, prepare_prediction, predict_class
print('Creating the model')
model = get_model('checkpoint.ckpt')
print('Loading the classifier')
classifier = CustomEfficientNet(target_size=7, pretrained=False)
classifier.load_state_dict(torch.load('class_efficientB0_taco_7_class.pth', map_location='cpu'))
def plot_img_no_mask(image, boxes, labels):
colors = {
0: (255,255,0),
1: (255, 0, 0),
2: (0, 0, 255),
3: (0,128,0),
4: (255,165,0),
5: (230,230,250),
6: (192,192,192)
}
texts = {
0: 'plastic',
1: 'dangerous',
2: 'carton',
3: 'glass',
4: 'organic',
5: 'rest',
6: 'other'
}
# Show image
boxes = boxes.cpu().detach().numpy().astype(np.int32)
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
for i, box in enumerate(boxes):
color = colors[labels[i]]
[x1, y1, x2, y2] = np.array(box).astype(int)
# Si no se hace la copia da error en cv2.rectangle
image = np.array(image).copy()
pt1 = (x1, y1)
pt2 = (x2, y2)
cv2.rectangle(image, pt1, pt2, color, thickness=5)
cv2.putText(image, texts[labels[i]], (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color)
plt.axis('off')
ax.imshow(image)
fig.savefig("img.png", bbox_inches='tight')
st.subheader('Upload Custom Image')
image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
st.subheader('Example Images')
example_imgs = [
'example_imgs/basura_4_2.jpg',
'example_imgs/basura_1.jpg',
'example_imgs/basura_3.jpg'
]
with st.container() as cont:
st.image(example_imgs[0], width=150, caption='1')
if st.button('Select Image', key='Image_1'):
image_file = example_imgs[0]
with st.container() as cont:
st.image(example_imgs[1], width=150, caption='2')
if st.button('Select Image', key='Image_2'):
image_file = example_imgs[1]
with st.container() as cont:
st.image(example_imgs[2], width=150, caption='2')
if st.button('Select Image', key='Image_3'):
image_file = example_imgs[2]
st.subheader('Detection parameters')
detection_threshold = st.slider('Detection threshold',
min_value=0.0,
max_value=1.0,
value=0.5,
step=0.1)
nms_threshold = st.slider('NMS threshold',
min_value=0.0,
max_value=1.0,
value=0.3,
step=0.1)
st.subheader('Prediction')
if image_file is not None:
print('Getting predictions')
if isinstance(image_file, str):
data = image_file
else:
data = image_file.read()
pred_dict = predict(model, data, detection_threshold)
print('Fixing the preds')
boxes, image = prepare_prediction(pred_dict, nms_threshold)
print('Predicting classes')
labels = predict_class(classifier, image, boxes)
print('Plotting')
plot_img_no_mask(image, boxes, labels)
img = PIL.Image.open('img.png')
st.image(img,width=750) |