Update app.py
Browse files
app.py
CHANGED
@@ -1,185 +1,271 @@
|
|
1 |
-
import
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
-
import torch
|
5 |
-
from ultralytics import YOLO # pip install ultralytics
|
6 |
import gradio as gr
|
7 |
-
import
|
|
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
pass
|
22 |
-
|
23 |
-
@spaces.GPU(duration=600) # ZeroGPU 环境下执行该函数,超时 600s
|
24 |
-
def analyze_video(video_path, num_mice, time_range, window_size_sec=1, fps=30):
|
25 |
-
"""
|
26 |
-
分割 → 跟踪 → 计算挣扎强度,仅分析指定时间区间
|
27 |
-
返回:标注后视频 & 绘制的挣扎强度曲线 (matplotlib Figure)
|
28 |
-
"""
|
29 |
-
# 打开视频并获取基本信息
|
30 |
cap = cv2.VideoCapture(video_path)
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
# 跳转到指定起始帧
|
40 |
-
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
41 |
-
|
42 |
-
# 输出视频初始化
|
43 |
-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
44 |
-
out_path = "output.mp4"
|
45 |
-
out = cv2.VideoWriter(out_path, fourcc, vid_fps, (width, height))
|
46 |
-
|
47 |
-
prev_centroids = [None] * num_mice
|
48 |
-
prev_masks = [None] * num_mice
|
49 |
-
struggle_records = [[] for _ in range(num_mice)]
|
50 |
-
frame_idx = start_frame
|
51 |
-
|
52 |
-
while frame_idx <= end_frame:
|
53 |
-
ret, frame = cap.read()
|
54 |
-
if not ret:
|
55 |
-
break
|
56 |
-
|
57 |
-
# 分割推理
|
58 |
-
device = "cuda" if use_cuda else "cpu"
|
59 |
-
results = model(frame, stream=True, device=device, conf=0.25)
|
60 |
-
res = next(results)
|
61 |
-
|
62 |
-
# 无检测帧处理
|
63 |
-
if res.masks is None or res.masks.data is None:
|
64 |
-
for mid in range(num_mice):
|
65 |
-
struggle_records[mid].append(None)
|
66 |
-
out.write(frame)
|
67 |
-
frame_idx += 1
|
68 |
-
continue
|
69 |
-
|
70 |
-
# 获取并对齐掩膜至帧尺寸
|
71 |
-
masks = res.masks.data.cpu().numpy() # (N, H_model, W_model)
|
72 |
-
aligned_masks = []
|
73 |
-
for m in masks:
|
74 |
-
m_bin = (m > 0).astype(np.uint8)
|
75 |
-
m_res = cv2.resize(m_bin, (width, height), interpolation=cv2.INTER_NEAREST)
|
76 |
-
aligned_masks.append(m_res)
|
77 |
-
aligned_masks = np.array(aligned_masks)
|
78 |
-
|
79 |
-
# 计算质心 & ID 分配 (nearest-centroid)
|
80 |
-
curr_centroids = []
|
81 |
-
for m in aligned_masks:
|
82 |
-
ys, xs = np.where(m > 0)
|
83 |
-
curr_centroids.append((int(xs.mean()), int(ys.mean())) if xs.size else None)
|
84 |
-
assignments = [-1] * len(curr_centroids)
|
85 |
-
unused_ids = set(range(num_mice))
|
86 |
-
for i, c in enumerate(curr_centroids):
|
87 |
-
if c is None:
|
88 |
-
continue
|
89 |
-
best_j, best_d = None, float("inf")
|
90 |
-
for j in unused_ids:
|
91 |
-
pc = prev_centroids[j]
|
92 |
-
if pc is None:
|
93 |
-
continue
|
94 |
-
d = (c[0] - pc[0])**2 + (c[1] - pc[1])**2
|
95 |
-
if d < best_d:
|
96 |
-
best_j, best_d = j, d
|
97 |
-
if best_j is not None and best_d < 50**2:
|
98 |
-
assignments[i] = best_j
|
99 |
-
unused_ids.remove(best_j)
|
100 |
-
for i in range(len(curr_centroids)):
|
101 |
-
if assignments[i] < 0 and unused_ids:
|
102 |
-
assignments[i] = unused_ids.pop()
|
103 |
-
|
104 |
-
# 计算挣扎强度 & 可视化叠加
|
105 |
-
for i, m in enumerate(aligned_masks):
|
106 |
-
mid = assignments[i]
|
107 |
-
if mid < 0:
|
108 |
-
continue
|
109 |
-
prev_m = prev_masks[mid]
|
110 |
-
if prev_m is None:
|
111 |
-
struggle_records[mid].append(None)
|
112 |
-
else:
|
113 |
-
struggle = int(np.logical_xor(prev_m, m).sum())
|
114 |
-
struggle_records[mid].append(struggle)
|
115 |
-
|
116 |
-
# 构建三通道掩膜
|
117 |
-
mask_rgb = np.stack([
|
118 |
-
np.zeros_like(m),
|
119 |
-
m * 255,
|
120 |
-
np.zeros_like(m)
|
121 |
-
], axis=-1).astype(np.uint8)
|
122 |
-
frame = cv2.addWeighted(frame, 1, mask_rgb, 0.5, 0)
|
123 |
-
|
124 |
-
centroid = curr_centroids[i]
|
125 |
-
if centroid:
|
126 |
-
cv2.putText(frame, f"ID:{mid}", centroid,
|
127 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
|
128 |
-
|
129 |
-
prev_centroids[mid] = curr_centroids[i]
|
130 |
-
prev_masks[mid] = m.copy()
|
131 |
-
|
132 |
-
out.write(frame)
|
133 |
-
frame_idx += 1
|
134 |
-
|
135 |
cap.release()
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
#
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
174 |
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
|
|
184 |
if __name__ == "__main__":
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
import cv2
|
3 |
import numpy as np
|
|
|
|
|
4 |
import gradio as gr
|
5 |
+
import tempfile
|
6 |
+
from mouse_tracker import MouseTrackerAnalyzer
|
7 |
|
8 |
+
# 全局变量
|
9 |
+
analyzer = None
|
10 |
+
video_file_path = None
|
11 |
+
model_file_path = None
|
12 |
+
total_frames = 0
|
13 |
+
output_path = None
|
14 |
|
15 |
+
# 从视频中提取特定帧
|
16 |
+
def extract_frame(video_path, frame_num):
|
17 |
+
"""从视频中提取特定帧"""
|
18 |
+
if not video_path:
|
19 |
+
return None
|
20 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
cap = cv2.VideoCapture(video_path)
|
22 |
+
if not cap.isOpened():
|
23 |
+
return None
|
24 |
+
|
25 |
+
# 设置帧位置
|
26 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
|
27 |
+
|
28 |
+
# 读取帧
|
29 |
+
ret, frame = cap.read()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
cap.release()
|
31 |
+
|
32 |
+
if not ret:
|
33 |
+
return None
|
34 |
+
|
35 |
+
# 转换为RGB格式
|
36 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
37 |
+
return frame_rgb
|
38 |
|
39 |
+
# 选择视频文件
|
40 |
+
def select_video(video_file):
|
41 |
+
global video_file_path, total_frames
|
42 |
+
|
43 |
+
if not video_file:
|
44 |
+
return None, "Please select a video file", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
|
45 |
+
|
46 |
+
video_file_path = video_file
|
47 |
+
|
48 |
+
# 获取视频总帧数
|
49 |
+
cap = cv2.VideoCapture(video_file_path)
|
50 |
+
if not cap.isOpened():
|
51 |
+
return None, "Cannot open video file", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
|
52 |
+
|
53 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
54 |
+
|
55 |
+
# 提取第一帧
|
56 |
+
ret, first_frame = cap.read()
|
57 |
+
cap.release()
|
58 |
+
|
59 |
+
if not ret:
|
60 |
+
return None, "Cannot read video frame", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
|
61 |
+
|
62 |
+
# 转为RGB
|
63 |
+
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
64 |
+
|
65 |
+
# 更新帧滑块
|
66 |
+
start_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
|
67 |
+
end_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
|
68 |
+
|
69 |
+
return first_frame_rgb, f"Video loaded successfully, total frames: {total_frames}", start_slider, end_slider
|
70 |
|
71 |
+
# 选择模型文件
|
72 |
+
def select_model(model_file):
|
73 |
+
global model_file_path
|
74 |
+
|
75 |
+
if model_file is None:
|
76 |
+
return "Please select a model file"
|
77 |
+
|
78 |
+
model_file_path = model_file
|
79 |
+
return f"Model selected: {os.path.basename(model_file_path)}"
|
80 |
|
81 |
+
# 预览帧
|
82 |
+
def preview_frame(video_file, frame_num):
|
83 |
+
if not video_file:
|
84 |
+
return None, "Please select a video first"
|
85 |
+
|
86 |
+
# 从视频提取帧
|
87 |
+
frame = extract_frame(video_file, frame_num)
|
88 |
+
if frame is None:
|
89 |
+
return None, "Cannot read specified frame"
|
90 |
+
|
91 |
+
return frame, f"Frame {frame_num}"
|
92 |
|
93 |
+
# 开始分析
|
94 |
+
def start_analysis(video, model, conf, iou, max_det, start_frame, end_frame, threshold):
|
95 |
+
global analyzer, output_path
|
96 |
+
|
97 |
+
if not video or not model:
|
98 |
+
return None, None, "Please select a video and model file"
|
99 |
+
|
100 |
+
if start_frame >= end_frame:
|
101 |
+
return None, None, "Start frame must be less than end frame"
|
102 |
+
|
103 |
+
# 创建输出路径
|
104 |
+
video_name = os.path.splitext(os.path.basename(video))[0]
|
105 |
+
output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
|
106 |
+
csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
|
107 |
+
|
108 |
+
try:
|
109 |
+
# 创建分析器
|
110 |
+
analyzer = MouseTrackerAnalyzer(
|
111 |
+
model_path=model,
|
112 |
+
conf=conf,
|
113 |
+
iou=iou,
|
114 |
+
max_det=max_det,
|
115 |
+
verbose=True # 开启详细日志
|
116 |
+
)
|
117 |
+
analyzer.struggle_threshold = threshold
|
118 |
+
|
119 |
+
# 处理视频的进度回调
|
120 |
+
def progress_update(progress, frame, results):
|
121 |
+
print(f"Processing: {progress}%, Objects detected: {len(results)}")
|
122 |
+
|
123 |
+
print(f"Processing video: {video}")
|
124 |
+
print(f"Output path: {output_path}")
|
125 |
+
print(f"Parameters: conf={conf}, iou={iou}, max_det={max_det}, threshold={threshold}")
|
126 |
+
|
127 |
+
# 提取视频帧数范围并分析
|
128 |
+
results = analyzer.process_video(
|
129 |
+
video_path=video,
|
130 |
+
output_path=output_path,
|
131 |
+
start_frame=start_frame,
|
132 |
+
end_frame=end_frame,
|
133 |
+
callback=progress_update
|
134 |
+
)
|
135 |
+
|
136 |
+
# 保存结果到CSV
|
137 |
+
print(f"Saving results to CSV: {csv_path}")
|
138 |
+
analyzer.save_results(csv_path)
|
139 |
+
print(f"Results saved to CSV with {len(analyzer.results)} frames of data")
|
140 |
+
|
141 |
+
# 生成分析图表
|
142 |
+
print("Generating time series plot...")
|
143 |
+
if len(analyzer.results) == 0:
|
144 |
+
print("WARNING: No results available for plotting!")
|
145 |
+
plot_path = None
|
146 |
+
else:
|
147 |
+
plot_path = analyzer.generate_time_series_plot()
|
148 |
+
if plot_path and os.path.exists(plot_path):
|
149 |
+
print(f"Plot generated and saved to: {plot_path}, size: {os.path.getsize(plot_path)/1024:.2f}KB")
|
150 |
+
else:
|
151 |
+
print(f"Failed to generate plot or plot file does not exist!")
|
152 |
+
plot_path = None
|
153 |
+
|
154 |
+
# 检查输出文件是否存在
|
155 |
+
if os.path.exists(output_path):
|
156 |
+
file_size = os.path.getsize(output_path) / (1024 * 1024) # MB
|
157 |
+
print(f"Output video size: {file_size:.2f}MB")
|
158 |
+
|
159 |
+
# 处理debug帧
|
160 |
+
debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
|
161 |
+
if os.path.exists(debug_frame_path):
|
162 |
+
print(f"Debug frame saved at: {debug_frame_path}")
|
163 |
+
|
164 |
+
if plot_path and os.path.exists(plot_path):
|
165 |
+
print(f"Plot file exists at: {plot_path}, size: {os.path.getsize(plot_path)/1024:.2f}KB")
|
166 |
+
|
167 |
+
# 确保返回正确的文件路径
|
168 |
+
status_message = "Analysis complete. "
|
169 |
+
|
170 |
+
if os.path.exists(output_path):
|
171 |
+
status_message += f"Video saved."
|
172 |
+
else:
|
173 |
+
status_message += "WARNING: Output video not found. "
|
174 |
+
|
175 |
+
if plot_path and os.path.exists(plot_path):
|
176 |
+
status_message += f" Time series plot generated."
|
177 |
+
else:
|
178 |
+
status_message += " WARNING: Failed to generate time series plot."
|
179 |
+
|
180 |
+
status_message += f" Results saved to: {csv_path}"
|
181 |
+
|
182 |
+
return output_path, plot_path, status_message
|
183 |
+
except Exception as e:
|
184 |
+
import traceback
|
185 |
+
traceback.print_exc()
|
186 |
+
return None, None, f"Processing error: {str(e)}"
|
187 |
|
188 |
+
# 创建Gradio界面
|
189 |
+
def create_interface():
|
190 |
+
with gr.Blocks(title="Mouse Struggle Analysis - Object Tracking") as app:
|
191 |
+
gr.Markdown("# Mouse Forced Swim Test Struggle Analysis (Object Tracking)")
|
192 |
+
|
193 |
+
with gr.Row():
|
194 |
+
with gr.Column(scale=1):
|
195 |
+
# 视频和模型选择
|
196 |
+
video_input = gr.Video(label="Input Video")
|
197 |
+
model_input = gr.File(label="Model File (.pt format recommended)")
|
198 |
+
|
199 |
+
# 参数设置
|
200 |
+
with gr.Row():
|
201 |
+
conf = gr.Slider(minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="Confidence Threshold")
|
202 |
+
iou = gr.Slider(minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="IoU Threshold")
|
203 |
+
|
204 |
+
with gr.Row():
|
205 |
+
max_det = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Max Detections")
|
206 |
+
threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="Struggle Threshold")
|
207 |
+
|
208 |
+
# 帧选择
|
209 |
+
start_frame = gr.Slider(minimum=0, maximum=999999, value=0, step=1, label="Start Frame")
|
210 |
+
end_frame = gr.Slider(minimum=0, maximum=999999, value=999999, step=1, label="End Frame")
|
211 |
+
|
212 |
+
# 预览按钮
|
213 |
+
preview_btn = gr.Button("Preview Frame")
|
214 |
+
|
215 |
+
# 开始分析
|
216 |
+
start_btn = gr.Button("Start Analysis", variant="primary")
|
217 |
+
|
218 |
+
with gr.Column(scale=2):
|
219 |
+
# 显示区域
|
220 |
+
with gr.Tab("Preview"):
|
221 |
+
# 图像预览
|
222 |
+
preview_image = gr.Image(label="Preview Image", type="numpy", height=400)
|
223 |
+
status_text = gr.Textbox(label="Status", interactive=False)
|
224 |
+
gr.Markdown("""
|
225 |
+
### Instructions:
|
226 |
+
1. Select a video and model file (.pt format segmentation model like yolov8n-seg.pt recommended)
|
227 |
+
2. Adjust parameters
|
228 |
+
- Confidence Threshold: Minimum confidence for object detection, lower values detect more potential objects
|
229 |
+
- IoU Threshold: For filtering overlapping detections
|
230 |
+
- Max Detections: Maximum number of objects to detect per frame
|
231 |
+
- Struggle Threshold: Minimum score to classify as struggle state
|
232 |
+
3. Set frame range
|
233 |
+
4. Click "Start Analysis" button
|
234 |
+
|
235 |
+
The system will automatically track mice and analyze their struggle behavior, no need to manually define regions
|
236 |
+
""")
|
237 |
+
|
238 |
+
with gr.Tab("Results"):
|
239 |
+
with gr.Row():
|
240 |
+
output_video = gr.Video(label="Analysis Result Video")
|
241 |
+
result_plot = gr.Image(label="Struggle Score Time Series")
|
242 |
+
|
243 |
+
result_status = gr.Textbox(label="Analysis Status", interactive=False)
|
244 |
+
|
245 |
+
# 绑定事件
|
246 |
+
video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
|
247 |
+
model_input.change(select_model, inputs=[model_input], outputs=[status_text])
|
248 |
+
|
249 |
+
preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
|
250 |
+
|
251 |
+
start_btn.click(
|
252 |
+
start_analysis,
|
253 |
+
inputs=[video_input, model_input, conf, iou, max_det, start_frame, end_frame, threshold],
|
254 |
+
outputs=[output_video, result_plot, result_status]
|
255 |
+
)
|
256 |
+
|
257 |
+
return app
|
258 |
|
259 |
+
# 启动应用
|
260 |
if __name__ == "__main__":
|
261 |
+
# 清除可能干扰的代理设置
|
262 |
+
if 'http_proxy' in os.environ:
|
263 |
+
del os.environ['http_proxy']
|
264 |
+
if 'https_proxy' in os.environ:
|
265 |
+
del os.environ['https_proxy']
|
266 |
+
if 'all_proxy' in os.environ:
|
267 |
+
del os.environ['all_proxy']
|
268 |
+
|
269 |
+
app = create_interface()
|
270 |
+
# 使用简化的启动配置
|
271 |
+
app.launch(server_name="127.0.0.1", server_port=7860, share=False)
|