Hakureirm commited on
Commit
2ee8a2f
·
verified ·
1 Parent(s): 72f82cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -292
app.py CHANGED
@@ -9,18 +9,7 @@ import numpy as np
9
  import cv2
10
  from pathlib import Path
11
  import tempfile
12
- import imageio
13
  from tqdm import tqdm
14
- import logging
15
- import torch.nn.functional as F
16
-
17
- # 新增: 配置logging
18
- logging.basicConfig(
19
- level=logging.INFO,
20
- format='%(asctime)s - %(levelname)s - %(message)s',
21
- datefmt='%Y-%m-%d %H:%M:%S'
22
- )
23
- logger = logging.getLogger(__name__)
24
 
25
  # 从环境变量获取密码
26
  APP_USERNAME = "admin" # 用户名保持固定
@@ -28,7 +17,9 @@ APP_PASSWORD = os.getenv("APP_PASSWORD", "default_password") # 从环境变量
28
 
29
  app = FastAPI()
30
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
31
  model = YOLO('kunin-mice-pose.v0.1.5n.pt')
 
32
 
33
  # 定义认证状态
34
  class AuthState:
@@ -39,38 +30,26 @@ auth_state = AuthState()
39
 
40
  def login(username, password):
41
  """登录验证"""
42
- logger.info(f"用户尝试登录: {username}")
43
  if username == APP_USERNAME and password == APP_PASSWORD:
44
  auth_state.is_logged_in = True
45
- logger.info("登录成功")
46
  return gr.update(visible=False), gr.update(visible=True), "登录成功"
47
- logger.warning("登录失败:用户名或密码错误")
48
  return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误"
49
 
50
- @spaces.GPU(duration=300)
51
  def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
52
  """
53
  处理视频并进行小鼠检测
54
- Args:
55
- video_path: 输入视频路径
56
- process_seconds: 处理时长(秒)
57
- conf_threshold: 置信度阈值(0-1)
58
- max_det: 每帧最大检测数量
59
  """
60
- logger.info(f"开始处理视频: {video_path}")
61
- logger.info(f"参数设置 - 处理时长: {process_seconds}秒, 置信度阈值: {conf_threshold}, 最大检测数: {max_det}")
62
 
63
  if not auth_state.is_logged_in:
64
- logger.warning("用户未登录,拒绝访问")
65
  return None, "请先登录"
66
 
67
- # 创建临时目录保存输出视频
68
- logger.info("创建临时输出目录")
69
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
70
  output_path = tmp_file.name
71
 
72
- # 获取视频信息
73
- logger.info("读取视频信息")
74
  cap = cv2.VideoCapture(video_path)
75
  fps = int(cap.get(cv2.CAP_PROP_FPS))
76
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -78,9 +57,9 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
78
  total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
79
  cap.release()
80
 
81
- logger.info(f"视频信息 - FPS: {fps}, 分辨率: {width}x{height}, 总帧数: {total_frames}")
82
 
83
- # 创建视频写入器
84
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
85
  video_writer = cv2.VideoWriter(
86
  output_path,
@@ -89,11 +68,10 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
89
  (width, height)
90
  )
91
 
92
- # 计算基于分辨率的线宽
93
  base_size = min(width, height)
94
- line_thickness = max(1, int(base_size * 0.002)) # 0.2% 的最小边长
95
 
96
- logger.info("开始YOLO模型推理")
97
  results = model.predict(
98
  source=video_path,
99
  device=device,
@@ -108,28 +86,26 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
108
  vid_stride=1,
109
  max_det=max_det,
110
  retina_masks=True,
111
- verbose=False # 关闭YOLO默认日志输出
112
  )
113
 
114
- logger.info("开始处理检测结果")
115
  frame_count = 0
116
  detection_info = []
117
  all_positions = []
118
  heatmap = np.zeros((height, width), dtype=np.float32)
119
 
120
- # 新增: 创建进度条
121
- pbar = tqdm(total=total_frames, desc="处理视频", unit="帧")
122
 
123
  for r in results:
124
  frame = r.plot()
125
 
126
- # 收集位置信息
127
  if hasattr(r, 'keypoints') and r.keypoints is not None:
128
  kpts = r.keypoints.data
129
  if isinstance(kpts, torch.Tensor):
130
  kpts = kpts.cpu().numpy()
131
 
132
- if kpts.shape == (1, 8, 3): # [num_objects, num_keypoints, xyz]
133
  x, y = int(kpts[0, 0, 0]), int(kpts[0, 0, 1])
134
  all_positions.append([x, y])
135
 
@@ -141,7 +117,6 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
141
  temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma)
142
  heatmap += temp_heatmap
143
 
144
- # 收集检测信息
145
  frame_info = {
146
  "frame": frame_count + 1,
147
  "count": len(r.boxes),
@@ -161,16 +136,17 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
161
  video_writer.write(frame)
162
 
163
  frame_count += 1
164
- pbar.update(1) # 更新进度条
165
 
166
  if process_seconds and frame_count >= total_frames:
167
  break
168
 
169
- pbar.close() # 关闭进度条
 
 
170
  video_writer.release()
171
- logger.info(f"视频处理完成,共处理 {frame_count} 帧")
172
-
173
- # 生成分析报告
174
  confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
175
  hist, bins = np.histogram(confidences, bins=5)
176
 
@@ -193,267 +169,113 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
193
  {confidence_report}
194
  """
195
 
196
- def filter_trajectories_gpu(positions, width, height, max_jump_distance=100):
197
- """GPU加速版本的轨迹过滤"""
198
  if len(positions) < 3:
199
  return positions
200
 
201
- # 转换为GPU张量
202
- points = torch.tensor(positions, device=device, dtype=torch.float32)
203
-
204
- # 计算相邻点之间的距离
205
- diffs = points[1:] - points[:-1]
206
- distances = torch.norm(diffs, dim=1)
207
-
208
- # 找出需要插值的位置
209
- mask = distances > max_jump_distance
210
- valid_indices = (~mask).nonzero().squeeze()
211
-
212
- if len(valid_indices) < 2:
213
- return positions
214
-
215
- # 使用GPU进行插值
216
- filtered_points = []
217
- last_valid_idx = 0
218
 
219
- for i in range(len(valid_indices)-1):
220
- curr_idx = valid_indices[i].item()
221
- next_idx = valid_indices[i+1].item()
222
 
223
- filtered_points.append(points[curr_idx].tolist())
 
224
 
225
- if next_idx - curr_idx > 1:
226
- # 线性插值
227
- steps = max(2, int((next_idx - curr_idx)))
228
- interp_points = torch.linspace(0, 1, steps)
229
- start_point = points[curr_idx]
230
- end_point = points[next_idx]
231
-
232
- interpolated = start_point[None] * (1 - interp_points[:, None]) + \
233
- end_point[None] * interp_points[:, None]
234
-
235
- filtered_points.extend(interpolated[1:-1].tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- filtered_points.append(points[valid_indices[-1]].tolist())
 
238
 
239
- # 平滑处理
240
- if len(filtered_points) >= 5:
241
- points_tensor = torch.tensor(filtered_points, device=device)
242
- kernel_size = 5
243
- padding = kernel_size // 2
244
 
245
- # 使用1D卷积进行平滑
246
- weights = torch.ones(1, 1, kernel_size, device=device) / kernel_size
247
- smoothed_x = F.conv1d(
248
- points_tensor[:, 0].view(1, 1, -1),
249
- weights,
250
- padding=padding
251
- ).squeeze()
252
- smoothed_y = F.conv1d(
253
- points_tensor[:, 1].view(1, 1, -1),
254
- weights,
255
- padding=padding
256
- ).squeeze()
257
 
258
- smoothed_points = torch.stack([smoothed_x, smoothed_y], dim=1)
259
- return smoothed_points.cpu().numpy().tolist()
 
260
 
261
- return filtered_points
262
 
263
- # 修改轨迹图生成部分
264
- trajectory_img = torch.ones((height, width, 3), device=device, dtype=torch.float32)
265
  points = np.array(all_positions, dtype=np.int32)
266
  if len(points) > 1:
267
- filtered_points = filter_trajectories_gpu(points.tolist(), width, height)
268
  points = np.array(filtered_points, dtype=np.int32)
269
 
270
  for i in range(len(points) - 1):
271
  ratio = i / (len(points) - 1)
272
- color = torch.tensor([
273
- int((1 - ratio) * 255), # B
274
- 50, # G
275
- int(ratio * 255) # R
276
- ], device=device, dtype=torch.float32)
277
-
278
- # 使用GPU绘制线段
279
- pt1, pt2 = points[i], points[i + 1]
280
- draw_line_gpu(trajectory_img, pt1, pt2, color, 2)
281
 
282
- trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
283
-
284
- # 修改热力图生成部分
285
- def generate_heatmap(points, width, height):
286
- """简化版本的热力图生成函数"""
287
- logger.info(f"开始生成热力图,共 {len(points)} 个点")
288
-
289
- # 使用numpy生成热力图
290
- heatmap = np.zeros((height, width), dtype=np.float32)
291
-
292
- for point in points:
293
- try:
294
- # 确保是整数坐标
295
- x = min(max(0, int(point[0])), width-1)
296
- y = min(max(0, int(point[1])), height-1)
297
-
298
- # 在点位置标记1
299
- heatmap[y, x] = 1.0
300
-
301
- # 直接使用OpenCV的高斯模糊
302
- temp_heatmap = cv2.GaussianBlur(heatmap, (31, 31), 10)
303
- heatmap = temp_heatmap
304
-
305
- except Exception as e:
306
- logger.warning(f"处理点 {point} 时出错: {e}")
307
- continue
308
 
309
- logger.info("热力图生成完成")
310
- return heatmap
311
-
312
- # 生成热力图
313
- logger.info("开始生成热力图")
314
- heatmap = generate_heatmap(filtered_points, width, height)
315
-
316
- # 归一化和着色
317
  if np.max(heatmap) > 0:
318
  heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
319
  heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
320
- # 添加白色背景
321
- heatmap_colored = cv2.addWeighted(heatmap_colored, 0.7, np.full_like(heatmap_colored, 255), 0.3, 0)
322
- else:
323
- # 如果没有检测到点,返回白色图像
324
- heatmap_colored = np.full((height, width, 3), 255, dtype=np.uint8)
325
-
326
- logger.info("热力图生成完成")
327
-
328
- trajectory_frames = []
329
- heatmap_frames = []
330
-
331
- base_trajectory = np.zeros((height, width, 3), dtype=np.uint8) + 255
332
- base_heatmap = np.zeros((height, width), dtype=np.float32)
333
-
334
- frame_interval = max(1, len(filtered_points) // 30)
335
-
336
- for i in range(0, len(filtered_points), frame_interval):
337
- current_points = filtered_points[:i+1]
338
-
339
- frame_trajectory = base_trajectory.copy()
340
- if len(current_points) > 1:
341
- points = np.array(current_points, dtype=np.int32)
342
- for j in range(len(points) - 1):
343
- ratio = j / (len(current_points) - 1)
344
- color = (
345
- int((1 - ratio) * 255),
346
- 50,
347
- int(ratio * 255)
348
- )
349
- cv2.line(frame_trajectory, tuple(points[j]), tuple(points[j + 1]), color, 2)
350
-
351
- cv2.circle(frame_trajectory, tuple(points[-1]), 8, (0, 0, 255), -1)
352
- trajectory_frames.append(frame_trajectory)
353
-
354
- frame_heatmap = base_heatmap.copy()
355
- for x, y in current_points:
356
- if 0 <= x < width and 0 <= y < height:
357
- temp_heatmap = np.zeros((height, width), dtype=np.float32)
358
- temp_heatmap[y, x] = 1
359
- temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
360
- frame_heatmap += temp_heatmap
361
-
362
- if np.max(frame_heatmap) > 0:
363
- frame_heatmap_norm = cv2.normalize(frame_heatmap, None, 0, 255, cv2.NORM_MINMAX)
364
- frame_heatmap_color = cv2.applyColorMap(frame_heatmap_norm.astype(np.uint8), cv2.COLORMAP_JET)
365
- frame_heatmap_color = cv2.addWeighted(frame_heatmap_color, 0.7, np.full_like(frame_heatmap_color, 255), 0.3, 0)
366
- heatmap_frames.append(frame_heatmap_color)
367
-
368
- logger.info("开始生成轨迹图和热力图")
369
- trajectory_gif_path = output_path.replace('.mp4', '_trajectory.gif')
370
- heatmap_gif_path = output_path.replace('.mp4', '_heatmap.gif')
371
-
372
- imageio.mimsave(trajectory_gif_path, trajectory_frames, duration=50)
373
- imageio.mimsave(heatmap_gif_path, heatmap_frames, duration=50)
374
 
 
375
  trajectory_path = output_path.replace('.mp4', '_trajectory.png')
376
  heatmap_path = output_path.replace('.mp4', '_heatmap.png')
377
  cv2.imwrite(trajectory_path, trajectory_img)
378
- if np.max(heatmap) > 0:
379
- heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
380
- heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
381
- heatmap_colored = cv2.addWeighted(heatmap_colored, 0.7, np.full_like(heatmap_colored, 255), 0.3, 0)
382
  cv2.imwrite(heatmap_path, heatmap_colored)
383
 
384
- logger.info("轨迹图和热力图生成完成")
385
- logger.info("开始生成GIF动画")
386
- imageio.mimsave(trajectory_gif_path, trajectory_frames, duration=50)
387
- imageio.mimsave(heatmap_gif_path, heatmap_frames, duration=50)
388
- logger.info("GIF动画生成完成")
389
-
390
- logger.info("所有处理完成,准备返回结果")
391
- return output_path, trajectory_path, heatmap_path, trajectory_gif_path, heatmap_gif_path, report
392
-
393
- def gaussian_blur_gpu(tensor, kernel_size=31, sigma=10):
394
- """GPU版本的高斯模糊"""
395
- channels = 1
396
- kernel = get_gaussian_kernel2d(kernel_size, sigma).to(device)
397
- kernel = kernel.view(1, 1, kernel_size, kernel_size)
398
- tensor = tensor.view(1, 1, tensor.shape[0], tensor.shape[1])
399
-
400
- return F.conv2d(tensor, kernel, padding=kernel_size//2).squeeze()
401
-
402
- def get_gaussian_kernel2d(kernel_size, sigma):
403
- """生成2D高斯核"""
404
- kernel_x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
405
- x, y = torch.meshgrid(kernel_x, kernel_x, indexing='ij')
406
- kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
407
- return kernel / kernel.sum()
408
-
409
- def draw_line_gpu(image, pt1, pt2, color, thickness=1):
410
- """GPU版本的线段绘制"""
411
- x1, y1 = map(int, pt1) # 确保是整数
412
- x2, y2 = map(int, pt2) # 确保是整数
413
- dx = abs(x2 - x1)
414
- dy = abs(y2 - y1)
415
-
416
- # 防止除零错误
417
- steps = max(dx, dy)
418
- if steps == 0:
419
- # 如果是同一个点,直接画点
420
- if 0 <= x1 < image.shape[1] and 0 <= y1 < image.shape[0]:
421
- image[y1, x1] = color
422
- return
423
-
424
- x_inc = (x2 - x1) / steps
425
- y_inc = (y2 - y1) / steps
426
-
427
- x = x1
428
- y = y1
429
-
430
- points = torch.zeros((int(steps) + 1, 2), device=device)
431
- for i in range(int(steps) + 1):
432
- points[i] = torch.tensor([x, y])
433
- x += x_inc
434
- y += y_inc
435
-
436
- points = points.long() # 转换为整数类型
437
- valid_points = (points[:, 0] >= 0) & (points[:, 0] < image.shape[1]) & \
438
- (points[:, 1] >= 0) & (points[:, 1] < image.shape[0])
439
- points = points[valid_points]
440
-
441
- color = color.to(image.dtype)
442
-
443
- if thickness > 1:
444
- for dx in range(-thickness//2, thickness//2 + 1):
445
- for dy in range(-thickness//2, thickness//2 + 1):
446
- offset_points = points + torch.tensor([dx, dy], device=device, dtype=torch.long)
447
- valid_offset = (offset_points[:, 0] >= 0) & (offset_points[:, 0] < image.shape[1]) & \
448
- (offset_points[:, 1] >= 0) & (offset_points[:, 1] < image.shape[0])
449
- offset_points = offset_points[valid_offset]
450
- image[offset_points[:, 1], offset_points[:, 0]] = color
451
- else:
452
- image[points[:, 1], points[:, 0]] = color
453
 
454
  # 创建 Gradio 界面
455
  with gr.Blocks() as demo:
456
- gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)")
457
 
458
  with gr.Group() as login_interface:
459
  username = gr.Textbox(label="用户名")
@@ -472,7 +294,7 @@ with gr.Blocks() as demo:
472
  value=20
473
  )
474
  conf_threshold = gr.Slider(
475
- minimum=0.0,
476
  maximum=1.0,
477
  value=0.2,
478
  step=0.05,
@@ -485,7 +307,7 @@ with gr.Blocks() as demo:
485
  value=1,
486
  step=1,
487
  label="最大检测数量",
488
- info="每帧���多检测的目标数量"
489
  )
490
  process_btn = gr.Button("开始处理")
491
 
@@ -493,17 +315,14 @@ with gr.Blocks() as demo:
493
  video_output = gr.Video(label="检测结果")
494
  with gr.Row():
495
  trajectory_output = gr.Image(label="运动轨迹")
496
- trajectory_gif_output = gr.Image(label="轨迹动画")
497
- with gr.Row():
498
  heatmap_output = gr.Image(label="热力图")
499
- heatmap_gif_output = gr.Image(label="热力图动画")
500
  report_output = gr.Textbox(label="分析报告")
501
 
502
  gr.Markdown("""
503
  ### 使用说明
504
  1. 上传视频文件
505
  2. 设置处理参数:
506
- - 处理时间:需要分析的视频时长(秒)
507
  - 置信度阈值:检测的置信度要求(越高越严格)
508
  - 最大检测数量:每帧最多检测的目标数量
509
  3. 等待处理完成
@@ -526,20 +345,8 @@ with gr.Blocks() as demo:
526
  process_btn.click(
527
  fn=process_video,
528
  inputs=[video_input, process_seconds, conf_threshold, max_det],
529
- outputs=[video_output, trajectory_output, heatmap_output,
530
- trajectory_gif_output, heatmap_gif_output, report_output]
531
  )
532
 
533
  if __name__ == "__main__":
534
- try:
535
- # GPU相关操作
536
- if torch.cuda.is_available():
537
- logger.info("使用GPU进行轨迹和热力图计算")
538
- # ... GPU操作 ...
539
- else:
540
- logger.info("使用CPU进行轨迹和热力图计算")
541
- # ... CPU操作 ...
542
- except Exception as e:
543
- logger.error(f"处理轨迹和热力图时出错: {str(e)}")
544
- raise
545
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
9
  import cv2
10
  from pathlib import Path
11
  import tempfile
 
12
  from tqdm import tqdm
 
 
 
 
 
 
 
 
 
 
13
 
14
  # 从环境变量获取密码
15
  APP_USERNAME = "admin" # 用户名保持固定
 
17
 
18
  app = FastAPI()
19
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ print(f"使用设备: {device}")
21
  model = YOLO('kunin-mice-pose.v0.1.5n.pt')
22
+ print("模型加载完成")
23
 
24
  # 定义认证状态
25
  class AuthState:
 
30
 
31
  def login(username, password):
32
  """登录验证"""
 
33
  if username == APP_USERNAME and password == APP_PASSWORD:
34
  auth_state.is_logged_in = True
 
35
  return gr.update(visible=False), gr.update(visible=True), "登录成功"
 
36
  return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误"
37
 
38
+ @spaces.GPU(duration=120)
39
  def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
40
  """
41
  处理视频并进行小鼠检测
 
 
 
 
 
42
  """
43
+ print("开始处理视频...")
 
44
 
45
  if not auth_state.is_logged_in:
 
46
  return None, "请先登录"
47
 
48
+ print("创建临时输出文件...")
 
49
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
50
  output_path = tmp_file.name
51
 
52
+ print("读取视频信息...")
 
53
  cap = cv2.VideoCapture(video_path)
54
  fps = int(cap.get(cv2.CAP_PROP_FPS))
55
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
57
  total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
58
  cap.release()
59
 
60
+ print(f"视频信息: {width}x{height} @ {fps}fps, 总帧数: {total_frames}")
61
 
62
+ print("初始化视频写入器...")
63
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
64
  video_writer = cv2.VideoWriter(
65
  output_path,
 
68
  (width, height)
69
  )
70
 
 
71
  base_size = min(width, height)
72
+ line_thickness = max(1, int(base_size * 0.002))
73
 
74
+ print("开始YOLO推理...")
75
  results = model.predict(
76
  source=video_path,
77
  device=device,
 
86
  vid_stride=1,
87
  max_det=max_det,
88
  retina_masks=True,
89
+ verbose=False
90
  )
91
 
 
92
  frame_count = 0
93
  detection_info = []
94
  all_positions = []
95
  heatmap = np.zeros((height, width), dtype=np.float32)
96
 
97
+ print("处理检测结果...")
98
+ progress_bar = tqdm(total=total_frames, desc="处理帧")
99
 
100
  for r in results:
101
  frame = r.plot()
102
 
 
103
  if hasattr(r, 'keypoints') and r.keypoints is not None:
104
  kpts = r.keypoints.data
105
  if isinstance(kpts, torch.Tensor):
106
  kpts = kpts.cpu().numpy()
107
 
108
+ if kpts.shape == (1, 8, 3):
109
  x, y = int(kpts[0, 0, 0]), int(kpts[0, 0, 1])
110
  all_positions.append([x, y])
111
 
 
117
  temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma)
118
  heatmap += temp_heatmap
119
 
 
120
  frame_info = {
121
  "frame": frame_count + 1,
122
  "count": len(r.boxes),
 
136
  video_writer.write(frame)
137
 
138
  frame_count += 1
139
+ progress_bar.update(1)
140
 
141
  if process_seconds and frame_count >= total_frames:
142
  break
143
 
144
+ progress_bar.close()
145
+ print("视频处理完成")
146
+
147
  video_writer.release()
148
+
149
+ print("生成分析报告...")
 
150
  confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
151
  hist, bins = np.histogram(confidences, bins=5)
152
 
 
169
  {confidence_report}
170
  """
171
 
172
+ def filter_trajectories(positions, width, height, max_jump_distance=100):
 
173
  if len(positions) < 3:
174
  return positions
175
 
176
+ filtered_positions = []
177
+ last_valid_pos = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ for i, pos in enumerate(positions):
180
+ x, y = pos
 
181
 
182
+ if not (0 <= x < width and 0 <= y < height):
183
+ continue
184
 
185
+ if last_valid_pos is None:
186
+ filtered_positions.append(pos)
187
+ last_valid_pos = pos
188
+ continue
189
+
190
+ distance = np.sqrt((x - last_valid_pos[0])**2 + (y - last_valid_pos[1])**2)
191
+
192
+ if distance > max_jump_distance:
193
+ if len(filtered_positions) > 0:
194
+ next_valid_pos = None
195
+ for next_pos in positions[i:]:
196
+ nx, ny = next_pos
197
+ if (0 <= nx < width and 0 <= ny < height):
198
+ next_distance = np.sqrt((nx - last_valid_pos[0])**2 + (ny - last_valid_pos[1])**2)
199
+ if next_distance <= max_jump_distance:
200
+ next_valid_pos = next_pos
201
+ break
202
+
203
+ if next_valid_pos is not None:
204
+ steps = max(2, int(distance / max_jump_distance))
205
+ for j in range(1, steps):
206
+ alpha = j / steps
207
+ interp_x = int(last_valid_pos[0] * (1 - alpha) + next_valid_pos[0] * alpha)
208
+ interp_y = int(last_valid_pos[1] * (1 - alpha) + next_valid_pos[1] * alpha)
209
+ filtered_positions.append([interp_x, interp_y])
210
+ filtered_positions.append(next_valid_pos)
211
+ last_valid_pos = next_valid_pos
212
+ else:
213
+ filtered_positions.append(pos)
214
+ last_valid_pos = pos
215
 
216
+ window_size = 5
217
+ smoothed_positions = []
218
 
219
+ if len(filtered_positions) >= window_size:
220
+ smoothed_positions.extend(filtered_positions[:window_size//2])
 
 
 
221
 
222
+ for i in range(window_size//2, len(filtered_positions) - window_size//2):
223
+ window = filtered_positions[i-window_size//2:i+window_size//2+1]
224
+ smoothed_x = int(np.mean([p[0] for p in window]))
225
+ smoothed_y = int(np.mean([p[1] for p in window]))
226
+ smoothed_positions.append([smoothed_x, smoothed_y])
 
 
 
 
 
 
 
227
 
228
+ smoothed_positions.extend(filtered_positions[-window_size//2:])
229
+ else:
230
+ smoothed_positions = filtered_positions
231
 
232
+ return smoothed_positions
233
 
234
+ print("生成轨迹图...")
235
+ trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255
236
  points = np.array(all_positions, dtype=np.int32)
237
  if len(points) > 1:
238
+ filtered_points = filter_trajectories(points.tolist(), width, height)
239
  points = np.array(filtered_points, dtype=np.int32)
240
 
241
  for i in range(len(points) - 1):
242
  ratio = i / (len(points) - 1)
243
+ color = (
244
+ int((1 - ratio) * 255),
245
+ 50,
246
+ int(ratio * 255)
247
+ )
248
+ cv2.line(trajectory_img, tuple(points[i]), tuple(points[i + 1]), color, 2)
 
 
 
249
 
250
+ cv2.circle(trajectory_img, tuple(points[0]), 8, (0, 255, 0), -1)
251
+ cv2.circle(trajectory_img, tuple(points[-1]), 8, (0, 0, 255), -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
+ arrow_interval = max(len(points) // 20, 1)
254
+ for i in range(0, len(points) - arrow_interval, arrow_interval):
255
+ pt1 = tuple(points[i])
256
+ pt2 = tuple(points[i + arrow_interval])
257
+ angle = np.arctan2(pt2[1] - pt1[1], pt2[0] - pt1[0])
258
+ cv2.arrowedLine(trajectory_img, pt1, pt2, (100, 100, 100), 1, tipLength=0.2)
259
+
260
+ print("生成热力图...")
261
  if np.max(heatmap) > 0:
262
  heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
263
  heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
264
+ alpha = 0.7
265
+ heatmap_colored = cv2.addWeighted(heatmap_colored, alpha, np.full_like(heatmap_colored, 255), 1-alpha, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
+ print("保存结果图像...")
268
  trajectory_path = output_path.replace('.mp4', '_trajectory.png')
269
  heatmap_path = output_path.replace('.mp4', '_heatmap.png')
270
  cv2.imwrite(trajectory_path, trajectory_img)
 
 
 
 
271
  cv2.imwrite(heatmap_path, heatmap_colored)
272
 
273
+ print("处理完成!")
274
+ return output_path, trajectory_path, heatmap_path, report
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  # 创建 Gradio 界面
277
  with gr.Blocks() as demo:
278
+ gr.Markdown("# 🐭 小鼠行为分析 (Mice Behavior Analysis)")
279
 
280
  with gr.Group() as login_interface:
281
  username = gr.Textbox(label="用户名")
 
294
  value=20
295
  )
296
  conf_threshold = gr.Slider(
297
+ minimum=0.1,
298
  maximum=1.0,
299
  value=0.2,
300
  step=0.05,
 
307
  value=1,
308
  step=1,
309
  label="最大检测数量",
310
+ info="每帧最多检测的目标数量"
311
  )
312
  process_btn = gr.Button("开始处理")
313
 
 
315
  video_output = gr.Video(label="检测结果")
316
  with gr.Row():
317
  trajectory_output = gr.Image(label="运动轨迹")
 
 
318
  heatmap_output = gr.Image(label="热力图")
 
319
  report_output = gr.Textbox(label="分析报告")
320
 
321
  gr.Markdown("""
322
  ### 使用说明
323
  1. 上传视频文件
324
  2. 设置处理参数:
325
+ - 处理时长:需要分析的视频时长(秒)
326
  - 置信度阈值:检测的置信度要求(越高越严格)
327
  - 最大检测数量:每帧最多检测的目标数量
328
  3. 等待处理完成
 
345
  process_btn.click(
346
  fn=process_video,
347
  inputs=[video_input, process_seconds, conf_threshold, max_det],
348
+ outputs=[video_output, trajectory_output, heatmap_output, report_output]
 
349
  )
350
 
351
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
352
  demo.launch(server_name="0.0.0.0", server_port=7860)