Hakureirm commited on
Commit
a0c5b0b
·
verified ·
1 Parent(s): a1af4e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -171
app.py CHANGED
@@ -1,185 +1,271 @@
1
- import spaces # 必须最先 import,用于 ZeroGPU 装饰
2
  import cv2
3
  import numpy as np
4
- import torch
5
- from ultralytics import YOLO # pip install ultralytics
6
  import gradio as gr
7
- import matplotlib.pyplot as plt
 
8
 
9
- # GPU 可用性检查 & 日志
10
- use_cuda = torch.cuda.is_available()
11
- print(f"CUDA available: {use_cuda}")
12
- if use_cuda:
13
- print(f"GPU Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
 
14
 
15
- # 加载模型并指定分割任务
16
- model = YOLO("fst-v1.2-n.onnx", task="segment")
17
- if use_cuda:
18
- try:
19
- model.model.to("cuda")
20
- except:
21
- pass
22
-
23
- @spaces.GPU(duration=600) # ZeroGPU 环境下执行该函数,超时 600s
24
- def analyze_video(video_path, num_mice, time_range, window_size_sec=1, fps=30):
25
- """
26
- 分割 → 跟踪 → 计算挣扎强度,仅分析指定时间区间
27
- 返回:标注后视频 & 绘制的挣扎强度曲线 (matplotlib Figure)
28
- """
29
- # 打开视频并获取基本信息
30
  cap = cv2.VideoCapture(video_path)
31
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
32
- vid_fps = cap.get(cv2.CAP_PROP_FPS) or fps
33
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
35
- start_s, end_s = time_range
36
- start_frame = min(int(start_s * vid_fps), total_frames)
37
- end_frame = min(int(end_s * vid_fps), total_frames)
38
-
39
- # 跳转到指定起始帧
40
- cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
41
-
42
- # 输出视频初始化
43
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
44
- out_path = "output.mp4"
45
- out = cv2.VideoWriter(out_path, fourcc, vid_fps, (width, height))
46
-
47
- prev_centroids = [None] * num_mice
48
- prev_masks = [None] * num_mice
49
- struggle_records = [[] for _ in range(num_mice)]
50
- frame_idx = start_frame
51
-
52
- while frame_idx <= end_frame:
53
- ret, frame = cap.read()
54
- if not ret:
55
- break
56
-
57
- # 分割推理
58
- device = "cuda" if use_cuda else "cpu"
59
- results = model(frame, stream=True, device=device, conf=0.25)
60
- res = next(results)
61
-
62
- # 无检测帧处理
63
- if res.masks is None or res.masks.data is None:
64
- for mid in range(num_mice):
65
- struggle_records[mid].append(None)
66
- out.write(frame)
67
- frame_idx += 1
68
- continue
69
-
70
- # 获取并对齐掩膜至帧尺寸
71
- masks = res.masks.data.cpu().numpy() # (N, H_model, W_model)
72
- aligned_masks = []
73
- for m in masks:
74
- m_bin = (m > 0).astype(np.uint8)
75
- m_res = cv2.resize(m_bin, (width, height), interpolation=cv2.INTER_NEAREST)
76
- aligned_masks.append(m_res)
77
- aligned_masks = np.array(aligned_masks)
78
-
79
- # 计算质心 & ID 分配 (nearest-centroid)
80
- curr_centroids = []
81
- for m in aligned_masks:
82
- ys, xs = np.where(m > 0)
83
- curr_centroids.append((int(xs.mean()), int(ys.mean())) if xs.size else None)
84
- assignments = [-1] * len(curr_centroids)
85
- unused_ids = set(range(num_mice))
86
- for i, c in enumerate(curr_centroids):
87
- if c is None:
88
- continue
89
- best_j, best_d = None, float("inf")
90
- for j in unused_ids:
91
- pc = prev_centroids[j]
92
- if pc is None:
93
- continue
94
- d = (c[0] - pc[0])**2 + (c[1] - pc[1])**2
95
- if d < best_d:
96
- best_j, best_d = j, d
97
- if best_j is not None and best_d < 50**2:
98
- assignments[i] = best_j
99
- unused_ids.remove(best_j)
100
- for i in range(len(curr_centroids)):
101
- if assignments[i] < 0 and unused_ids:
102
- assignments[i] = unused_ids.pop()
103
-
104
- # 计算挣扎强度 & 可视化叠加
105
- for i, m in enumerate(aligned_masks):
106
- mid = assignments[i]
107
- if mid < 0:
108
- continue
109
- prev_m = prev_masks[mid]
110
- if prev_m is None:
111
- struggle_records[mid].append(None)
112
- else:
113
- struggle = int(np.logical_xor(prev_m, m).sum())
114
- struggle_records[mid].append(struggle)
115
-
116
- # 构建三通道掩膜
117
- mask_rgb = np.stack([
118
- np.zeros_like(m),
119
- m * 255,
120
- np.zeros_like(m)
121
- ], axis=-1).astype(np.uint8)
122
- frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
123
-
124
- centroid = curr_centroids[i]
125
- if centroid:
126
- cv2.putText(frame, f"ID:{mid}", centroid,
127
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
128
-
129
- prev_centroids[mid] = curr_centroids[i]
130
- prev_masks[mid] = m.copy()
131
-
132
- out.write(frame)
133
- frame_idx += 1
134
-
135
  cap.release()
136
- out.release()
 
 
 
 
 
 
137
 
138
- # 绘制挣扎曲线
139
- win = int(window_size_sec * vid_fps)
140
- fig, ax = plt.subplots(figsize=(8,4))
141
- times = np.arange(start_s, end_s, win/vid_fps)
142
- for mid, rec in enumerate(struggle_records):
143
- sums = []
144
- for i in range(len(times)):
145
- segment = rec[i*win:(i+1)*win]
146
- sums.append(sum(v if v is not None else 0 for v in segment))
147
- ax.plot(times, sums, label=f"Mouse {mid}")
148
- first_valid = next((i for i,v in enumerate(rec) if v is not None), None)
149
- if first_valid is not None:
150
- ax.axvspan(start_s, start_s+first_valid/vid_fps, alpha=0.3, color='gray')
151
-
152
- ax.set_xlabel("Time (s)")
153
- ax.set_ylabel("Struggle Intensity")
154
- ax.set_title("Mouse Struggle Over Time")
155
- ax.legend()
156
-
157
- return out_path, fig
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- # Gradio 前端
160
- with gr.Blocks(title="Mice Struggle Analysis") as demo:
161
- gr.Markdown("上传视频,输入鼠标数量,选择分析时间范围,点击 Run")
162
- with gr.Row():
163
- video_input = gr.Video(label="Input Video")
164
- num_input = gr.Number(value=1, precision=0, label="Number of Mice")
165
- time_range = gr.RangeSlider(label="Analysis Time Range (s)", minimum=0, maximum=1, value=(0,1), step=1, disabled=True)
 
 
166
 
167
- def enable_slider(path):
168
- cap = cv2.VideoCapture(path)
169
- vid_fps = cap.get(cv2.CAP_PROP_FPS) or fps
170
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
171
- duration = total_frames / vid_fps
172
- cap.release()
173
- return gr.update(maximum=duration, value=(0,duration), disabled=False)
 
 
 
 
174
 
175
- video_input.change(fn=enable_slider, inputs=video_input, outputs=time_range)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- run_button = gr.Button("Run")
178
- output_video = gr.Video(label="Annotated Video")
179
- output_plot = gr.Plot(label="Struggle Plot")
180
- run_button.click(fn=analyze_video,
181
- inputs=[video_input, num_input, time_range],
182
- outputs=[output_video, output_plot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
 
184
  if __name__ == "__main__":
185
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import cv2
3
  import numpy as np
 
 
4
  import gradio as gr
5
+ import tempfile
6
+ from mouse_tracker import MouseTrackerAnalyzer
7
 
8
+ # 全局变量
9
+ analyzer = None
10
+ video_file_path = None
11
+ model_file_path = None
12
+ total_frames = 0
13
+ output_path = None
14
 
15
+ # 从视频中提取特定帧
16
+ def extract_frame(video_path, frame_num):
17
+ """从视频中提取特定帧"""
18
+ if not video_path:
19
+ return None
20
+
 
 
 
 
 
 
 
 
 
21
  cap = cv2.VideoCapture(video_path)
22
+ if not cap.isOpened():
23
+ return None
24
+
25
+ # 设置帧位置
26
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
27
+
28
+ # 读取帧
29
+ ret, frame = cap.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  cap.release()
31
+
32
+ if not ret:
33
+ return None
34
+
35
+ # 转换为RGB格式
36
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
37
+ return frame_rgb
38
 
39
+ # 选择视频文件
40
+ def select_video(video_file):
41
+ global video_file_path, total_frames
42
+
43
+ if not video_file:
44
+ return None, "Please select a video file", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
45
+
46
+ video_file_path = video_file
47
+
48
+ # 获取视频总帧数
49
+ cap = cv2.VideoCapture(video_file_path)
50
+ if not cap.isOpened():
51
+ return None, "Cannot open video file", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
52
+
53
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
54
+
55
+ # 提取第一帧
56
+ ret, first_frame = cap.read()
57
+ cap.release()
58
+
59
+ if not ret:
60
+ return None, "Cannot read video frame", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
61
+
62
+ # 转为RGB
63
+ first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
64
+
65
+ # 更新帧滑块
66
+ start_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
67
+ end_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
68
+
69
+ return first_frame_rgb, f"Video loaded successfully, total frames: {total_frames}", start_slider, end_slider
70
 
71
+ # 选择模型文件
72
+ def select_model(model_file):
73
+ global model_file_path
74
+
75
+ if model_file is None:
76
+ return "Please select a model file"
77
+
78
+ model_file_path = model_file
79
+ return f"Model selected: {os.path.basename(model_file_path)}"
80
 
81
+ # 预览帧
82
+ def preview_frame(video_file, frame_num):
83
+ if not video_file:
84
+ return None, "Please select a video first"
85
+
86
+ # 从视频提取帧
87
+ frame = extract_frame(video_file, frame_num)
88
+ if frame is None:
89
+ return None, "Cannot read specified frame"
90
+
91
+ return frame, f"Frame {frame_num}"
92
 
93
+ # 开始分析
94
+ def start_analysis(video, model, conf, iou, max_det, start_frame, end_frame, threshold):
95
+ global analyzer, output_path
96
+
97
+ if not video or not model:
98
+ return None, None, "Please select a video and model file"
99
+
100
+ if start_frame >= end_frame:
101
+ return None, None, "Start frame must be less than end frame"
102
+
103
+ # 创建输出路径
104
+ video_name = os.path.splitext(os.path.basename(video))[0]
105
+ output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
106
+ csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
107
+
108
+ try:
109
+ # 创建分析器
110
+ analyzer = MouseTrackerAnalyzer(
111
+ model_path=model,
112
+ conf=conf,
113
+ iou=iou,
114
+ max_det=max_det,
115
+ verbose=True # 开启详细日志
116
+ )
117
+ analyzer.struggle_threshold = threshold
118
+
119
+ # 处理视频的进度回调
120
+ def progress_update(progress, frame, results):
121
+ print(f"Processing: {progress}%, Objects detected: {len(results)}")
122
+
123
+ print(f"Processing video: {video}")
124
+ print(f"Output path: {output_path}")
125
+ print(f"Parameters: conf={conf}, iou={iou}, max_det={max_det}, threshold={threshold}")
126
+
127
+ # 提取视频帧数范围并分析
128
+ results = analyzer.process_video(
129
+ video_path=video,
130
+ output_path=output_path,
131
+ start_frame=start_frame,
132
+ end_frame=end_frame,
133
+ callback=progress_update
134
+ )
135
+
136
+ # 保存结果到CSV
137
+ print(f"Saving results to CSV: {csv_path}")
138
+ analyzer.save_results(csv_path)
139
+ print(f"Results saved to CSV with {len(analyzer.results)} frames of data")
140
+
141
+ # 生成分析图表
142
+ print("Generating time series plot...")
143
+ if len(analyzer.results) == 0:
144
+ print("WARNING: No results available for plotting!")
145
+ plot_path = None
146
+ else:
147
+ plot_path = analyzer.generate_time_series_plot()
148
+ if plot_path and os.path.exists(plot_path):
149
+ print(f"Plot generated and saved to: {plot_path}, size: {os.path.getsize(plot_path)/1024:.2f}KB")
150
+ else:
151
+ print(f"Failed to generate plot or plot file does not exist!")
152
+ plot_path = None
153
+
154
+ # 检查输出文件是否存在
155
+ if os.path.exists(output_path):
156
+ file_size = os.path.getsize(output_path) / (1024 * 1024) # MB
157
+ print(f"Output video size: {file_size:.2f}MB")
158
+
159
+ # 处理debug帧
160
+ debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
161
+ if os.path.exists(debug_frame_path):
162
+ print(f"Debug frame saved at: {debug_frame_path}")
163
+
164
+ if plot_path and os.path.exists(plot_path):
165
+ print(f"Plot file exists at: {plot_path}, size: {os.path.getsize(plot_path)/1024:.2f}KB")
166
+
167
+ # 确保返回正确的文件路径
168
+ status_message = "Analysis complete. "
169
+
170
+ if os.path.exists(output_path):
171
+ status_message += f"Video saved."
172
+ else:
173
+ status_message += "WARNING: Output video not found. "
174
+
175
+ if plot_path and os.path.exists(plot_path):
176
+ status_message += f" Time series plot generated."
177
+ else:
178
+ status_message += " WARNING: Failed to generate time series plot."
179
+
180
+ status_message += f" Results saved to: {csv_path}"
181
+
182
+ return output_path, plot_path, status_message
183
+ except Exception as e:
184
+ import traceback
185
+ traceback.print_exc()
186
+ return None, None, f"Processing error: {str(e)}"
187
 
188
+ # 创建Gradio界面
189
+ def create_interface():
190
+ with gr.Blocks(title="Mouse Struggle Analysis - Object Tracking") as app:
191
+ gr.Markdown("# Mouse Forced Swim Test Struggle Analysis (Object Tracking)")
192
+
193
+ with gr.Row():
194
+ with gr.Column(scale=1):
195
+ # 视频和模型选择
196
+ video_input = gr.Video(label="Input Video")
197
+ model_input = gr.File(label="Model File (.pt format recommended)")
198
+
199
+ # 参数设置
200
+ with gr.Row():
201
+ conf = gr.Slider(minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="Confidence Threshold")
202
+ iou = gr.Slider(minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="IoU Threshold")
203
+
204
+ with gr.Row():
205
+ max_det = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Max Detections")
206
+ threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="Struggle Threshold")
207
+
208
+ # 帧选择
209
+ start_frame = gr.Slider(minimum=0, maximum=999999, value=0, step=1, label="Start Frame")
210
+ end_frame = gr.Slider(minimum=0, maximum=999999, value=999999, step=1, label="End Frame")
211
+
212
+ # 预览按钮
213
+ preview_btn = gr.Button("Preview Frame")
214
+
215
+ # 开始分析
216
+ start_btn = gr.Button("Start Analysis", variant="primary")
217
+
218
+ with gr.Column(scale=2):
219
+ # 显示区域
220
+ with gr.Tab("Preview"):
221
+ # 图像预览
222
+ preview_image = gr.Image(label="Preview Image", type="numpy", height=400)
223
+ status_text = gr.Textbox(label="Status", interactive=False)
224
+ gr.Markdown("""
225
+ ### Instructions:
226
+ 1. Select a video and model file (.pt format segmentation model like yolov8n-seg.pt recommended)
227
+ 2. Adjust parameters
228
+ - Confidence Threshold: Minimum confidence for object detection, lower values detect more potential objects
229
+ - IoU Threshold: For filtering overlapping detections
230
+ - Max Detections: Maximum number of objects to detect per frame
231
+ - Struggle Threshold: Minimum score to classify as struggle state
232
+ 3. Set frame range
233
+ 4. Click "Start Analysis" button
234
+
235
+ The system will automatically track mice and analyze their struggle behavior, no need to manually define regions
236
+ """)
237
+
238
+ with gr.Tab("Results"):
239
+ with gr.Row():
240
+ output_video = gr.Video(label="Analysis Result Video")
241
+ result_plot = gr.Image(label="Struggle Score Time Series")
242
+
243
+ result_status = gr.Textbox(label="Analysis Status", interactive=False)
244
+
245
+ # 绑定事件
246
+ video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
247
+ model_input.change(select_model, inputs=[model_input], outputs=[status_text])
248
+
249
+ preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
250
+
251
+ start_btn.click(
252
+ start_analysis,
253
+ inputs=[video_input, model_input, conf, iou, max_det, start_frame, end_frame, threshold],
254
+ outputs=[output_video, result_plot, result_status]
255
+ )
256
+
257
+ return app
258
 
259
+ # 启动应用
260
  if __name__ == "__main__":
261
+ # 清除可能干扰的代理设置
262
+ if 'http_proxy' in os.environ:
263
+ del os.environ['http_proxy']
264
+ if 'https_proxy' in os.environ:
265
+ del os.environ['https_proxy']
266
+ if 'all_proxy' in os.environ:
267
+ del os.environ['all_proxy']
268
+
269
+ app = create_interface()
270
+ # 使用简化的启动配置
271
+ app.launch(server_name="127.0.0.1", server_port=7860, share=False)