Hakureirm commited on
Commit
ed63729
·
verified ·
1 Parent(s): 3217754

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -35
app.py CHANGED
@@ -50,7 +50,7 @@ def login(username, password):
50
  @spaces.GPU(duration=300)
51
  def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
52
  """
53
- 处理视频并进行��鼠检测
54
  Args:
55
  video_path: 输入视频路径
56
  process_seconds: 处理时长(秒)
@@ -282,43 +282,66 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
282
  trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
283
 
284
  # 修改热力图生成部分
285
- if torch.cuda.is_available():
286
- logger.info("使用GPU生成热力图")
287
  try:
288
- heatmap = torch.zeros((height, width), device=device)
289
- for pos in filtered_points:
290
- # 确保坐标是整数并且在有效范围内
291
- x, y = map(int, pos) # 明确转换为整数
292
- if 0 <= x < width and 0 <= y < height:
293
- temp_heatmap = torch.zeros((height, width), device=device)
294
- temp_heatmap[int(y), int(x)] = 1 # 再次确保是整数
295
- # 使用GPU的高斯模糊
296
- temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
297
- heatmap += temp_heatmap
298
-
299
- heatmap = heatmap.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  except Exception as e:
301
- logger.error(f"GPU热力图生成失败: {str(e)}")
302
- # 回退到CPU处理
303
- logger.info("切换到CPU生成热力图")
304
- heatmap = np.zeros((height, width), dtype=np.float32)
305
- for pos in filtered_points:
306
- x, y = map(int, pos)
307
- if 0 <= x < width and 0 <= y < height:
308
- temp_heatmap = np.zeros((height, width), dtype=np.float32)
309
- temp_heatmap[int(y), int(x)] = 1
310
- temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
311
- heatmap += temp_heatmap
 
 
312
  else:
313
- logger.info("使用CPU生成热力图")
314
- heatmap = np.zeros((height, width), dtype=np.float32)
315
- for pos in filtered_points:
316
- x, y = map(int, pos)
317
- if 0 <= x < width and 0 <= y < height:
318
- temp_heatmap = np.zeros((height, width), dtype=np.float32)
319
- temp_heatmap[int(y), int(x)] = 1
320
- temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
321
- heatmap += temp_heatmap
322
 
323
  trajectory_frames = []
324
  heatmap_frames = []
 
50
  @spaces.GPU(duration=300)
51
  def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
52
  """
53
+ 处理视频并进行鼠检测
54
  Args:
55
  video_path: 输入视频路径
56
  process_seconds: 处理时长(秒)
 
282
  trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
283
 
284
  # 修改热力图生成部分
285
+ def generate_heatmap(points, width, height, device='cuda'):
286
+ """生成热力图的独立函数"""
287
  try:
288
+ if device == 'cuda' and torch.cuda.is_available():
289
+ logger.info("使用GPU生成热力图")
290
+ heatmap = torch.zeros((height, width), device=device)
291
+ for point in points:
292
+ try:
293
+ # 确保坐标是整数
294
+ x = int(round(float(point[0])))
295
+ y = int(round(float(point[1])))
296
+
297
+ if 0 <= x < width and 0 <= y < height:
298
+ temp_heatmap = torch.zeros((height, width), device=device)
299
+ temp_heatmap[y, x] = 1.0
300
+ # 使用GPU的高斯模糊
301
+ temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
302
+ heatmap += temp_heatmap
303
+ except (ValueError, TypeError, IndexError) as e:
304
+ logger.warning(f"跳过无效点 {point}: {str(e)}")
305
+ continue
306
+
307
+ return heatmap.cpu().numpy()
308
+ else:
309
+ logger.info("使用CPU生成热力图")
310
+ heatmap = np.zeros((height, width), dtype=np.float32)
311
+ for point in points:
312
+ try:
313
+ x = int(round(float(point[0])))
314
+ y = int(round(float(point[1])))
315
+
316
+ if 0 <= x < width and 0 <= y < height:
317
+ temp_heatmap = np.zeros((height, width), dtype=np.float32)
318
+ temp_heatmap[y, x] = 1.0
319
+ temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
320
+ heatmap += temp_heatmap
321
+ except (ValueError, TypeError, IndexError) as e:
322
+ logger.warning(f"跳过无效点 {point}: {str(e)}")
323
+ continue
324
+
325
+ return heatmap
326
  except Exception as e:
327
+ logger.error(f"热力图生成失败: {str(e)}")
328
+ # 返回空热力图
329
+ return np.zeros((height, width), dtype=np.float32)
330
+
331
+ # 生成热力图
332
+ logger.info("开始生成热力图")
333
+ heatmap = generate_heatmap(filtered_points, width, height, device=device)
334
+
335
+ # 归一化和着色热力图
336
+ if np.max(heatmap) > 0:
337
+ heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
338
+ heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
339
+ heatmap_colored = cv2.addWeighted(heatmap_colored, 0.7, np.full_like(heatmap_colored, 255), 0.3, 0)
340
  else:
341
+ # 如果热力图全为0,创建一个空白的彩色图像
342
+ heatmap_colored = np.full((height, width, 3), 255, dtype=np.uint8)
343
+
344
+ logger.info("热力图生成完成")
 
 
 
 
 
345
 
346
  trajectory_frames = []
347
  heatmap_frames = []