Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -281,19 +281,44 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
281 |
|
282 |
trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
|
283 |
|
|
|
284 |
if torch.cuda.is_available():
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
if 0 <= x < width and 0 <= y < height:
|
290 |
-
temp_heatmap =
|
291 |
-
temp_heatmap[y, x] = 1
|
292 |
-
|
293 |
-
temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
|
294 |
heatmap += temp_heatmap
|
295 |
-
|
296 |
-
heatmap = heatmap.cpu().numpy()
|
297 |
|
298 |
trajectory_frames = []
|
299 |
heatmap_frames = []
|
@@ -372,7 +397,7 @@ def gaussian_blur_gpu(tensor, kernel_size=31, sigma=10):
|
|
372 |
def get_gaussian_kernel2d(kernel_size, sigma):
|
373 |
"""生成2D高斯核"""
|
374 |
kernel_x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
|
375 |
-
x, y = torch.meshgrid(kernel_x, kernel_x)
|
376 |
kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
|
377 |
return kernel / kernel.sum()
|
378 |
|
@@ -432,7 +457,7 @@ with gr.Blocks() as demo:
|
|
432 |
login_msg = gr.Textbox(label="消息", interactive=False)
|
433 |
|
434 |
with gr.Group(visible=False) as main_interface:
|
435 |
-
gr.Markdown("
|
436 |
|
437 |
with gr.Row():
|
438 |
with gr.Column():
|
@@ -473,7 +498,7 @@ with gr.Blocks() as demo:
|
|
473 |
### 使用说明
|
474 |
1. 上传视频文件
|
475 |
2. 设置处理参数:
|
476 |
-
-
|
477 |
- 置信度阈值:检测的置信度要求(越高越严格)
|
478 |
- 最大检测数量:每帧最多检测的目标数量
|
479 |
3. 等待处理完成
|
|
|
281 |
|
282 |
trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
|
283 |
|
284 |
+
# 修改热力图生成部分
|
285 |
if torch.cuda.is_available():
|
286 |
+
logger.info("使用GPU生成热力图")
|
287 |
+
try:
|
288 |
+
heatmap = torch.zeros((height, width), device=device)
|
289 |
+
for pos in filtered_points:
|
290 |
+
# 确保坐标是整数并且在有效范围内
|
291 |
+
x, y = map(int, pos) # 明确转换为整数
|
292 |
+
if 0 <= x < width and 0 <= y < height:
|
293 |
+
temp_heatmap = torch.zeros((height, width), device=device)
|
294 |
+
temp_heatmap[int(y), int(x)] = 1 # 再次确保是整数
|
295 |
+
# 使用GPU的高斯模糊
|
296 |
+
temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
|
297 |
+
heatmap += temp_heatmap
|
298 |
+
|
299 |
+
heatmap = heatmap.cpu().numpy()
|
300 |
+
except Exception as e:
|
301 |
+
logger.error(f"GPU热力图生成失败: {str(e)}")
|
302 |
+
# 回退到CPU处理
|
303 |
+
logger.info("切换到CPU生成热力图")
|
304 |
+
heatmap = np.zeros((height, width), dtype=np.float32)
|
305 |
+
for pos in filtered_points:
|
306 |
+
x, y = map(int, pos)
|
307 |
+
if 0 <= x < width and 0 <= y < height:
|
308 |
+
temp_heatmap = np.zeros((height, width), dtype=np.float32)
|
309 |
+
temp_heatmap[int(y), int(x)] = 1
|
310 |
+
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
|
311 |
+
heatmap += temp_heatmap
|
312 |
+
else:
|
313 |
+
logger.info("使用CPU生成热力图")
|
314 |
+
heatmap = np.zeros((height, width), dtype=np.float32)
|
315 |
+
for pos in filtered_points:
|
316 |
+
x, y = map(int, pos)
|
317 |
if 0 <= x < width and 0 <= y < height:
|
318 |
+
temp_heatmap = np.zeros((height, width), dtype=np.float32)
|
319 |
+
temp_heatmap[int(y), int(x)] = 1
|
320 |
+
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
|
|
|
321 |
heatmap += temp_heatmap
|
|
|
|
|
322 |
|
323 |
trajectory_frames = []
|
324 |
heatmap_frames = []
|
|
|
397 |
def get_gaussian_kernel2d(kernel_size, sigma):
|
398 |
"""生成2D高斯核"""
|
399 |
kernel_x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
|
400 |
+
x, y = torch.meshgrid(kernel_x, kernel_x, indexing='ij')
|
401 |
kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
|
402 |
return kernel / kernel.sum()
|
403 |
|
|
|
457 |
login_msg = gr.Textbox(label="消息", interactive=False)
|
458 |
|
459 |
with gr.Group(visible=False) as main_interface:
|
460 |
+
gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior")
|
461 |
|
462 |
with gr.Row():
|
463 |
with gr.Column():
|
|
|
498 |
### 使用说明
|
499 |
1. 上传视频文件
|
500 |
2. 设置处理参数:
|
501 |
+
- 处理时间:需要分析的视频时长(秒)
|
502 |
- 置信度阈值:检测的置信度要求(越高越严格)
|
503 |
- 最大检测数量:每帧最多检测的目标数量
|
504 |
3. 等待处理完成
|