Hakureirm commited on
Commit
72f82cd
·
verified ·
1 Parent(s): ed63729

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -48
app.py CHANGED
@@ -50,7 +50,7 @@ def login(username, password):
50
  @spaces.GPU(duration=300)
51
  def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
52
  """
53
- 处理视频并进行鼠检测
54
  Args:
55
  video_path: 输入视频路径
56
  process_seconds: 处理时长(秒)
@@ -282,63 +282,45 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
282
  trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
283
 
284
  # 修改热力图生成部分
285
- def generate_heatmap(points, width, height, device='cuda'):
286
- """生成热力图的独立函数"""
287
- try:
288
- if device == 'cuda' and torch.cuda.is_available():
289
- logger.info("使用GPU生成热力图")
290
- heatmap = torch.zeros((height, width), device=device)
291
- for point in points:
292
- try:
293
- # 确保坐标是整数
294
- x = int(round(float(point[0])))
295
- y = int(round(float(point[1])))
296
-
297
- if 0 <= x < width and 0 <= y < height:
298
- temp_heatmap = torch.zeros((height, width), device=device)
299
- temp_heatmap[y, x] = 1.0
300
- # 使用GPU的高斯模糊
301
- temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
302
- heatmap += temp_heatmap
303
- except (ValueError, TypeError, IndexError) as e:
304
- logger.warning(f"跳过无效点 {point}: {str(e)}")
305
- continue
306
 
307
- return heatmap.cpu().numpy()
308
- else:
309
- logger.info("使用CPU生成热力图")
310
- heatmap = np.zeros((height, width), dtype=np.float32)
311
- for point in points:
312
- try:
313
- x = int(round(float(point[0])))
314
- y = int(round(float(point[1])))
315
-
316
- if 0 <= x < width and 0 <= y < height:
317
- temp_heatmap = np.zeros((height, width), dtype=np.float32)
318
- temp_heatmap[y, x] = 1.0
319
- temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
320
- heatmap += temp_heatmap
321
- except (ValueError, TypeError, IndexError) as e:
322
- logger.warning(f"跳过无效点 {point}: {str(e)}")
323
- continue
324
 
325
- return heatmap
326
- except Exception as e:
327
- logger.error(f"热力图生成失败: {str(e)}")
328
- # 返回空热力图
329
- return np.zeros((height, width), dtype=np.float32)
 
 
 
 
 
330
 
331
  # 生成热力图
332
  logger.info("开始生成热力图")
333
- heatmap = generate_heatmap(filtered_points, width, height, device=device)
334
 
335
- # 归一化和着色热力图
336
  if np.max(heatmap) > 0:
337
  heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
338
  heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
 
339
  heatmap_colored = cv2.addWeighted(heatmap_colored, 0.7, np.full_like(heatmap_colored, 255), 0.3, 0)
340
  else:
341
- # 如果热力图全为0,创建一个空白的彩色图像
342
  heatmap_colored = np.full((height, width, 3), 255, dtype=np.uint8)
343
 
344
  logger.info("热力图生成完成")
@@ -503,7 +485,7 @@ with gr.Blocks() as demo:
503
  value=1,
504
  step=1,
505
  label="最大检测数量",
506
- info="每帧最多检测的目标数量"
507
  )
508
  process_btn = gr.Button("开始处��")
509
 
 
50
  @spaces.GPU(duration=300)
51
  def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
52
  """
53
+ 处理视频并进行小鼠检测
54
  Args:
55
  video_path: 输入视频路径
56
  process_seconds: 处理时长(秒)
 
282
  trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
283
 
284
  # 修改热力图生成部分
285
+ def generate_heatmap(points, width, height):
286
+ """简化版本的热力图生成函数"""
287
+ logger.info(f"开始生成热力图,共 {len(points)} 个点")
288
+
289
+ # 使用numpy生成热力图
290
+ heatmap = np.zeros((height, width), dtype=np.float32)
291
+
292
+ for point in points:
293
+ try:
294
+ # 确保是整数坐标
295
+ x = min(max(0, int(point[0])), width-1)
296
+ y = min(max(0, int(point[1])), height-1)
 
 
 
 
 
 
 
 
 
297
 
298
+ # 在点位置标记1
299
+ heatmap[y, x] = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ # 直接使用OpenCV的高斯模糊
302
+ temp_heatmap = cv2.GaussianBlur(heatmap, (31, 31), 10)
303
+ heatmap = temp_heatmap
304
+
305
+ except Exception as e:
306
+ logger.warning(f"处理点 {point} 时出错: {e}")
307
+ continue
308
+
309
+ logger.info("热力图生成完成")
310
+ return heatmap
311
 
312
  # 生成热力图
313
  logger.info("开始生成热力图")
314
+ heatmap = generate_heatmap(filtered_points, width, height)
315
 
316
+ # 归一化和着色
317
  if np.max(heatmap) > 0:
318
  heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
319
  heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
320
+ # 添加白色背景
321
  heatmap_colored = cv2.addWeighted(heatmap_colored, 0.7, np.full_like(heatmap_colored, 255), 0.3, 0)
322
  else:
323
+ # 如果没有检测到点,返回白色图像
324
  heatmap_colored = np.full((height, width, 3), 255, dtype=np.uint8)
325
 
326
  logger.info("热力图生成完成")
 
485
  value=1,
486
  step=1,
487
  label="最大检测数量",
488
+ info="每帧���多检测的目标数量"
489
  )
490
  process_btn = gr.Button("开始处��")
491