Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -22,8 +22,6 @@ from io import BytesIO
|
|
22 |
# Load the YOLO model
|
23 |
from models.common import DetectMultiBackend
|
24 |
|
25 |
-
|
26 |
-
|
27 |
weights_path = "./last.pt"
|
28 |
device = torch.device("cpu") # Correctly define the device
|
29 |
model = DetectMultiBackend(weights_path, device=device) # Load YOLOv5 model correctly
|
@@ -49,24 +47,23 @@ transform = transforms.Compose([
|
|
49 |
|
50 |
OBJECT_NAMES = ['enemies']
|
51 |
|
52 |
-
def detect_objects_in_image(image):
|
53 |
|
|
|
54 |
# Ensure image is a PIL Image
|
|
|
55 |
if isinstance(image, torch.Tensor):
|
56 |
image = transforms.ToPILImage()(image) # Convert tensor to PIL image
|
57 |
-
|
58 |
-
if isinstance(
|
59 |
orig_w, orig_h = image.size # PIL image size returns (width, height)
|
60 |
else:
|
61 |
-
|
62 |
print(type(image))
|
63 |
exit()
|
64 |
|
65 |
# Apply transformation
|
66 |
img_tensor = transform(image).unsqueeze(0)
|
67 |
|
68 |
-
|
69 |
-
|
70 |
with torch.no_grad():
|
71 |
pred = model(img_tensor)[0]
|
72 |
|
@@ -106,7 +103,8 @@ def detect_objects_in_image(image):
|
|
106 |
if object_name in object_counts:
|
107 |
object_counts[object_name] += 1
|
108 |
cv2.rectangle(img_array, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
109 |
-
cv2.putText(img_array, f"{object_name}: {scores[i]:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
|
|
110 |
|
111 |
# Generate and return graph instead of dictionary
|
112 |
graph_image = generate_vehicle_count_graph(object_counts)
|
@@ -116,58 +114,61 @@ def detect_objects_in_image(image):
|
|
116 |
|
117 |
# def generate_vehicle_count_graph(object_counts):
|
118 |
# color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1']
|
119 |
-
|
120 |
# fig, ax = plt.subplots(figsize=(8, 5))
|
121 |
# labels = list(object_counts.keys())
|
122 |
# values = list(object_counts.values())
|
123 |
-
|
124 |
# ax.bar(labels, values, color=color_palette[:len(labels)])
|
125 |
-
|
126 |
# ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
|
127 |
# ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
|
128 |
# ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
|
129 |
-
|
130 |
# plt.xticks(rotation=45, ha='right', fontsize=10)
|
131 |
# plt.yticks(fontsize=10)
|
132 |
-
|
133 |
# plt.tight_layout()
|
134 |
|
135 |
# buf = BytesIO()
|
136 |
# plt.savefig(buf, format='png')
|
137 |
# buf.seek(0)
|
138 |
-
|
139 |
# return Image.open(buf)
|
140 |
|
141 |
def generate_vehicle_count_graph(object_counts):
|
142 |
-
color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1',
|
143 |
-
|
|
|
144 |
fig, ax = plt.subplots(figsize=(8, 5))
|
145 |
labels = list(object_counts.keys())
|
146 |
values = list(object_counts.values())
|
147 |
-
|
148 |
ax.bar(labels, values, color=color_palette[:len(labels)])
|
149 |
ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
|
150 |
ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
|
151 |
ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
|
152 |
-
|
153 |
plt.xticks(rotation=45, ha='right', fontsize=10)
|
154 |
plt.yticks(fontsize=10)
|
155 |
plt.tight_layout()
|
156 |
-
|
157 |
buf = BytesIO()
|
158 |
plt.savefig(buf, format='png')
|
159 |
buf.seek(0)
|
160 |
-
|
161 |
plt.close(fig) # ✅ CLOSE THE FIGURE TO FREE MEMORY
|
162 |
-
|
163 |
return Image.open(buf)
|
164 |
-
|
|
|
165 |
def detect_objects_in_video(video_input):
|
166 |
cap = cv2.VideoCapture(video_input)
|
167 |
if not cap.isOpened():
|
168 |
return "Error: Cannot open video file.", None # Returning a second value (None) to match expected outputs
|
169 |
|
170 |
-
frame_width, frame_height, fps = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(
|
|
|
171 |
temp_video_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
172 |
out = cv2.VideoWriter(temp_video_output, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
|
173 |
|
@@ -192,34 +193,37 @@ def detect_objects_in_video(video_input):
|
|
192 |
|
193 |
return temp_video_output, graph_image # Return both expected outputs
|
194 |
|
|
|
195 |
def greet(name):
|
196 |
return "Hello " + name + "!!"
|
197 |
|
|
|
198 |
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
199 |
|
200 |
from urllib.request import urlretrieve
|
201 |
|
202 |
# get image examples from github
|
203 |
-
urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-1450-_jpg.jpg?raw=true",
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
207 |
["clip2_-539-_jpg.jpg"]]
|
208 |
|
209 |
-
|
210 |
# define app features and run
|
211 |
title = "SpecLab Demo"
|
212 |
description = "<p style='text-align: center'>Gradio demo for an ASPP model architecture trained on the SpecLab dataset. To use it, simply add your image, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing. </p>"
|
213 |
article = "<p style='text-align: center'><a href='https://github.com/Nano1337/SpecLab'>Github Repo</a></p>"
|
214 |
css = "#0 {object-fit: contain;} #1 {object-fit: contain;}"
|
215 |
-
demo = gr.Interface(fn=detect_objects_in_image,
|
216 |
-
title=title,
|
217 |
description=description,
|
218 |
article=article,
|
219 |
-
inputs=gr.Image(elem_id=0, show_label=False),
|
220 |
outputs=gr.Image(elem_id=1, show_label=False),
|
221 |
-
css=css,
|
222 |
-
examples=examples,
|
223 |
cache_examples=True,
|
224 |
allow_flagging='never')
|
225 |
demo.launch()
|
|
|
22 |
# Load the YOLO model
|
23 |
from models.common import DetectMultiBackend
|
24 |
|
|
|
|
|
25 |
weights_path = "./last.pt"
|
26 |
device = torch.device("cpu") # Correctly define the device
|
27 |
model = DetectMultiBackend(weights_path, device=device) # Load YOLOv5 model correctly
|
|
|
47 |
|
48 |
OBJECT_NAMES = ['enemies']
|
49 |
|
|
|
50 |
|
51 |
+
def detect_objects_in_image(image):
|
52 |
# Ensure image is a PIL Image
|
53 |
+
data = Image.fromarray(array)
|
54 |
if isinstance(image, torch.Tensor):
|
55 |
image = transforms.ToPILImage()(image) # Convert tensor to PIL image
|
56 |
+
|
57 |
+
if isinstance(data, Image.Image):
|
58 |
orig_w, orig_h = image.size # PIL image size returns (width, height)
|
59 |
else:
|
60 |
+
raise TypeError(f"Expected a PIL Image but got ")
|
61 |
print(type(image))
|
62 |
exit()
|
63 |
|
64 |
# Apply transformation
|
65 |
img_tensor = transform(image).unsqueeze(0)
|
66 |
|
|
|
|
|
67 |
with torch.no_grad():
|
68 |
pred = model(img_tensor)[0]
|
69 |
|
|
|
103 |
if object_name in object_counts:
|
104 |
object_counts[object_name] += 1
|
105 |
cv2.rectangle(img_array, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
106 |
+
cv2.putText(img_array, f"{object_name}: {scores[i]:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
107 |
+
(0, 255, 0), 2)
|
108 |
|
109 |
# Generate and return graph instead of dictionary
|
110 |
graph_image = generate_vehicle_count_graph(object_counts)
|
|
|
114 |
|
115 |
# def generate_vehicle_count_graph(object_counts):
|
116 |
# color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1']
|
117 |
+
|
118 |
# fig, ax = plt.subplots(figsize=(8, 5))
|
119 |
# labels = list(object_counts.keys())
|
120 |
# values = list(object_counts.values())
|
121 |
+
|
122 |
# ax.bar(labels, values, color=color_palette[:len(labels)])
|
123 |
+
|
124 |
# ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
|
125 |
# ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
|
126 |
# ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
|
127 |
+
|
128 |
# plt.xticks(rotation=45, ha='right', fontsize=10)
|
129 |
# plt.yticks(fontsize=10)
|
130 |
+
|
131 |
# plt.tight_layout()
|
132 |
|
133 |
# buf = BytesIO()
|
134 |
# plt.savefig(buf, format='png')
|
135 |
# buf.seek(0)
|
136 |
+
|
137 |
# return Image.open(buf)
|
138 |
|
139 |
def generate_vehicle_count_graph(object_counts):
|
140 |
+
color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1',
|
141 |
+
'#B8B8D1']
|
142 |
+
|
143 |
fig, ax = plt.subplots(figsize=(8, 5))
|
144 |
labels = list(object_counts.keys())
|
145 |
values = list(object_counts.values())
|
146 |
+
|
147 |
ax.bar(labels, values, color=color_palette[:len(labels)])
|
148 |
ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold')
|
149 |
ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold')
|
150 |
ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold')
|
151 |
+
|
152 |
plt.xticks(rotation=45, ha='right', fontsize=10)
|
153 |
plt.yticks(fontsize=10)
|
154 |
plt.tight_layout()
|
155 |
+
|
156 |
buf = BytesIO()
|
157 |
plt.savefig(buf, format='png')
|
158 |
buf.seek(0)
|
159 |
+
|
160 |
plt.close(fig) # ✅ CLOSE THE FIGURE TO FREE MEMORY
|
161 |
+
|
162 |
return Image.open(buf)
|
163 |
+
|
164 |
+
|
165 |
def detect_objects_in_video(video_input):
|
166 |
cap = cv2.VideoCapture(video_input)
|
167 |
if not cap.isOpened():
|
168 |
return "Error: Cannot open video file.", None # Returning a second value (None) to match expected outputs
|
169 |
|
170 |
+
frame_width, frame_height, fps = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(
|
171 |
+
cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FPS))
|
172 |
temp_video_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
173 |
out = cv2.VideoWriter(temp_video_output, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
|
174 |
|
|
|
193 |
|
194 |
return temp_video_output, graph_image # Return both expected outputs
|
195 |
|
196 |
+
|
197 |
def greet(name):
|
198 |
return "Hello " + name + "!!"
|
199 |
|
200 |
+
|
201 |
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
202 |
|
203 |
from urllib.request import urlretrieve
|
204 |
|
205 |
# get image examples from github
|
206 |
+
urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-1450-_jpg.jpg?raw=true",
|
207 |
+
"clip2_-1450-_jpg.jpg") # make sure to use "copy image address when copying image from Github"
|
208 |
+
urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-539-_jpg.jpg?raw=true",
|
209 |
+
"clip2_-539-_jpg.jpg")
|
210 |
+
examples = [ # need to manually delete cache everytime new examples are added
|
211 |
+
["clip2_-1450-_jpg.jpg"],
|
212 |
["clip2_-539-_jpg.jpg"]]
|
213 |
|
|
|
214 |
# define app features and run
|
215 |
title = "SpecLab Demo"
|
216 |
description = "<p style='text-align: center'>Gradio demo for an ASPP model architecture trained on the SpecLab dataset. To use it, simply add your image, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing. </p>"
|
217 |
article = "<p style='text-align: center'><a href='https://github.com/Nano1337/SpecLab'>Github Repo</a></p>"
|
218 |
css = "#0 {object-fit: contain;} #1 {object-fit: contain;}"
|
219 |
+
demo = gr.Interface(fn=detect_objects_in_image,
|
220 |
+
title=title,
|
221 |
description=description,
|
222 |
article=article,
|
223 |
+
inputs=gr.Image(elem_id=0, show_label=False),
|
224 |
outputs=gr.Image(elem_id=1, show_label=False),
|
225 |
+
css=css,
|
226 |
+
examples=examples,
|
227 |
cache_examples=True,
|
228 |
allow_flagging='never')
|
229 |
demo.launch()
|