Hakureirm commited on
Commit
3217754
·
verified ·
1 Parent(s): 71b2d6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -13
app.py CHANGED
@@ -281,19 +281,44 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
281
 
282
  trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
283
 
 
284
  if torch.cuda.is_available():
285
- heatmap = torch.zeros((height, width), device=device)
286
- for x, y in filtered_points:
287
- # 确保坐标是整数
288
- x, y = int(x), int(y)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  if 0 <= x < width and 0 <= y < height:
290
- temp_heatmap = torch.zeros((height, width), device=device)
291
- temp_heatmap[y, x] = 1
292
- # 使用GPU的高斯模糊
293
- temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
294
  heatmap += temp_heatmap
295
-
296
- heatmap = heatmap.cpu().numpy()
297
 
298
  trajectory_frames = []
299
  heatmap_frames = []
@@ -372,7 +397,7 @@ def gaussian_blur_gpu(tensor, kernel_size=31, sigma=10):
372
  def get_gaussian_kernel2d(kernel_size, sigma):
373
  """生成2D高斯核"""
374
  kernel_x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
375
- x, y = torch.meshgrid(kernel_x, kernel_x)
376
  kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
377
  return kernel / kernel.sum()
378
 
@@ -432,7 +457,7 @@ with gr.Blocks() as demo:
432
  login_msg = gr.Textbox(label="消息", interactive=False)
433
 
434
  with gr.Group(visible=False) as main_interface:
435
- gr.Markdown("上传视频来检测和���析小鼠行为 | Upload a video to detect and analyze mice behavior")
436
 
437
  with gr.Row():
438
  with gr.Column():
@@ -473,7 +498,7 @@ with gr.Blocks() as demo:
473
  ### 使用说明
474
  1. 上传视频文件
475
  2. 设置处理参数:
476
- - 处理时长:需要分析的视频时长(秒)
477
  - 置信度阈值:检测的置信度要求(越高越严格)
478
  - 最大检测数量:每帧最多检测的目标数量
479
  3. 等待处理完成
 
281
 
282
  trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
283
 
284
+ # 修改热力图生成部分
285
  if torch.cuda.is_available():
286
+ logger.info("使用GPU生成热力图")
287
+ try:
288
+ heatmap = torch.zeros((height, width), device=device)
289
+ for pos in filtered_points:
290
+ # 确保坐标是整数并且在有效范围内
291
+ x, y = map(int, pos) # 明确转换为整数
292
+ if 0 <= x < width and 0 <= y < height:
293
+ temp_heatmap = torch.zeros((height, width), device=device)
294
+ temp_heatmap[int(y), int(x)] = 1 # 再次确保是整数
295
+ # 使用GPU的高斯模糊
296
+ temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
297
+ heatmap += temp_heatmap
298
+
299
+ heatmap = heatmap.cpu().numpy()
300
+ except Exception as e:
301
+ logger.error(f"GPU热力图生成失败: {str(e)}")
302
+ # 回退到CPU处理
303
+ logger.info("切换到CPU生成热力图")
304
+ heatmap = np.zeros((height, width), dtype=np.float32)
305
+ for pos in filtered_points:
306
+ x, y = map(int, pos)
307
+ if 0 <= x < width and 0 <= y < height:
308
+ temp_heatmap = np.zeros((height, width), dtype=np.float32)
309
+ temp_heatmap[int(y), int(x)] = 1
310
+ temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
311
+ heatmap += temp_heatmap
312
+ else:
313
+ logger.info("使用CPU生成热力图")
314
+ heatmap = np.zeros((height, width), dtype=np.float32)
315
+ for pos in filtered_points:
316
+ x, y = map(int, pos)
317
  if 0 <= x < width and 0 <= y < height:
318
+ temp_heatmap = np.zeros((height, width), dtype=np.float32)
319
+ temp_heatmap[int(y), int(x)] = 1
320
+ temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
 
321
  heatmap += temp_heatmap
 
 
322
 
323
  trajectory_frames = []
324
  heatmap_frames = []
 
397
  def get_gaussian_kernel2d(kernel_size, sigma):
398
  """生成2D高斯核"""
399
  kernel_x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
400
+ x, y = torch.meshgrid(kernel_x, kernel_x, indexing='ij')
401
  kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
402
  return kernel / kernel.sum()
403
 
 
457
  login_msg = gr.Textbox(label="消息", interactive=False)
458
 
459
  with gr.Group(visible=False) as main_interface:
460
+ gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior")
461
 
462
  with gr.Row():
463
  with gr.Column():
 
498
  ### 使用说明
499
  1. 上传视频文件
500
  2. 设置处理参数:
501
+ - 处理时间:需要分析的视频时长(秒)
502
  - 置信度阈值:检测的置信度要求(越高越严格)
503
  - 最大检测数量:每帧最多检测的目标数量
504
  3. 等待处理完成