Hakureirm commited on
Commit
a1af4e9
·
verified ·
1 Parent(s): 7ae1738

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -65
app.py CHANGED
@@ -7,36 +7,36 @@ import gradio as gr
7
  import matplotlib.pyplot as plt
8
 
9
  # GPU 可用性检查 & 日志
10
- evice_is_cuda = torch.cuda.is_available()
11
  print(f"CUDA available: {use_cuda}")
12
  if use_cuda:
13
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
14
 
15
  # 加载模型并指定分割任务
16
- yolo_model = YOLO("fst-v1.2-n.onnx", task="segment")
17
  if use_cuda:
18
  try:
19
- yolo_model.model.to("cuda")
20
  except:
21
  pass
22
 
23
- @spaces.GPU(duration=600) # ZeroGPU 上运行,超时 600s
24
  def analyze_video(video_path, num_mice, time_range, window_size_sec=1, fps=30):
25
  """
26
- 分割 → 跟踪 → 计算“挣扎强度”,只分析指定时间范围
27
- 返回:标注后视频 & 挣扎曲线 Figure
28
  """
29
- # 打开视频,获取基本信息
30
  cap = cv2.VideoCapture(video_path)
31
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
32
  vid_fps = cap.get(cv2.CAP_PROP_FPS) or fps
33
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
35
  start_s, end_s = time_range
36
- start_frame = int(start_s * vid_fps)
37
- end_frame = int(end_s * vid_fps)
38
 
39
- # 跳转到起始帧
40
  cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
41
 
42
  # 输出视频初始化
@@ -44,8 +44,8 @@ def analyze_video(video_path, num_mice, time_range, window_size_sec=1, fps=30):
44
  out_path = "output.mp4"
45
  out = cv2.VideoWriter(out_path, fourcc, vid_fps, (width, height))
46
 
47
- prev_centroids = [None] * num_mice
48
- prev_masks = [None] * num_mice
49
  struggle_records = [[] for _ in range(num_mice)]
50
  frame_idx = start_frame
51
 
@@ -54,71 +54,79 @@ def analyze_video(video_path, num_mice, time_range, window_size_sec=1, fps=30):
54
  if not ret:
55
  break
56
 
57
- # 分割推理\ device = "cuda" if use_cuda else "cpu"
58
- results = yolo_model(frame, stream=True, device=device, conf=0.25)
 
59
  res = next(results)
60
 
61
- # 处理无检测帧\ if res.masks is None or res.masks.data is None:
 
62
  for mid in range(num_mice):
63
  struggle_records[mid].append(None)
64
  out.write(frame)
65
  frame_idx += 1
66
  continue
67
 
 
68
  masks = res.masks.data.cpu().numpy() # (N, H_model, W_model)
69
-
70
- # 对齐掩膜至帧尺寸
71
- aligned = []
72
  for m in masks:
73
  m_bin = (m > 0).astype(np.uint8)
74
  m_res = cv2.resize(m_bin, (width, height), interpolation=cv2.INTER_NEAREST)
75
- aligned.append(m_res)
76
- aligned = np.array(aligned)
77
 
78
- # 计算质心 & 分配 ID
79
- curr_cent = []
80
- for m in aligned:
81
  ys, xs = np.where(m > 0)
82
- curr_cent.append((int(xs.mean()), int(ys.mean())) if xs.size else None)
83
- assign = [-1] * len(curr_cent)
84
- unused = set(range(num_mice))
85
- for i, c in enumerate(curr_cent):
86
- if c is None: continue
87
- best_j, best_d = None, float('inf')
88
- for j in unused:
 
89
  pc = prev_centroids[j]
90
- if pc is None: continue
91
- d = (c[0]-pc[0])**2 + (c[1]-pc[1])**2
 
92
  if d < best_d:
93
  best_j, best_d = j, d
94
  if best_j is not None and best_d < 50**2:
95
- assign[i] = best_j
96
- unused.remove(best_j)
97
- for i in range(len(curr_cent)):
98
- if assign[i] < 0 and unused:
99
- assign[i] = unused.pop()
100
-
101
- # 计算挣扎强度 & 叠加
102
- for i, m in enumerate(aligned):
103
- mid = assign[i]
104
- if mid < 0: continue
105
- pm = prev_masks[mid]
106
- if pm is None:
 
107
  struggle_records[mid].append(None)
108
  else:
109
- diff = int(np.logical_xor(pm, m).sum())
110
- struggle_records[mid].append(diff)
111
 
 
112
  mask_rgb = np.stack([
113
  np.zeros_like(m),
114
  m * 255,
115
  np.zeros_like(m)
116
  ], axis=-1).astype(np.uint8)
117
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
118
- if curr_cent[i]:
119
- cv2.putText(frame, f"ID:{mid}", curr_cent[i], cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
120
 
121
- prev_centroids[mid] = curr_cent[i]
 
 
 
 
 
122
  prev_masks[mid] = m.copy()
123
 
124
  out.write(frame)
@@ -132,11 +140,15 @@ def analyze_video(video_path, num_mice, time_range, window_size_sec=1, fps=30):
132
  fig, ax = plt.subplots(figsize=(8,4))
133
  times = np.arange(start_s, end_s, win/vid_fps)
134
  for mid, rec in enumerate(struggle_records):
135
- sums = [sum(v if v is not None else 0 for v in rec[i*win:(i+1)*win]) for i in range(len(times))]
 
 
 
136
  ax.plot(times, sums, label=f"Mouse {mid}")
137
- first = next((i for i,v in enumerate(rec) if v is not None), None)
138
- if first is not None:
139
- ax.axvspan(start_s, start_s+first/vid_fps, alpha=0.3, color='gray')
 
140
  ax.set_xlabel("Time (s)")
141
  ax.set_ylabel("Struggle Intensity")
142
  ax.set_title("Mouse Struggle Over Time")
@@ -148,25 +160,26 @@ def analyze_video(video_path, num_mice, time_range, window_size_sec=1, fps=30):
148
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
149
  gr.Markdown("上传视频,输入鼠标数量,选择分析时间范围,点击 Run")
150
  with gr.Row():
151
- video_in = gr.Video(label="Input Video")
152
- num_in = gr.Number(value=1, precision=0, label="Number of Mice")
153
- time_range = gr.RangeSlider(label="Analysis Time Range (s)", minimum=0, maximum=1, value=(0,1), step=1, disabled=True)
154
 
155
- # 上传视频后激活滑块并设置最大值
156
- def get_video_duration(path):
157
  cap = cv2.VideoCapture(path)
158
- fps = cap.get(cv2.CAP_PROP_FPS) or fps
159
- frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
160
- dur = int(frames / fps)
161
  cap.release()
162
- return gr.update(maximum=dur, value=(0, dur), disabled=False)
163
 
164
- video_in.change(fn=get_video_duration, inputs=video_in, outputs=time_range)
165
 
166
- run_btn = gr.Button("Run")
167
  output_video = gr.Video(label="Annotated Video")
168
  output_plot = gr.Plot(label="Struggle Plot")
169
- run_btn.click(fn=analyze_video, inputs=[video_in, num_in, time_range], outputs=[output_video, output_plot])
 
 
170
 
171
  if __name__ == "__main__":
172
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
7
  import matplotlib.pyplot as plt
8
 
9
  # GPU 可用性检查 & 日志
10
+ use_cuda = torch.cuda.is_available()
11
  print(f"CUDA available: {use_cuda}")
12
  if use_cuda:
13
  print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
14
 
15
  # 加载模型并指定分割任务
16
+ model = YOLO("fst-v1.2-n.onnx", task="segment")
17
  if use_cuda:
18
  try:
19
+ model.model.to("cuda")
20
  except:
21
  pass
22
 
23
+ @spaces.GPU(duration=600) # ZeroGPU 环境下执行该函数,超时 600s
24
  def analyze_video(video_path, num_mice, time_range, window_size_sec=1, fps=30):
25
  """
26
+ 分割 → 跟踪 → 计算挣扎强度,仅分析指定时间区间
27
+ 返回:标注后视频 & 绘制的挣扎强度曲线 (matplotlib Figure)
28
  """
29
+ # 打开视频并获取基本信息
30
  cap = cv2.VideoCapture(video_path)
31
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
32
  vid_fps = cap.get(cv2.CAP_PROP_FPS) or fps
33
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
35
  start_s, end_s = time_range
36
+ start_frame = min(int(start_s * vid_fps), total_frames)
37
+ end_frame = min(int(end_s * vid_fps), total_frames)
38
 
39
+ # 跳转到指定起始帧
40
  cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
41
 
42
  # 输出视频初始化
 
44
  out_path = "output.mp4"
45
  out = cv2.VideoWriter(out_path, fourcc, vid_fps, (width, height))
46
 
47
+ prev_centroids = [None] * num_mice
48
+ prev_masks = [None] * num_mice
49
  struggle_records = [[] for _ in range(num_mice)]
50
  frame_idx = start_frame
51
 
 
54
  if not ret:
55
  break
56
 
57
+ # 分割推理
58
+ device = "cuda" if use_cuda else "cpu"
59
+ results = model(frame, stream=True, device=device, conf=0.25)
60
  res = next(results)
61
 
62
+ # 无检测帧处理
63
+ if res.masks is None or res.masks.data is None:
64
  for mid in range(num_mice):
65
  struggle_records[mid].append(None)
66
  out.write(frame)
67
  frame_idx += 1
68
  continue
69
 
70
+ # 获取并对齐掩膜至帧尺寸
71
  masks = res.masks.data.cpu().numpy() # (N, H_model, W_model)
72
+ aligned_masks = []
 
 
73
  for m in masks:
74
  m_bin = (m > 0).astype(np.uint8)
75
  m_res = cv2.resize(m_bin, (width, height), interpolation=cv2.INTER_NEAREST)
76
+ aligned_masks.append(m_res)
77
+ aligned_masks = np.array(aligned_masks)
78
 
79
+ # 计算质心 & ID 分配 (nearest-centroid)
80
+ curr_centroids = []
81
+ for m in aligned_masks:
82
  ys, xs = np.where(m > 0)
83
+ curr_centroids.append((int(xs.mean()), int(ys.mean())) if xs.size else None)
84
+ assignments = [-1] * len(curr_centroids)
85
+ unused_ids = set(range(num_mice))
86
+ for i, c in enumerate(curr_centroids):
87
+ if c is None:
88
+ continue
89
+ best_j, best_d = None, float("inf")
90
+ for j in unused_ids:
91
  pc = prev_centroids[j]
92
+ if pc is None:
93
+ continue
94
+ d = (c[0] - pc[0])**2 + (c[1] - pc[1])**2
95
  if d < best_d:
96
  best_j, best_d = j, d
97
  if best_j is not None and best_d < 50**2:
98
+ assignments[i] = best_j
99
+ unused_ids.remove(best_j)
100
+ for i in range(len(curr_centroids)):
101
+ if assignments[i] < 0 and unused_ids:
102
+ assignments[i] = unused_ids.pop()
103
+
104
+ # 计算挣扎强度 & 可视化叠加
105
+ for i, m in enumerate(aligned_masks):
106
+ mid = assignments[i]
107
+ if mid < 0:
108
+ continue
109
+ prev_m = prev_masks[mid]
110
+ if prev_m is None:
111
  struggle_records[mid].append(None)
112
  else:
113
+ struggle = int(np.logical_xor(prev_m, m).sum())
114
+ struggle_records[mid].append(struggle)
115
 
116
+ # 构建三通道掩膜
117
  mask_rgb = np.stack([
118
  np.zeros_like(m),
119
  m * 255,
120
  np.zeros_like(m)
121
  ], axis=-1).astype(np.uint8)
122
  frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
 
 
123
 
124
+ centroid = curr_centroids[i]
125
+ if centroid:
126
+ cv2.putText(frame, f"ID:{mid}", centroid,
127
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
128
+
129
+ prev_centroids[mid] = curr_centroids[i]
130
  prev_masks[mid] = m.copy()
131
 
132
  out.write(frame)
 
140
  fig, ax = plt.subplots(figsize=(8,4))
141
  times = np.arange(start_s, end_s, win/vid_fps)
142
  for mid, rec in enumerate(struggle_records):
143
+ sums = []
144
+ for i in range(len(times)):
145
+ segment = rec[i*win:(i+1)*win]
146
+ sums.append(sum(v if v is not None else 0 for v in segment))
147
  ax.plot(times, sums, label=f"Mouse {mid}")
148
+ first_valid = next((i for i,v in enumerate(rec) if v is not None), None)
149
+ if first_valid is not None:
150
+ ax.axvspan(start_s, start_s+first_valid/vid_fps, alpha=0.3, color='gray')
151
+
152
  ax.set_xlabel("Time (s)")
153
  ax.set_ylabel("Struggle Intensity")
154
  ax.set_title("Mouse Struggle Over Time")
 
160
  with gr.Blocks(title="Mice Struggle Analysis") as demo:
161
  gr.Markdown("上传视频,输入鼠标数量,选择分析时间范围,点击 Run")
162
  with gr.Row():
163
+ video_input = gr.Video(label="Input Video")
164
+ num_input = gr.Number(value=1, precision=0, label="Number of Mice")
165
+ time_range = gr.RangeSlider(label="Analysis Time Range (s)", minimum=0, maximum=1, value=(0,1), step=1, disabled=True)
166
 
167
+ def enable_slider(path):
 
168
  cap = cv2.VideoCapture(path)
169
+ vid_fps = cap.get(cv2.CAP_PROP_FPS) or fps
170
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
171
+ duration = total_frames / vid_fps
172
  cap.release()
173
+ return gr.update(maximum=duration, value=(0,duration), disabled=False)
174
 
175
+ video_input.change(fn=enable_slider, inputs=video_input, outputs=time_range)
176
 
177
+ run_button = gr.Button("Run")
178
  output_video = gr.Video(label="Annotated Video")
179
  output_plot = gr.Plot(label="Struggle Plot")
180
+ run_button.click(fn=analyze_video,
181
+ inputs=[video_input, num_input, time_range],
182
+ outputs=[output_video, output_plot])
183
 
184
  if __name__ == "__main__":
185
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)