Spaces:
Runtime error
Runtime error
xcurvnubaim
commited on
Commit
•
a9f1bab
1
Parent(s):
765c987
feat: add object detection
Browse files- main.py +111 -2
- requirements.txt +4 -1
main.py
CHANGED
@@ -3,11 +3,17 @@ from fastapi import FastAPI, File, UploadFile
|
|
3 |
import tensorflow as tf
|
4 |
from PIL import Image
|
5 |
from io import BytesIO
|
|
|
|
|
|
|
|
|
6 |
|
7 |
app = FastAPI()
|
8 |
|
9 |
labels = []
|
10 |
-
|
|
|
|
|
11 |
with open("labels.txt") as f:
|
12 |
for line in f:
|
13 |
labels.append(line.replace('\n', ''))
|
@@ -17,13 +23,116 @@ def classify_image(img):
|
|
17 |
img_array = np.asarray(img.resize((224, 224)))[..., :3]
|
18 |
img_array = img_array.reshape((1, 224, 224, 3)) # Add batch dimension
|
19 |
img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
|
20 |
-
prediction =
|
21 |
confidences = {labels[i]: float(prediction[i]) for i in range(90)}
|
22 |
# Sort the confidences dictionary by value and get the top 3 items
|
23 |
# top_3_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:3])
|
24 |
|
25 |
return confidences
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
@app.post("/predict")
|
28 |
async def predict(file: bytes = File(...)):
|
29 |
img = Image.open(BytesIO(file))
|
|
|
3 |
import tensorflow as tf
|
4 |
from PIL import Image
|
5 |
from io import BytesIO
|
6 |
+
from ultralytics import YOLO
|
7 |
+
import cv2
|
8 |
+
from datetime import datetime
|
9 |
+
from fastapi.responses import FileResponse
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
13 |
labels = []
|
14 |
+
classification_model = tf.keras.models.load_model('./models.h5')
|
15 |
+
detection_model = YOLO('./best.pt')
|
16 |
+
|
17 |
with open("labels.txt") as f:
|
18 |
for line in f:
|
19 |
labels.append(line.replace('\n', ''))
|
|
|
23 |
img_array = np.asarray(img.resize((224, 224)))[..., :3]
|
24 |
img_array = img_array.reshape((1, 224, 224, 3)) # Add batch dimension
|
25 |
img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
|
26 |
+
prediction = classification_model.predict(img_array).flatten()
|
27 |
confidences = {labels[i]: float(prediction[i]) for i in range(90)}
|
28 |
# Sort the confidences dictionary by value and get the top 3 items
|
29 |
# top_3_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:3])
|
30 |
|
31 |
return confidences
|
32 |
|
33 |
+
def animal_detect_and_classify(img_path):
|
34 |
+
# Read the image
|
35 |
+
img = cv2.imread(img_path)
|
36 |
+
|
37 |
+
# Pass the image through the detection model and get the result
|
38 |
+
detect_results = detection_model(img)
|
39 |
+
|
40 |
+
combined_results = []
|
41 |
+
# print("dss", detect_results[0])
|
42 |
+
# Iterate over the detected objects
|
43 |
+
# Iterate over detections
|
44 |
+
for result in detect_results:
|
45 |
+
for box in result.boxes:
|
46 |
+
# print(box)
|
47 |
+
# Crop the RoI
|
48 |
+
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
49 |
+
detect_img = img[y1:y2, x1:x2]
|
50 |
+
# Convert the image to RGB format
|
51 |
+
detect_img = cv2.cvtColor(detect_img, cv2.COLOR_BGR2RGB)
|
52 |
+
|
53 |
+
# Resize the input image to the expected shape (224, 224)
|
54 |
+
detect_img = cv2.resize(detect_img, (224, 224))
|
55 |
+
|
56 |
+
# Convert the image to a numpy array
|
57 |
+
inp_array = np.array(detect_img)
|
58 |
+
|
59 |
+
# Reshape the array to match the expected input shape
|
60 |
+
inp_array = inp_array.reshape((-1, 224, 224, 3))
|
61 |
+
|
62 |
+
# Preprocess the input array
|
63 |
+
inp_array = tf.keras.applications.efficientnet.preprocess_input(inp_array)
|
64 |
+
|
65 |
+
# Make predictions using the classification model
|
66 |
+
prediction = classification_model.predict(inp_array)
|
67 |
+
# Map predictions to labels
|
68 |
+
threshold = 0.75
|
69 |
+
predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= threshold else "animal" for pred in prediction]
|
70 |
+
print(predicted_labels)
|
71 |
+
combined_results.append(((x1, y1, x2, y2), predicted_labels))
|
72 |
+
|
73 |
+
return combined_results
|
74 |
+
|
75 |
+
def generate_color(class_name):
|
76 |
+
# Generate a hash from the class name
|
77 |
+
color_hash = hash(class_name)
|
78 |
+
print(color_hash)
|
79 |
+
# Normalize the hash value to fit within the range of valid color values (0-255)
|
80 |
+
color_hash = abs(color_hash) % 16777216
|
81 |
+
R = color_hash//(256*256)
|
82 |
+
G = (color_hash//256) % 256
|
83 |
+
B = color_hash % 256
|
84 |
+
# Convert the hash value to RGB color format
|
85 |
+
color = (R, G, B)
|
86 |
+
|
87 |
+
return color
|
88 |
+
|
89 |
+
def plot_detected_rectangles(image, detections, output_path):
|
90 |
+
# Create a copy of the image to draw on
|
91 |
+
img_with_rectangles = image.copy()
|
92 |
+
|
93 |
+
# Iterate over each detected rectangle and its corresponding class name
|
94 |
+
for rectangle, class_names in detections:
|
95 |
+
# Extract the coordinates of the rectangle
|
96 |
+
x1, y1, x2, y2 = rectangle
|
97 |
+
|
98 |
+
# Generate a random color
|
99 |
+
color = generate_color(class_names[0])
|
100 |
+
|
101 |
+
# Draw the rectangle on the image
|
102 |
+
cv2.rectangle(img_with_rectangles, (x1, y1), (x2, y2), color, 2)
|
103 |
+
|
104 |
+
# Put the class names above the rectangle
|
105 |
+
for i, class_name in enumerate(class_names):
|
106 |
+
cv2.putText(img_with_rectangles, class_name, (x1, y1 - 10 - i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
107 |
+
|
108 |
+
# Show the image with rectangles and class names
|
109 |
+
cv2.imwrite(output_path, img_with_rectangles)
|
110 |
+
|
111 |
+
|
112 |
+
# Call the animal_detect_and_classify function to get detections
|
113 |
+
detections = animal_detect_and_classify('/content/cat_tiger.jpg')
|
114 |
+
|
115 |
+
# Plot the detected rectangles with their corresponding class names
|
116 |
+
plot_detected_rectangles(cv2.imread('/content/cat_tiger.jpg'), detections)
|
117 |
+
|
118 |
+
|
119 |
+
@app.post("/predict/v2")
|
120 |
+
async def predict_v2(file: UploadFile = File(...)):
|
121 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
122 |
+
filename = timestamp + file.filename
|
123 |
+
contents = await file.read()
|
124 |
+
image = Image.open(BytesIO(contents))
|
125 |
+
image.save("input/" + filename)
|
126 |
+
detections = animal_detect_and_classify("input/" + filename)
|
127 |
+
plot_detected_rectangles(cv2.imread("input/" + filename), detections, "output/" + filename)
|
128 |
+
return {"message": "Detection and classification completed successfully"}
|
129 |
+
|
130 |
+
@app.get("/image/")
|
131 |
+
async def get_image(image_name: str):
|
132 |
+
# Assume the images are stored in a directory named "images"
|
133 |
+
image_path = f"images/{image_name}"
|
134 |
+
return FileResponse(image_path)
|
135 |
+
|
136 |
@app.post("/predict")
|
137 |
async def predict(file: bytes = File(...)):
|
138 |
img = Image.open(BytesIO(file))
|
requirements.txt
CHANGED
@@ -10,4 +10,7 @@ uvicorn
|
|
10 |
python-multipart
|
11 |
numpy==1.25.2
|
12 |
Pillow==9.4.0
|
13 |
-
keras==2.15.0
|
|
|
|
|
|
|
|
10 |
python-multipart
|
11 |
numpy==1.25.2
|
12 |
Pillow==9.4.0
|
13 |
+
keras==2.15.0
|
14 |
+
ultralytics
|
15 |
+
squarify
|
16 |
+
cv2
|