Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,73 +8,157 @@ import gradio as gr
|
|
8 |
import numpy as np
|
9 |
import requests
|
10 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
11 |
import torch
|
12 |
from torchvision import transforms
|
13 |
from PIL import Image
|
|
|
|
|
14 |
|
15 |
# Load the YOLO model
|
16 |
-
model_path = "./best
|
17 |
model = torch.jit.load(model_path, map_location=torch.device("cpu"))
|
18 |
model.eval()
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
)
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def greet(name):
|
80 |
return "Hello " + name + "!!"
|
|
|
8 |
import numpy as np
|
9 |
import requests
|
10 |
from PIL import Image
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
import cv2
|
14 |
+
import tempfile
|
15 |
+
import numpy as np
|
16 |
import torch
|
17 |
from torchvision import transforms
|
18 |
from PIL import Image
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
from io import BytesIO
|
21 |
|
22 |
# Load the YOLO model
|
23 |
+
model_path = "./best.pt"
|
24 |
model = torch.jit.load(model_path, map_location=torch.device("cpu"))
|
25 |
model.eval()
|
26 |
|
27 |
+
transform = transforms.Compose([
|
28 |
+
transforms.Resize((640, 640)),
|
29 |
+
transforms.ToTensor(),
|
30 |
+
])
|
31 |
+
|
32 |
+
OBJECT_NAMES = ['enemies']
|
33 |
+
|
34 |
+
def detect_objects_in_image(image):
|
35 |
+
img_tensor = transform(image).unsqueeze(0)
|
36 |
+
orig_w, orig_h = image.size
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
pred = model(img_tensor)[0]
|
40 |
+
|
41 |
+
if isinstance(pred[0], torch.Tensor):
|
42 |
+
pred = [p.cpu().numpy() for p in pred]
|
43 |
+
|
44 |
+
pred = np.concatenate(pred, axis=0)
|
45 |
+
conf_thres = 0.25
|
46 |
+
mask = pred[:, 4] > conf_thres
|
47 |
+
pred = pred[mask]
|
48 |
+
|
49 |
+
if len(pred) == 0:
|
50 |
+
return Image.fromarray(np.array(image)), None # Return only image and None for graph
|
51 |
+
|
52 |
+
boxes, scores, class_probs = pred[:, :4], pred[:, 4], pred[:, 5:]
|
53 |
+
class_ids = np.argmax(class_probs, axis=1)
|
54 |
+
|
55 |
+
boxes[:, 0] = boxes[:, 0] - (boxes[:, 2] / 2)
|
56 |
+
boxes[:, 1] = boxes[:, 1] - (boxes[:, 3] / 2)
|
57 |
+
boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
|
58 |
+
boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
|
59 |
+
|
60 |
+
boxes[:, [0, 2]] *= orig_w / 640
|
61 |
+
boxes[:, [1, 3]] *= orig_h / 640
|
62 |
+
boxes = np.clip(boxes, 0, [orig_w, orig_h, orig_w, orig_h])
|
63 |
+
|
64 |
+
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), conf_thres, 0.5)
|
65 |
+
|
66 |
+
object_counts = {name: 0 for name in OBJECT_NAMES}
|
67 |
+
img_array = np.array(image)
|
68 |
+
|
69 |
+
if len(indices) > 0:
|
70 |
+
for i in indices.flatten():
|
71 |
+
x1, y1, x2, y2 = map(int, boxes[i])
|
72 |
+
cls = class_ids[i]
|
73 |
+
object_name = OBJECT_NAMES[cls] if cls < len(OBJECT_NAMES) else f"Unknown ({cls})"
|
74 |
+
if object_name in object_counts:
|
75 |
+
object_counts[object_name] += 1
|
76 |
+
cv2.rectangle(img_array, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
77 |
+
cv2.putText(img_array, f"{object_name}: {scores[i]:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
78 |
+
|
79 |
+
# Generate and return graph instead of dictionary
|
80 |
+
graph_image = generate_vehicle_count_graph(object_counts)
|
81 |
+
|
82 |
+
return Image.fromarray(img_array), graph_image # Now returning only 2 outputs
|
83 |
+
|
84 |
+
|
85 |
+
# def generate_vehicle_count_graph(object_counts):
|
86 |
+
# color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1']
|
87 |
+
|
88 |
+
# fig, ax = plt.subplots(figsize=(8, 5))
|
89 |
+
# labels = list(object_counts.keys())
|
90 |
+
# values = list(object_counts.values())
|
91 |
+
|
92 |
+
# ax.bar(labels, values, color=color_palette[:len(labels)])
|
93 |
+
|
94 |
+
# ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
|
95 |
+
# ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
|
96 |
+
# ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
|
97 |
+
|
98 |
+
# plt.xticks(rotation=45, ha='right', fontsize=10)
|
99 |
+
# plt.yticks(fontsize=10)
|
100 |
+
|
101 |
+
# plt.tight_layout()
|
102 |
+
|
103 |
+
# buf = BytesIO()
|
104 |
+
# plt.savefig(buf, format='png')
|
105 |
+
# buf.seek(0)
|
106 |
+
|
107 |
+
# return Image.open(buf)
|
108 |
+
|
109 |
+
def generate_vehicle_count_graph(object_counts):
|
110 |
+
color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1']
|
111 |
+
|
112 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
113 |
+
labels = list(object_counts.keys())
|
114 |
+
values = list(object_counts.values())
|
115 |
+
|
116 |
+
ax.bar(labels, values, color=color_palette[:len(labels)])
|
117 |
+
ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
|
118 |
+
ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
|
119 |
+
ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
|
120 |
+
|
121 |
+
plt.xticks(rotation=45, ha='right', fontsize=10)
|
122 |
+
plt.yticks(fontsize=10)
|
123 |
+
plt.tight_layout()
|
124 |
+
|
125 |
+
buf = BytesIO()
|
126 |
+
plt.savefig(buf, format='png')
|
127 |
+
buf.seek(0)
|
128 |
+
|
129 |
+
plt.close(fig) # ✅ CLOSE THE FIGURE TO FREE MEMORY
|
130 |
+
|
131 |
+
return Image.open(buf)
|
132 |
+
|
133 |
+
def detect_objects_in_video(video_input):
|
134 |
+
cap = cv2.VideoCapture(video_input)
|
135 |
+
if not cap.isOpened():
|
136 |
+
return "Error: Cannot open video file.", None # Returning a second value (None) to match expected outputs
|
137 |
+
|
138 |
+
frame_width, frame_height, fps = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FPS))
|
139 |
+
temp_video_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
140 |
+
out = cv2.VideoWriter(temp_video_output, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
|
141 |
+
|
142 |
+
# Initialize the counts for vehicle categories
|
143 |
+
total_counts = {name: 0 for name in ['car', 'truck', 'bus', 'motorcycle', 'bicycle']}
|
144 |
+
|
145 |
+
while cap.isOpened():
|
146 |
+
ret, frame = cap.read()
|
147 |
+
if not ret:
|
148 |
+
break
|
149 |
+
|
150 |
+
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
151 |
+
|
152 |
+
# Get frame with detected objects and graph
|
153 |
+
frame_with_boxes, graph_image = detect_objects_in_image(image)
|
154 |
+
|
155 |
+
# Convert image back to OpenCV format for writing video
|
156 |
+
out.write(cv2.cvtColor(np.array(frame_with_boxes), cv2.COLOR_RGB2BGR))
|
157 |
+
|
158 |
+
cap.release()
|
159 |
+
out.release()
|
160 |
+
|
161 |
+
return temp_video_output, graph_image # Return both expected outputs
|
162 |
|
163 |
def greet(name):
|
164 |
return "Hello " + name + "!!"
|