Hakureirm commited on
Commit
be58931
·
verified ·
1 Parent(s): 1af9c54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -81
app.py CHANGED
@@ -12,6 +12,7 @@ import tempfile
12
  import imageio
13
  from tqdm import tqdm
14
  import logging
 
15
 
16
  # 新增: 配置logging
17
  logging.basicConfig(
@@ -49,7 +50,7 @@ def login(username, password):
49
  @spaces.GPU(duration=300)
50
  def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
51
  """
52
- 处理视频并进行小鼠检测
53
  Args:
54
  video_path: 输入视频路径
55
  process_seconds: 处理时长(秒)
@@ -167,7 +168,7 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
167
 
168
  pbar.close() # 关闭进度条
169
  video_writer.release()
170
- logger.info(f"视频处理完成,共处理 {frame_count} 帧")
171
 
172
  # 生成分析报告
173
  confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
@@ -183,118 +184,114 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
183
  - 置信度阈值: {conf_threshold:.2f}
184
  - 最大检测数量: {max_det}
185
  - 处理时长: {process_seconds}秒
186
-
187
  分析结果:
188
  - 处理帧数: {frame_count}
189
  - 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f}
190
  - 最大检测数: {max([info['count'] for info in detection_info])}
191
  - 最小检测数: {min([info['count'] for info in detection_info])}
192
-
193
  置信度分布:
194
  {confidence_report}
195
  """
196
 
197
- def filter_trajectories(positions, width, height, max_jump_distance=100):
198
- """
199
- 过滤轨迹中的异常点
200
- Args:
201
- positions: 位置列表 [[x1,y1], [x2,y2],...]
202
- width: 视频宽度
203
- height: 视频高度
204
- max_jump_distance: 允许��最大跳跃距离
205
- """
206
  if len(positions) < 3:
207
  return positions
208
 
209
- filtered_positions = []
210
- last_valid_pos = None
211
 
212
- for i, pos in enumerate(positions):
213
- x, y = pos
214
-
215
- if not (0 <= x < width and 0 <= y < height):
216
- continue
217
-
218
- if last_valid_pos is None:
219
- filtered_positions.append(pos)
220
- last_valid_pos = pos
221
- continue
 
 
 
 
 
 
 
 
222
 
223
- distance = np.sqrt((x - last_valid_pos[0])**2 + (y - last_valid_pos[1])**2)
224
 
225
- if distance > max_jump_distance:
226
- if len(filtered_positions) > 0:
227
- next_valid_pos = None
228
- for next_pos in positions[i:]:
229
- nx, ny = next_pos
230
- if (0 <= nx < width and 0 <= ny < height):
231
- next_distance = np.sqrt((nx - last_valid_pos[0])**2 + (ny - last_valid_pos[1])**2)
232
- if next_distance <= max_jump_distance:
233
- next_valid_pos = next_pos
234
- break
235
-
236
- if next_valid_pos is not None:
237
- steps = max(2, int(distance / max_jump_distance))
238
- for j in range(1, steps):
239
- alpha = j / steps
240
- interp_x = int(last_valid_pos[0] * (1 - alpha) + next_valid_pos[0] * alpha)
241
- interp_y = int(last_valid_pos[1] * (1 - alpha) + next_valid_pos[1] * alpha)
242
- filtered_positions.append([interp_x, interp_y])
243
- filtered_positions.append(next_valid_pos)
244
- last_valid_pos = next_valid_pos
245
- else:
246
- filtered_positions.append(pos)
247
- last_valid_pos = pos
248
 
249
- window_size = 5
250
- smoothed_positions = []
251
 
252
- if len(filtered_positions) >= window_size:
253
- smoothed_positions.extend(filtered_positions[:window_size//2])
 
 
 
254
 
255
- for i in range(window_size//2, len(filtered_positions) - window_size//2):
256
- window = filtered_positions[i-window_size//2:i+window_size//2+1]
257
- smoothed_x = int(np.mean([p[0] for p in window]))
258
- smoothed_y = int(np.mean([p[1] for p in window]))
259
- smoothed_positions.append([smoothed_x, smoothed_y])
 
 
 
 
 
 
 
260
 
261
- smoothed_positions.extend(filtered_positions[-window_size//2:])
262
- else:
263
- smoothed_positions = filtered_positions
264
 
265
- return smoothed_positions
266
 
267
  # 修改轨迹图生成部分
268
- trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255
269
  points = np.array(all_positions, dtype=np.int32)
270
  if len(points) > 1:
271
- filtered_points = filter_trajectories(points.tolist(), width, height)
272
  points = np.array(filtered_points, dtype=np.int32)
273
 
274
  for i in range(len(points) - 1):
275
  ratio = i / (len(points) - 1)
276
- color = (
277
  int((1 - ratio) * 255), # B
278
  50, # G
279
  int(ratio * 255) # R
280
- )
281
- cv2.line(trajectory_img, tuple(points[i]), tuple(points[i + 1]), color, 2)
 
 
 
282
 
283
- cv2.circle(trajectory_img, tuple(points[0]), 8, (0, 255, 0), -1)
284
- cv2.circle(trajectory_img, tuple(points[-1]), 8, (0, 0, 255), -1)
 
 
 
 
 
 
 
 
 
285
 
286
- arrow_interval = max(len(points) // 20, 1)
287
- for i in range(0, len(points) - arrow_interval, arrow_interval):
288
- pt1 = tuple(points[i])
289
- pt2 = tuple(points[i + arrow_interval])
290
- angle = np.arctan2(pt2[1] - pt1[1], pt2[0] - pt1[0])
291
- cv2.arrowedLine(trajectory_img, pt1, pt2, (100, 100, 100), 1, tipLength=0.2)
292
-
293
- if np.max(heatmap) > 0:
294
- heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
295
- heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
296
- alpha = 0.7
297
- heatmap_colored = cv2.addWeighted(heatmap_colored, alpha, np.full_like(heatmap_colored, 255), 1-alpha, 0)
298
 
299
  trajectory_frames = []
300
  heatmap_frames = []
@@ -357,6 +354,62 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
357
  logger.info("所有处理完成,准备返回结果")
358
  return output_path, trajectory_path, heatmap_path, trajectory_gif_path, heatmap_gif_path, report
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  # 创建 Gradio 界面
361
  with gr.Blocks() as demo:
362
  gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)")
 
12
  import imageio
13
  from tqdm import tqdm
14
  import logging
15
+ import torch.nn.functional as F
16
 
17
  # 新增: 配置logging
18
  logging.basicConfig(
 
50
  @spaces.GPU(duration=300)
51
  def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
52
  """
53
+ 处理视频并进行��鼠检测
54
  Args:
55
  video_path: 输入视频路径
56
  process_seconds: 处理时长(秒)
 
168
 
169
  pbar.close() # 关闭进度条
170
  video_writer.release()
171
+ logger.info(f"视频处��完成,共处理 {frame_count} 帧")
172
 
173
  # 生成分析报告
174
  confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
 
184
  - 置信度阈值: {conf_threshold:.2f}
185
  - 最大检测数量: {max_det}
186
  - 处理时长: {process_seconds}秒
 
187
  分析结果:
188
  - 处理帧数: {frame_count}
189
  - 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f}
190
  - 最大检测数: {max([info['count'] for info in detection_info])}
191
  - 最小检测数: {min([info['count'] for info in detection_info])}
 
192
  置信度分布:
193
  {confidence_report}
194
  """
195
 
196
+ def filter_trajectories_gpu(positions, width, height, max_jump_distance=100):
197
+ """GPU加速版本的轨迹过滤"""
 
 
 
 
 
 
 
198
  if len(positions) < 3:
199
  return positions
200
 
201
+ # 转换为GPU张量
202
+ points = torch.tensor(positions, device=device, dtype=torch.float32)
203
 
204
+ # 计算相邻点之间的距离
205
+ diffs = points[1:] - points[:-1]
206
+ distances = torch.norm(diffs, dim=1)
207
+
208
+ # 找出需要插值的位置
209
+ mask = distances > max_jump_distance
210
+ valid_indices = (~mask).nonzero().squeeze()
211
+
212
+ if len(valid_indices) < 2:
213
+ return positions
214
+
215
+ # 使用GPU进行插值
216
+ filtered_points = []
217
+ last_valid_idx = 0
218
+
219
+ for i in range(len(valid_indices)-1):
220
+ curr_idx = valid_indices[i].item()
221
+ next_idx = valid_indices[i+1].item()
222
 
223
+ filtered_points.append(points[curr_idx].tolist())
224
 
225
+ if next_idx - curr_idx > 1:
226
+ # 线性插值
227
+ steps = max(2, int((next_idx - curr_idx)))
228
+ interp_points = torch.linspace(0, 1, steps)
229
+ start_point = points[curr_idx]
230
+ end_point = points[next_idx]
231
+
232
+ interpolated = start_point[None] * (1 - interp_points[:, None]) + \
233
+ end_point[None] * interp_points[:, None]
234
+
235
+ filtered_points.extend(interpolated[1:-1].tolist())
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+ filtered_points.append(points[valid_indices[-1]].tolist())
 
238
 
239
+ # 平滑处理
240
+ if len(filtered_points) >= 5:
241
+ points_tensor = torch.tensor(filtered_points, device=device)
242
+ kernel_size = 5
243
+ padding = kernel_size // 2
244
 
245
+ # 使用1D卷积进行平滑
246
+ weights = torch.ones(1, 1, kernel_size, device=device) / kernel_size
247
+ smoothed_x = F.conv1d(
248
+ points_tensor[:, 0].view(1, 1, -1),
249
+ weights,
250
+ padding=padding
251
+ ).squeeze()
252
+ smoothed_y = F.conv1d(
253
+ points_tensor[:, 1].view(1, 1, -1),
254
+ weights,
255
+ padding=padding
256
+ ).squeeze()
257
 
258
+ smoothed_points = torch.stack([smoothed_x, smoothed_y], dim=1)
259
+ return smoothed_points.cpu().numpy().tolist()
 
260
 
261
+ return filtered_points
262
 
263
  # 修改轨迹图生成部分
264
+ trajectory_img = torch.ones((height, width, 3), device=device) * 255
265
  points = np.array(all_positions, dtype=np.int32)
266
  if len(points) > 1:
267
+ filtered_points = filter_trajectories_gpu(points.tolist(), width, height)
268
  points = np.array(filtered_points, dtype=np.int32)
269
 
270
  for i in range(len(points) - 1):
271
  ratio = i / (len(points) - 1)
272
+ color = torch.tensor([
273
  int((1 - ratio) * 255), # B
274
  50, # G
275
  int(ratio * 255) # R
276
+ ], device=device)
277
+
278
+ # 使用GPU绘制线段
279
+ pt1, pt2 = points[i], points[i + 1]
280
+ draw_line_gpu(trajectory_img, pt1, pt2, color, 2)
281
 
282
+ trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
283
+
284
+ if torch.cuda.is_available():
285
+ heatmap = torch.zeros((height, width), device=device)
286
+ for x, y in filtered_points:
287
+ if 0 <= x < width and 0 <= y < height:
288
+ temp_heatmap = torch.zeros((height, width), device=device)
289
+ temp_heatmap[y, x] = 1
290
+ # 使用GPU的高斯模糊
291
+ temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
292
+ heatmap += temp_heatmap
293
 
294
+ heatmap = heatmap.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  trajectory_frames = []
297
  heatmap_frames = []
 
354
  logger.info("所有处理完成,准备返回结果")
355
  return output_path, trajectory_path, heatmap_path, trajectory_gif_path, heatmap_gif_path, report
356
 
357
+ def gaussian_blur_gpu(tensor, kernel_size=31, sigma=10):
358
+ """GPU版本的高斯模糊"""
359
+ channels = 1
360
+ kernel = get_gaussian_kernel2d(kernel_size, sigma).to(device)
361
+ kernel = kernel.view(1, 1, kernel_size, kernel_size)
362
+ tensor = tensor.view(1, 1, tensor.shape[0], tensor.shape[1])
363
+
364
+ return F.conv2d(tensor, kernel, padding=kernel_size//2).squeeze()
365
+
366
+ def get_gaussian_kernel2d(kernel_size, sigma):
367
+ """生成2D高斯核"""
368
+ kernel_x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
369
+ x, y = torch.meshgrid(kernel_x, kernel_x)
370
+ kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
371
+ return kernel / kernel.sum()
372
+
373
+ def draw_line_gpu(image, pt1, pt2, color, thickness=1):
374
+ """GPU版本的线段绘制"""
375
+ x1, y1 = pt1
376
+ x2, y2 = pt2
377
+ dx = abs(x2 - x1)
378
+ dy = abs(y2 - y1)
379
+
380
+ if dx > dy:
381
+ steps = dx
382
+ else:
383
+ steps = dy
384
+
385
+ x_inc = (x2 - x1) / steps
386
+ y_inc = (y2 - y1) / steps
387
+
388
+ x = x1
389
+ y = y1
390
+
391
+ points = torch.zeros((int(steps) + 1, 2), device=device)
392
+ for i in range(int(steps) + 1):
393
+ points[i] = torch.tensor([x, y])
394
+ x += x_inc
395
+ y += y_inc
396
+
397
+ points = points.long()
398
+ valid_points = (points[:, 0] >= 0) & (points[:, 0] < image.shape[1]) & \
399
+ (points[:, 1] >= 0) & (points[:, 1] < image.shape[0])
400
+ points = points[valid_points]
401
+
402
+ if thickness > 1:
403
+ for dx in range(-thickness//2, thickness//2 + 1):
404
+ for dy in range(-thickness//2, thickness//2 + 1):
405
+ offset_points = points + torch.tensor([dx, dy], device=device)
406
+ valid_offset = (offset_points[:, 0] >= 0) & (offset_points[:, 0] < image.shape[1]) & \
407
+ (offset_points[:, 1] >= 0) & (offset_points[:, 1] < image.shape[0])
408
+ offset_points = offset_points[valid_offset]
409
+ image[offset_points[:, 1], offset_points[:, 0]] = color
410
+ else:
411
+ image[points[:, 1], points[:, 0]] = color
412
+
413
  # 创建 Gradio 界面
414
  with gr.Blocks() as demo:
415
  gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)")