Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -50,7 +50,7 @@ def login(username, password):
|
|
50 |
@spaces.GPU(duration=300)
|
51 |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
|
52 |
"""
|
53 |
-
|
54 |
Args:
|
55 |
video_path: 输入视频路径
|
56 |
process_seconds: 处理时长(秒)
|
@@ -282,63 +282,45 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
282 |
trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
|
283 |
|
284 |
# 修改热力图生成部分
|
285 |
-
def generate_heatmap(points, width, height
|
286 |
-
"""
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
if 0 <= x < width and 0 <= y < height:
|
298 |
-
temp_heatmap = torch.zeros((height, width), device=device)
|
299 |
-
temp_heatmap[y, x] = 1.0
|
300 |
-
# 使用GPU的高斯模糊
|
301 |
-
temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
|
302 |
-
heatmap += temp_heatmap
|
303 |
-
except (ValueError, TypeError, IndexError) as e:
|
304 |
-
logger.warning(f"跳过无效点 {point}: {str(e)}")
|
305 |
-
continue
|
306 |
|
307 |
-
|
308 |
-
|
309 |
-
logger.info("使用CPU生成热力图")
|
310 |
-
heatmap = np.zeros((height, width), dtype=np.float32)
|
311 |
-
for point in points:
|
312 |
-
try:
|
313 |
-
x = int(round(float(point[0])))
|
314 |
-
y = int(round(float(point[1])))
|
315 |
-
|
316 |
-
if 0 <= x < width and 0 <= y < height:
|
317 |
-
temp_heatmap = np.zeros((height, width), dtype=np.float32)
|
318 |
-
temp_heatmap[y, x] = 1.0
|
319 |
-
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
|
320 |
-
heatmap += temp_heatmap
|
321 |
-
except (ValueError, TypeError, IndexError) as e:
|
322 |
-
logger.warning(f"跳过无效点 {point}: {str(e)}")
|
323 |
-
continue
|
324 |
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
# 生成热力图
|
332 |
logger.info("开始生成热力图")
|
333 |
-
heatmap = generate_heatmap(filtered_points, width, height
|
334 |
|
335 |
-
#
|
336 |
if np.max(heatmap) > 0:
|
337 |
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
|
338 |
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
|
|
|
339 |
heatmap_colored = cv2.addWeighted(heatmap_colored, 0.7, np.full_like(heatmap_colored, 255), 0.3, 0)
|
340 |
else:
|
341 |
-
#
|
342 |
heatmap_colored = np.full((height, width, 3), 255, dtype=np.uint8)
|
343 |
|
344 |
logger.info("热力图生成完成")
|
@@ -503,7 +485,7 @@ with gr.Blocks() as demo:
|
|
503 |
value=1,
|
504 |
step=1,
|
505 |
label="最大检测数量",
|
506 |
-
info="
|
507 |
)
|
508 |
process_btn = gr.Button("开始处��")
|
509 |
|
|
|
50 |
@spaces.GPU(duration=300)
|
51 |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
|
52 |
"""
|
53 |
+
处理视频并进行小鼠检测
|
54 |
Args:
|
55 |
video_path: 输入视频路径
|
56 |
process_seconds: 处理时长(秒)
|
|
|
282 |
trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
|
283 |
|
284 |
# 修改热力图生成部分
|
285 |
+
def generate_heatmap(points, width, height):
|
286 |
+
"""简化版本的热力图生成函数"""
|
287 |
+
logger.info(f"开始生成热力图,共 {len(points)} 个点")
|
288 |
+
|
289 |
+
# 使用numpy生成热力图
|
290 |
+
heatmap = np.zeros((height, width), dtype=np.float32)
|
291 |
+
|
292 |
+
for point in points:
|
293 |
+
try:
|
294 |
+
# 确保是整数坐标
|
295 |
+
x = min(max(0, int(point[0])), width-1)
|
296 |
+
y = min(max(0, int(point[1])), height-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
+
# 在点位置标记1
|
299 |
+
heatmap[y, x] = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
+
# 直接使用OpenCV的高斯模糊
|
302 |
+
temp_heatmap = cv2.GaussianBlur(heatmap, (31, 31), 10)
|
303 |
+
heatmap = temp_heatmap
|
304 |
+
|
305 |
+
except Exception as e:
|
306 |
+
logger.warning(f"处理点 {point} 时出错: {e}")
|
307 |
+
continue
|
308 |
+
|
309 |
+
logger.info("热力图生成完成")
|
310 |
+
return heatmap
|
311 |
|
312 |
# 生成热力图
|
313 |
logger.info("开始生成热力图")
|
314 |
+
heatmap = generate_heatmap(filtered_points, width, height)
|
315 |
|
316 |
+
# 归一化和着色
|
317 |
if np.max(heatmap) > 0:
|
318 |
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
|
319 |
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
|
320 |
+
# 添加白色背景
|
321 |
heatmap_colored = cv2.addWeighted(heatmap_colored, 0.7, np.full_like(heatmap_colored, 255), 0.3, 0)
|
322 |
else:
|
323 |
+
# 如果没有检测到点,返回白色图像
|
324 |
heatmap_colored = np.full((height, width, 3), 255, dtype=np.uint8)
|
325 |
|
326 |
logger.info("热力图生成完成")
|
|
|
485 |
value=1,
|
486 |
step=1,
|
487 |
label="最大检测数量",
|
488 |
+
info="每帧���多检测的目标数量"
|
489 |
)
|
490 |
process_btn = gr.Button("开始处��")
|
491 |
|