Hakureirm commited on
Commit
7dd4b60
·
verified ·
1 Parent(s): 989b2b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -64
app.py CHANGED
@@ -10,6 +10,7 @@ import cv2
10
  from pathlib import Path
11
  import tempfile
12
  import imageio
 
13
 
14
  # 从环境变量获取密码
15
  APP_USERNAME = "admin" # 用户名保持固定
@@ -71,7 +72,7 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
71
  base_size = min(width, height)
72
  line_thickness = max(1, int(base_size * 0.002)) # 0.2% 的最小边长
73
 
74
- # 设置推理参数并处理视频
75
  results = model.predict(
76
  source=video_path,
77
  device=device,
@@ -79,24 +80,25 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
79
  save=False,
80
  show=False,
81
  stream=True,
82
- line_width=line_thickness, # 线宽
83
- boxes=True, # 显示边界框
84
  show_labels=True,
85
  show_conf=True,
86
  vid_stride=1,
87
  max_det=max_det,
88
- retina_masks=True, # 更精细的显示
89
- verbose=False
90
  )
91
 
92
  # 处理结果
93
  frame_count = 0
94
  detection_info = []
95
-
96
- # 用于记录轨迹和热图数据
97
  all_positions = []
98
  heatmap = np.zeros((height, width), dtype=np.float32)
99
 
 
 
 
100
  for r in results:
101
  frame = r.plot()
102
 
@@ -106,16 +108,13 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
106
  if isinstance(kpts, torch.Tensor):
107
  kpts = kpts.cpu().numpy()
108
 
109
- # 取第一个检测目标的第一个关键点(通常是头部)
110
  if kpts.shape == (1, 8, 3): # [num_objects, num_keypoints, xyz]
111
- x, y = int(kpts[0, 0, 0]), int(kpts[0, 0, 1]) # 使用第一个关键点
112
  all_positions.append([x, y])
113
 
114
- # 更新热图,使用高斯核来平滑
115
  if 0 <= x < width and 0 <= y < height:
116
- # 创建高斯核心点
117
- sigma = 10 # 调整这个值来改变热点大小
118
- kernel_size = 31 # 必须是奇数
119
  temp_heatmap = np.zeros((height, width), dtype=np.float32)
120
  temp_heatmap[y, x] = 1
121
  temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma)
@@ -138,17 +137,17 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
138
  })
139
 
140
  detection_info.append(frame_info)
141
-
142
- # 写入视频
143
  video_writer.write(frame)
144
 
145
  frame_count += 1
 
 
146
  if process_seconds and frame_count >= total_frames:
147
  break
148
 
149
- # 释放视频写入器
150
  video_writer.release()
151
-
152
  # 生成分析报告
153
  confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
154
  hist, bins = np.histogram(confidences, bins=5)
@@ -174,7 +173,6 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
174
  {confidence_report}
175
  """
176
 
177
- # 在生成轨迹图之前,添加异常点过滤和轨迹平滑
178
  def filter_trajectories(positions, width, height, max_jump_distance=100):
179
  """
180
  过滤轨迹中的异常点
@@ -193,23 +191,18 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
193
  for i, pos in enumerate(positions):
194
  x, y = pos
195
 
196
- # 检查点是否在有效范围内
197
  if not (0 <= x < width and 0 <= y < height):
198
  continue
199
 
200
- # 第一个有效点
201
  if last_valid_pos is None:
202
  filtered_positions.append(pos)
203
  last_valid_pos = pos
204
  continue
205
 
206
- # 计算与上一个点的距离
207
  distance = np.sqrt((x - last_valid_pos[0])**2 + (y - last_valid_pos[1])**2)
208
 
209
  if distance > max_jump_distance:
210
- # 如果距离太大,进行插值
211
  if len(filtered_positions) > 0:
212
- # 寻找下一个有效点
213
  next_valid_pos = None
214
  for next_pos in positions[i:]:
215
  nx, ny = next_pos
@@ -220,7 +213,6 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
220
  break
221
 
222
  if next_valid_pos is not None:
223
- # 线性插值
224
  steps = max(2, int(distance / max_jump_distance))
225
  for j in range(1, steps):
226
  alpha = j / steps
@@ -233,22 +225,18 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
233
  filtered_positions.append(pos)
234
  last_valid_pos = pos
235
 
236
- # 使用移动平均平滑轨迹
237
  window_size = 5
238
  smoothed_positions = []
239
 
240
  if len(filtered_positions) >= window_size:
241
- # 添加开始的点
242
  smoothed_positions.extend(filtered_positions[:window_size//2])
243
 
244
- # 平滑中间的点
245
  for i in range(window_size//2, len(filtered_positions) - window_size//2):
246
  window = filtered_positions[i-window_size//2:i+window_size//2+1]
247
  smoothed_x = int(np.mean([p[0] for p in window]))
248
  smoothed_y = int(np.mean([p[1] for p in window]))
249
  smoothed_positions.append([smoothed_x, smoothed_y])
250
 
251
- # 添加结束的点
252
  smoothed_positions.extend(filtered_positions[-window_size//2:])
253
  else:
254
  smoothed_positions = filtered_positions
@@ -256,17 +244,14 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
256
  return smoothed_positions
257
 
258
  # 修改轨迹图生成部分
259
- trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255 # 白色背景
260
  points = np.array(all_positions, dtype=np.int32)
261
  if len(points) > 1:
262
- # 过滤和平滑轨迹
263
  filtered_points = filter_trajectories(points.tolist(), width, height)
264
  points = np.array(filtered_points, dtype=np.int32)
265
 
266
- # 绘制轨迹线,使用渐变色
267
  for i in range(len(points) - 1):
268
  ratio = i / (len(points) - 1)
269
- # 使用从蓝到红的渐变色
270
  color = (
271
  int((1 - ratio) * 255), # B
272
  50, # G
@@ -274,60 +259,48 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
274
  )
275
  cv2.line(trajectory_img, tuple(points[i]), tuple(points[i + 1]), color, 2)
276
 
277
- # 绘制起点和终点
278
- cv2.circle(trajectory_img, tuple(points[0]), 8, (0, 255, 0), -1) # 绿色起点
279
- cv2.circle(trajectory_img, tuple(points[-1]), 8, (0, 0, 255), -1) # 红色终点
280
 
281
- # 每隔一定帧数添加方向箭头
282
- arrow_interval = max(len(points) // 20, 1) # 控制箭头数量
283
  for i in range(0, len(points) - arrow_interval, arrow_interval):
284
  pt1 = tuple(points[i])
285
  pt2 = tuple(points[i + arrow_interval])
286
- # 计算箭头方向
287
  angle = np.arctan2(pt2[1] - pt1[1], pt2[0] - pt1[0])
288
- # 绘制箭头
289
  cv2.arrowedLine(trajectory_img, pt1, pt2, (100, 100, 100), 1, tipLength=0.2)
290
 
291
- # 处理热图
292
- if np.max(heatmap) > 0: # 确保有数据
293
  heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
294
  heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
295
- # 添加一些透明度
296
  alpha = 0.7
297
  heatmap_colored = cv2.addWeighted(heatmap_colored, alpha, np.full_like(heatmap_colored, 255), 1-alpha, 0)
298
 
299
- # 准备GIF动画帧
300
  trajectory_frames = []
301
  heatmap_frames = []
302
 
303
- # 创建基础图像
304
  base_trajectory = np.zeros((height, width, 3), dtype=np.uint8) + 255
305
  base_heatmap = np.zeros((height, width), dtype=np.float32)
306
 
307
- # 每N帧生成一个动画帧
308
- frame_interval = max(1, len(filtered_points) // 50) # 控制GIF帧数
309
 
310
  for i in range(0, len(filtered_points), frame_interval):
311
  current_points = filtered_points[:i+1]
312
 
313
- # 轨迹动画帧 - 修改颜色方案,与静态图保持一致
314
  frame_trajectory = base_trajectory.copy()
315
  if len(current_points) > 1:
316
  points = np.array(current_points, dtype=np.int32)
317
  for j in range(len(points) - 1):
318
- ratio = j / (len(current_points) - 1) # 修改这里,使用当前总长度
319
  color = (
320
- int((1 - ratio) * 255), # B
321
- 50, # G
322
- int(ratio * 255) # R
323
  )
324
  cv2.line(frame_trajectory, tuple(points[j]), tuple(points[j + 1]), color, 2)
325
 
326
- # 绘制当前位置点
327
  cv2.circle(frame_trajectory, tuple(points[-1]), 8, (0, 0, 255), -1)
328
  trajectory_frames.append(frame_trajectory)
329
 
330
- # 热图动画帧 - 修改颜色方案,使用与静态图相同的配色
331
  frame_heatmap = base_heatmap.copy()
332
  for x, y in current_points:
333
  if 0 <= x < width and 0 <= y < height:
@@ -339,18 +312,15 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
339
  if np.max(frame_heatmap) > 0:
340
  frame_heatmap_norm = cv2.normalize(frame_heatmap, None, 0, 255, cv2.NORM_MINMAX)
341
  frame_heatmap_color = cv2.applyColorMap(frame_heatmap_norm.astype(np.uint8), cv2.COLORMAP_JET)
342
- # 添加白色背景,与静态图保持一致
343
  frame_heatmap_color = cv2.addWeighted(frame_heatmap_color, 0.7, np.full_like(frame_heatmap_color, 255), 0.3, 0)
344
  heatmap_frames.append(frame_heatmap_color)
345
 
346
- # 保存GIF动画 - 修改这部分
347
- trajectory_gif_path = output_path.replace('.mp4', '_trajectory.gif') # 使用.gif后缀
348
- heatmap_gif_path = output_path.replace('.mp4', '_heatmap.gif') # 使用.gif后缀
349
 
350
- imageio.mimsave(trajectory_gif_path, trajectory_frames, duration=50) # 50ms per frame
351
  imageio.mimsave(heatmap_gif_path, heatmap_frames, duration=50)
352
 
353
- # 保存图像
354
  trajectory_path = output_path.replace('.mp4', '_trajectory.png')
355
  heatmap_path = output_path.replace('.mp4', '_heatmap.png')
356
  cv2.imwrite(trajectory_path, trajectory_img)
@@ -362,14 +332,12 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
362
  with gr.Blocks() as demo:
363
  gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)")
364
 
365
- # 登录界面
366
  with gr.Group() as login_interface:
367
  username = gr.Textbox(label="用户名")
368
  password = gr.Textbox(label="密码", type="password")
369
  login_button = gr.Button("登录")
370
  login_msg = gr.Textbox(label="消息", interactive=False)
371
 
372
- # 主界面
373
  with gr.Group(visible=False) as main_interface:
374
  gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior")
375
 
@@ -412,7 +380,7 @@ with gr.Blocks() as demo:
412
  ### 使用说明
413
  1. 上传视频文件
414
  2. 设置处理参数:
415
- - 处理时长:需要分析的视频时长(秒
416
  - 置信度阈值:检测的置信度要求(越高越严格)
417
  - 最大检测数量:每帧最多检测的目标数量
418
  3. 等待处理完成
@@ -426,7 +394,6 @@ with gr.Blocks() as demo:
426
  - 最大检测数量建议根据实际场景设置
427
  """)
428
 
429
- # 设置事件处理
430
  login_button.click(
431
  fn=login,
432
  inputs=[username, password],
@@ -441,4 +408,4 @@ with gr.Blocks() as demo:
441
  )
442
 
443
  if __name__ == "__main__":
444
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
10
  from pathlib import Path
11
  import tempfile
12
  import imageio
13
+ from tqdm import tqdm # 新增: 导入tqdm
14
 
15
  # 从环境变量获取密码
16
  APP_USERNAME = "admin" # 用户名保持固定
 
72
  base_size = min(width, height)
73
  line_thickness = max(1, int(base_size * 0.002)) # 0.2% 的最小边长
74
 
75
+ # 修改: 设置推理参数并处理视频,关闭verbose输出
76
  results = model.predict(
77
  source=video_path,
78
  device=device,
 
80
  save=False,
81
  show=False,
82
  stream=True,
83
+ line_width=line_thickness,
84
+ boxes=True,
85
  show_labels=True,
86
  show_conf=True,
87
  vid_stride=1,
88
  max_det=max_det,
89
+ retina_masks=True,
90
+ verbose=False # 关闭YOLO默认日志输出
91
  )
92
 
93
  # 处理结果
94
  frame_count = 0
95
  detection_info = []
 
 
96
  all_positions = []
97
  heatmap = np.zeros((height, width), dtype=np.float32)
98
 
99
+ # 新增: 创建进度条
100
+ pbar = tqdm(total=total_frames, desc="处理视频", unit="帧")
101
+
102
  for r in results:
103
  frame = r.plot()
104
 
 
108
  if isinstance(kpts, torch.Tensor):
109
  kpts = kpts.cpu().numpy()
110
 
 
111
  if kpts.shape == (1, 8, 3): # [num_objects, num_keypoints, xyz]
112
+ x, y = int(kpts[0, 0, 0]), int(kpts[0, 0, 1])
113
  all_positions.append([x, y])
114
 
 
115
  if 0 <= x < width and 0 <= y < height:
116
+ sigma = 10
117
+ kernel_size = 31
 
118
  temp_heatmap = np.zeros((height, width), dtype=np.float32)
119
  temp_heatmap[y, x] = 1
120
  temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma)
 
137
  })
138
 
139
  detection_info.append(frame_info)
 
 
140
  video_writer.write(frame)
141
 
142
  frame_count += 1
143
+ pbar.update(1) # 更新进度条
144
+
145
  if process_seconds and frame_count >= total_frames:
146
  break
147
 
148
+ pbar.close() # 关闭进度条
149
  video_writer.release()
150
+
151
  # 生成分析报告
152
  confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
153
  hist, bins = np.histogram(confidences, bins=5)
 
173
  {confidence_report}
174
  """
175
 
 
176
  def filter_trajectories(positions, width, height, max_jump_distance=100):
177
  """
178
  过滤轨迹中的异常点
 
191
  for i, pos in enumerate(positions):
192
  x, y = pos
193
 
 
194
  if not (0 <= x < width and 0 <= y < height):
195
  continue
196
 
 
197
  if last_valid_pos is None:
198
  filtered_positions.append(pos)
199
  last_valid_pos = pos
200
  continue
201
 
 
202
  distance = np.sqrt((x - last_valid_pos[0])**2 + (y - last_valid_pos[1])**2)
203
 
204
  if distance > max_jump_distance:
 
205
  if len(filtered_positions) > 0:
 
206
  next_valid_pos = None
207
  for next_pos in positions[i:]:
208
  nx, ny = next_pos
 
213
  break
214
 
215
  if next_valid_pos is not None:
 
216
  steps = max(2, int(distance / max_jump_distance))
217
  for j in range(1, steps):
218
  alpha = j / steps
 
225
  filtered_positions.append(pos)
226
  last_valid_pos = pos
227
 
 
228
  window_size = 5
229
  smoothed_positions = []
230
 
231
  if len(filtered_positions) >= window_size:
 
232
  smoothed_positions.extend(filtered_positions[:window_size//2])
233
 
 
234
  for i in range(window_size//2, len(filtered_positions) - window_size//2):
235
  window = filtered_positions[i-window_size//2:i+window_size//2+1]
236
  smoothed_x = int(np.mean([p[0] for p in window]))
237
  smoothed_y = int(np.mean([p[1] for p in window]))
238
  smoothed_positions.append([smoothed_x, smoothed_y])
239
 
 
240
  smoothed_positions.extend(filtered_positions[-window_size//2:])
241
  else:
242
  smoothed_positions = filtered_positions
 
244
  return smoothed_positions
245
 
246
  # 修改轨迹图生成部分
247
+ trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255
248
  points = np.array(all_positions, dtype=np.int32)
249
  if len(points) > 1:
 
250
  filtered_points = filter_trajectories(points.tolist(), width, height)
251
  points = np.array(filtered_points, dtype=np.int32)
252
 
 
253
  for i in range(len(points) - 1):
254
  ratio = i / (len(points) - 1)
 
255
  color = (
256
  int((1 - ratio) * 255), # B
257
  50, # G
 
259
  )
260
  cv2.line(trajectory_img, tuple(points[i]), tuple(points[i + 1]), color, 2)
261
 
262
+ cv2.circle(trajectory_img, tuple(points[0]), 8, (0, 255, 0), -1)
263
+ cv2.circle(trajectory_img, tuple(points[-1]), 8, (0, 0, 255), -1)
 
264
 
265
+ arrow_interval = max(len(points) // 20, 1)
 
266
  for i in range(0, len(points) - arrow_interval, arrow_interval):
267
  pt1 = tuple(points[i])
268
  pt2 = tuple(points[i + arrow_interval])
 
269
  angle = np.arctan2(pt2[1] - pt1[1], pt2[0] - pt1[0])
 
270
  cv2.arrowedLine(trajectory_img, pt1, pt2, (100, 100, 100), 1, tipLength=0.2)
271
 
272
+ if np.max(heatmap) > 0:
 
273
  heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
274
  heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
 
275
  alpha = 0.7
276
  heatmap_colored = cv2.addWeighted(heatmap_colored, alpha, np.full_like(heatmap_colored, 255), 1-alpha, 0)
277
 
 
278
  trajectory_frames = []
279
  heatmap_frames = []
280
 
 
281
  base_trajectory = np.zeros((height, width, 3), dtype=np.uint8) + 255
282
  base_heatmap = np.zeros((height, width), dtype=np.float32)
283
 
284
+ frame_interval = max(1, len(filtered_points) // 50)
 
285
 
286
  for i in range(0, len(filtered_points), frame_interval):
287
  current_points = filtered_points[:i+1]
288
 
 
289
  frame_trajectory = base_trajectory.copy()
290
  if len(current_points) > 1:
291
  points = np.array(current_points, dtype=np.int32)
292
  for j in range(len(points) - 1):
293
+ ratio = j / (len(current_points) - 1)
294
  color = (
295
+ int((1 - ratio) * 255),
296
+ 50,
297
+ int(ratio * 255)
298
  )
299
  cv2.line(frame_trajectory, tuple(points[j]), tuple(points[j + 1]), color, 2)
300
 
 
301
  cv2.circle(frame_trajectory, tuple(points[-1]), 8, (0, 0, 255), -1)
302
  trajectory_frames.append(frame_trajectory)
303
 
 
304
  frame_heatmap = base_heatmap.copy()
305
  for x, y in current_points:
306
  if 0 <= x < width and 0 <= y < height:
 
312
  if np.max(frame_heatmap) > 0:
313
  frame_heatmap_norm = cv2.normalize(frame_heatmap, None, 0, 255, cv2.NORM_MINMAX)
314
  frame_heatmap_color = cv2.applyColorMap(frame_heatmap_norm.astype(np.uint8), cv2.COLORMAP_JET)
 
315
  frame_heatmap_color = cv2.addWeighted(frame_heatmap_color, 0.7, np.full_like(frame_heatmap_color, 255), 0.3, 0)
316
  heatmap_frames.append(frame_heatmap_color)
317
 
318
+ trajectory_gif_path = output_path.replace('.mp4', '_trajectory.gif')
319
+ heatmap_gif_path = output_path.replace('.mp4', '_heatmap.gif')
 
320
 
321
+ imageio.mimsave(trajectory_gif_path, trajectory_frames, duration=50)
322
  imageio.mimsave(heatmap_gif_path, heatmap_frames, duration=50)
323
 
 
324
  trajectory_path = output_path.replace('.mp4', '_trajectory.png')
325
  heatmap_path = output_path.replace('.mp4', '_heatmap.png')
326
  cv2.imwrite(trajectory_path, trajectory_img)
 
332
  with gr.Blocks() as demo:
333
  gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)")
334
 
 
335
  with gr.Group() as login_interface:
336
  username = gr.Textbox(label="用户名")
337
  password = gr.Textbox(label="密码", type="password")
338
  login_button = gr.Button("登录")
339
  login_msg = gr.Textbox(label="消息", interactive=False)
340
 
 
341
  with gr.Group(visible=False) as main_interface:
342
  gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior")
343
 
 
380
  ### 使用说明
381
  1. 上传视频文件
382
  2. 设置处理参数:
383
+ - 处理时长:需要分析的视频时长(秒)
384
  - 置信度阈值:检测的置信度要求(越高越严格)
385
  - 最大检测数量:每帧最多检测的目标数量
386
  3. 等待处理完成
 
394
  - 最大检测数量建议根据实际场景设置
395
  """)
396
 
 
397
  login_button.click(
398
  fn=login,
399
  inputs=[username, password],
 
408
  )
409
 
410
  if __name__ == "__main__":
411
+ demo.launch(server_name="0.0.0.0", server_port=7860)