Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -50,7 +50,7 @@ def login(username, password):
|
|
50 |
@spaces.GPU(duration=300)
|
51 |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
|
52 |
"""
|
53 |
-
|
54 |
Args:
|
55 |
video_path: 输入视频路径
|
56 |
process_seconds: 处理时长(秒)
|
@@ -282,43 +282,66 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
282 |
trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
|
283 |
|
284 |
# 修改热力图生成部分
|
285 |
-
|
286 |
-
|
287 |
try:
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
except Exception as e:
|
301 |
-
logger.error(f"
|
302 |
-
#
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
|
|
|
|
312 |
else:
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
if 0 <= x < width and 0 <= y < height:
|
318 |
-
temp_heatmap = np.zeros((height, width), dtype=np.float32)
|
319 |
-
temp_heatmap[int(y), int(x)] = 1
|
320 |
-
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
|
321 |
-
heatmap += temp_heatmap
|
322 |
|
323 |
trajectory_frames = []
|
324 |
heatmap_frames = []
|
|
|
50 |
@spaces.GPU(duration=300)
|
51 |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
|
52 |
"""
|
53 |
+
处理视频并进行鼠检测
|
54 |
Args:
|
55 |
video_path: 输入视频路径
|
56 |
process_seconds: 处理时长(秒)
|
|
|
282 |
trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
|
283 |
|
284 |
# 修改热力图生成部分
|
285 |
+
def generate_heatmap(points, width, height, device='cuda'):
|
286 |
+
"""生成热力图的独立函数"""
|
287 |
try:
|
288 |
+
if device == 'cuda' and torch.cuda.is_available():
|
289 |
+
logger.info("使用GPU生成热力图")
|
290 |
+
heatmap = torch.zeros((height, width), device=device)
|
291 |
+
for point in points:
|
292 |
+
try:
|
293 |
+
# 确保坐标是整数
|
294 |
+
x = int(round(float(point[0])))
|
295 |
+
y = int(round(float(point[1])))
|
296 |
+
|
297 |
+
if 0 <= x < width and 0 <= y < height:
|
298 |
+
temp_heatmap = torch.zeros((height, width), device=device)
|
299 |
+
temp_heatmap[y, x] = 1.0
|
300 |
+
# 使用GPU的高斯模糊
|
301 |
+
temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
|
302 |
+
heatmap += temp_heatmap
|
303 |
+
except (ValueError, TypeError, IndexError) as e:
|
304 |
+
logger.warning(f"跳过无效点 {point}: {str(e)}")
|
305 |
+
continue
|
306 |
+
|
307 |
+
return heatmap.cpu().numpy()
|
308 |
+
else:
|
309 |
+
logger.info("使用CPU生成热力图")
|
310 |
+
heatmap = np.zeros((height, width), dtype=np.float32)
|
311 |
+
for point in points:
|
312 |
+
try:
|
313 |
+
x = int(round(float(point[0])))
|
314 |
+
y = int(round(float(point[1])))
|
315 |
+
|
316 |
+
if 0 <= x < width and 0 <= y < height:
|
317 |
+
temp_heatmap = np.zeros((height, width), dtype=np.float32)
|
318 |
+
temp_heatmap[y, x] = 1.0
|
319 |
+
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
|
320 |
+
heatmap += temp_heatmap
|
321 |
+
except (ValueError, TypeError, IndexError) as e:
|
322 |
+
logger.warning(f"跳过无效点 {point}: {str(e)}")
|
323 |
+
continue
|
324 |
+
|
325 |
+
return heatmap
|
326 |
except Exception as e:
|
327 |
+
logger.error(f"热力图生成失败: {str(e)}")
|
328 |
+
# 返回空热力图
|
329 |
+
return np.zeros((height, width), dtype=np.float32)
|
330 |
+
|
331 |
+
# 生成热力图
|
332 |
+
logger.info("开始生成热力图")
|
333 |
+
heatmap = generate_heatmap(filtered_points, width, height, device=device)
|
334 |
+
|
335 |
+
# 归一化和着色热力图
|
336 |
+
if np.max(heatmap) > 0:
|
337 |
+
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
|
338 |
+
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
|
339 |
+
heatmap_colored = cv2.addWeighted(heatmap_colored, 0.7, np.full_like(heatmap_colored, 255), 0.3, 0)
|
340 |
else:
|
341 |
+
# 如果热力图全为0,创建一个空白的彩色图像
|
342 |
+
heatmap_colored = np.full((height, width, 3), 255, dtype=np.uint8)
|
343 |
+
|
344 |
+
logger.info("热力图生成完成")
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
trajectory_frames = []
|
347 |
heatmap_frames = []
|