Hakureirm commited on
Commit
eddf713
·
verified ·
1 Parent(s): a0c5b0b

Create mouse_tracker.py

Browse files
Files changed (1) hide show
  1. mouse_tracker.py +572 -0
mouse_tracker.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import pandas as pd
6
+ import collections
7
+ import tempfile
8
+ from ultralytics import YOLO
9
+ import math
10
+
11
+ class MouseTrackerAnalyzer:
12
+ """基于Ultralytics对象跟踪的鼠强迫游泳实验挣扎度分析器"""
13
+ def __init__(self, model_path, history_size=5, conf=0.25, iou=0.45, max_det=20, verbose=False):
14
+ # 初始化模型和参数
15
+ self.model = YOLO(model_path, task="segment", verbose=False)
16
+ self.history_size = history_size
17
+ self.verbose = verbose # 控制日志输出级别
18
+ self.struggle_threshold = 0.3 # 挣扎阈值
19
+
20
+ # 跟踪相关参数
21
+ self.conf = conf # 置信度阈值
22
+ self.iou = iou # IOU阈值
23
+ self.max_det = max_det # 最大检测数量
24
+
25
+ # 预设16种固定颜色 (BGR顺序)
26
+ self.colors = [
27
+ (255, 0, 0), # 红
28
+ (0, 255, 0), # 绿
29
+ (0, 0, 255), # 蓝
30
+ (255, 255, 0), # 青
31
+ (255, 0, 255), # 洋红
32
+ (0, 255, 255), # 黄
33
+ (128, 0, 0), # 深红
34
+ (128, 0, 128), # 紫
35
+ (0, 128, 128), # 青绿
36
+ (192, 192, 192),# 银
37
+ (128, 128, 128),# 灰
38
+ (255, 128, 0), # 橙
39
+ (255, 0, 128), # 粉
40
+ (0, 128, 255), # 浅蓝
41
+ (128, 255, 0), # 黄绿
42
+ (0, 255, 128) # 浅绿
43
+ ]
44
+ # 追踪相关
45
+ self.prev_masks = {} # 上一帧各 ID 二值掩码
46
+ self.histories = {} # 各 ID 分数历史队列
47
+ self.track_ids = set() # 所有被跟踪的ID
48
+
49
+ # 视频处理状态
50
+ self.cap = None
51
+ self.writer = None
52
+ self.frame_id = 0
53
+ self.results = [] # 存储每帧结果
54
+ self.start_frame = 0
55
+ self.end_frame = 0
56
+
57
+ def init_video(self, video_path, output_path=None, start_frame=0, end_frame=None):
58
+ """初始化视频处理"""
59
+ # 打开视频并初始化写出器
60
+ self.cap = cv2.VideoCapture(video_path)
61
+ if not self.cap.isOpened():
62
+ raise IOError(f"无法打开视频 {video_path}")
63
+
64
+ # 获取视频属性
65
+ width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
66
+ height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
67
+ fps = self.cap.get(cv2.CAP_PROP_FPS) or 30
68
+ self.fps = max(fps, 1.0) # 保存帧率到实例变量,确保至少为1
69
+ total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
70
+
71
+ if self.verbose:
72
+ print(f"视频尺寸: {width}x{height}, 帧率: {fps}, 总帧数: {total_frames}")
73
+
74
+ # 设置帧范围
75
+ self.start_frame = start_frame
76
+ self.end_frame = end_frame if end_frame is not None else total_frames - 1
77
+
78
+ # 确保帧范围有效
79
+ if self.start_frame < 0:
80
+ self.start_frame = 0
81
+ if self.end_frame >= total_frames:
82
+ self.end_frame = total_frames - 1
83
+ if self.start_frame > self.end_frame:
84
+ self.start_frame, self.end_frame = self.end_frame, self.start_frame
85
+
86
+ # 将视频定位到起始帧
87
+ if self.start_frame > 0:
88
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.start_frame)
89
+
90
+ # 如果输出为视频则初始化 VideoWriter
91
+ if output_path and output_path.lower().endswith(('.mp4', '.avi')):
92
+ # 使用标准编码器
93
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
94
+ # 创建VideoWriter
95
+ self.writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
96
+ if self.writer.isOpened():
97
+ print(f"成功创建输出视频: {output_path}, 尺寸: {width}x{height}")
98
+ else:
99
+ print(f"警告: 无法创建输出视频 {output_path}")
100
+
101
+ # 重置状态
102
+ self.frame_id = self.start_frame
103
+ self.results = []
104
+ self.prev_masks.clear()
105
+ self.histories.clear()
106
+ self.track_ids.clear()
107
+
108
+ if self.verbose:
109
+ print(f"视频初始化完成: 总帧数 {total_frames}, 分析范围 {self.start_frame}-{self.end_frame}")
110
+
111
+ return total_frames, self.start_frame, self.end_frame
112
+
113
+ def process_frame(self, frame, frame_id):
114
+ """处理单帧,返回可视化帧和本帧结果列表"""
115
+ if self.verbose and frame_id % 10 == 0:
116
+ print(f"process_frame: 处理帧 {frame_id}")
117
+
118
+ try:
119
+ # 使用YOLO模型跟踪对象
120
+ results = self.model.track(
121
+ frame,
122
+ persist=True, # 保持跟踪ID的持久性
123
+ conf=self.conf,
124
+ iou=self.iou,
125
+ max_det=self.max_det,
126
+ verbose=False
127
+ )
128
+
129
+ # 检查是否有检测结果
130
+ frame_results = []
131
+
132
+ if results[0].boxes is None or len(results[0].boxes) == 0:
133
+ if self.verbose and frame_id % 50 == 0:
134
+ print("没有检测到任何对象")
135
+ return frame.copy(), []
136
+
137
+ # 处理检测结果
138
+ if hasattr(results[0], 'masks') and results[0].masks is not None:
139
+ # 获取掩码和跟踪ID
140
+ masks = results[0].masks.data.cpu().numpy()
141
+ track_ids = results[0].boxes.id
142
+
143
+ if track_ids is None:
144
+ if self.verbose and frame_id % 50 == 0:
145
+ print("没有获取到跟踪ID")
146
+ return frame.copy(), []
147
+
148
+ track_ids = track_ids.int().cpu().numpy()
149
+
150
+ if self.verbose and frame_id % 50 == 0:
151
+ print(f"检测到 {len(masks)} 个掩码,{len(track_ids)} 个跟踪ID")
152
+
153
+ # 更新跟踪ID集合
154
+ for track_id in track_ids:
155
+ self.track_ids.add(int(track_id))
156
+
157
+ # 处理每个跟踪对象
158
+ for i, (mask, track_id) in enumerate(zip(masks, track_ids)):
159
+ track_id = int(track_id)
160
+
161
+ # 二值化掩码
162
+ bin_mask = (mask > 0.2).astype(np.uint8)
163
+
164
+ # 应用形态学操作清理掩码
165
+ kernel = np.ones((5,5), np.uint8)
166
+ bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel)
167
+
168
+ # 调整掩码尺寸到与原始帧相同
169
+ if bin_mask.shape != (frame.shape[0], frame.shape[1]):
170
+ bin_mask = cv2.resize(bin_mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
171
+
172
+ # 计算挣扎度
173
+ if track_id in self.prev_masks:
174
+ prev_mask = self.prev_masks[track_id]
175
+ # 确保比较的掩码尺寸一致
176
+ if prev_mask.shape != bin_mask.shape:
177
+ prev_mask = cv2.resize(prev_mask, (bin_mask.shape[1], bin_mask.shape[0]), interpolation=cv2.INTER_NEAREST)
178
+ inter = np.logical_and(prev_mask > 0, bin_mask > 0).sum()
179
+ union = np.logical_or(prev_mask > 0, bin_mask > 0).sum()
180
+ iou = inter / union if union > 0 else 0
181
+ score = 1 - iou
182
+ if self.verbose and frame_id % 50 == 0:
183
+ print(f"跟踪ID {track_id} 挣扎分数: {score:.4f} (IoU: {iou:.4f})")
184
+ else:
185
+ score = 0.0
186
+ if self.verbose and frame_id % 50 == 0:
187
+ print(f"跟踪ID {track_id} 初始帧,分数为0")
188
+
189
+ # 保存当前掩码和历史
190
+ self.prev_masks[track_id] = bin_mask
191
+
192
+ if track_id not in self.histories:
193
+ self.histories[track_id] = collections.deque(maxlen=self.history_size)
194
+ self.histories[track_id].append(score)
195
+
196
+ # 计算挣扎状态
197
+ is_struggling = score >= self.struggle_threshold
198
+
199
+ # 计算质心
200
+ ys, xs = np.where(bin_mask > 0)
201
+ if len(xs) > 0:
202
+ centroid = (int(xs.mean()), int(ys.mean()))
203
+ else:
204
+ # 如果掩码为空,使用边界框中心点
205
+ box = results[0].boxes[i].xyxy.cpu().numpy()[0]
206
+ centroid = (int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2))
207
+
208
+ # 添加到帧结果
209
+ frame_results.append({
210
+ 'id': track_id,
211
+ 'score': float(score),
212
+ 'centroid': centroid,
213
+ 'is_struggling': is_struggling
214
+ })
215
+ else:
216
+ if self.verbose and frame_id % 50 == 0:
217
+ print("没有检测到任何掩码")
218
+ return frame.copy(), []
219
+
220
+ # 可视化 - 在这里创建最终的标注帧
221
+ annotated = frame.copy()
222
+
223
+ # 绘制掩码和ID
224
+ for result in frame_results:
225
+ track_id = result['id']
226
+ color = self.colors[track_id % len(self.colors)]
227
+
228
+ # 绘制掩码
229
+ if track_id in self.prev_masks:
230
+ mask = self.prev_masks[track_id]
231
+ # 确保掩码与帧大小一致
232
+ if mask.shape != (frame.shape[0], frame.shape[1]):
233
+ mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
234
+ mask_overlay = np.zeros_like(frame)
235
+ mask_overlay[mask > 0] = color
236
+
237
+ # 使用更精确的掩码边缘
238
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
239
+ cv2.drawContours(annotated, contours, -1, color, 2)
240
+
241
+ # 使用addWeighted进行混合
242
+ cv2.addWeighted(annotated, 1.0, mask_overlay, 0.4, 0, annotated)
243
+
244
+ # 在质心位置绘制ID和挣扎状态
245
+ centroid = result['centroid']
246
+ status_text = "Struggle" if result['is_struggling'] else "Static"
247
+ cv2.putText(annotated, f"ID:{track_id} {status_text}",
248
+ (centroid[0], centroid[1]),
249
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
250
+
251
+ # 在顶部创建黑色半透明条,显示总结信息
252
+ cv2.rectangle(annotated, (0, 0), (frame.shape[1], 40), (0, 0, 0), -1)
253
+
254
+ # 计算挣扎中的老鼠数量
255
+ struggling_count = sum(1 for r in frame_results if r['is_struggling'])
256
+ total_count = len(frame_results)
257
+
258
+ # 显示统计信息
259
+ cv2.putText(annotated, f"Total: {total_count} Struggling: {struggling_count}",
260
+ (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
261
+
262
+ # 最后,由于OpenCV以BGR格式工作,但可能需要RGB格式,
263
+ # 确保返回的图像是BGR格式(视频写入用BGR,显示用RGB)
264
+ if annotated.dtype != np.uint8:
265
+ annotated = annotated.astype(np.uint8)
266
+
267
+ return annotated, frame_results
268
+
269
+ except Exception as e:
270
+ import traceback
271
+ if self.verbose:
272
+ print(f"处理帧时出错: {str(e)}")
273
+ traceback.print_exc()
274
+ # 返回原始帧和空结果
275
+ return frame.copy(), []
276
+
277
+ def process_video(self, video_path, output_path=None, start_frame=0, end_frame=None, callback=None):
278
+ """处理整段视频,可选的回调函数用于更新进度"""
279
+ # 初始化视频
280
+ total_frames, start, end = self.init_video(video_path, output_path, start_frame, end_frame)
281
+ self.results = [] # 确保结果列表被清空
282
+
283
+ frame_id = start
284
+ processed_frames = 0
285
+ frames_to_process = end - start + 1
286
+ last_progress = -1
287
+
288
+ # 临时保存一帧,用于调试
289
+ debug_frame_saved = False
290
+
291
+ while frame_id <= end:
292
+ ret, frame = self.cap.read()
293
+ if not ret:
294
+ break
295
+
296
+ # 处理当前帧
297
+ annotated, frame_res = self.process_frame(frame, frame_id)
298
+ self.results.append(frame_res) # 将当前帧结果存入results列表
299
+
300
+ # 保存第一帧用于调试
301
+ if not debug_frame_saved and len(frame_res) > 0:
302
+ debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
303
+ cv2.imwrite(debug_frame_path, annotated)
304
+ print(f"调试: 保存了标注帧到 {debug_frame_path}")
305
+ debug_frame_saved = True
306
+
307
+ # 写入输出视频
308
+ if self.writer:
309
+ # 确保帧是BGR格式
310
+ if len(annotated.shape) == 3 and annotated.shape[2] == 3:
311
+ # 如果需要,将RGB转换回BGR (OpenCV使用BGR)
312
+ # 默认应该已经是BGR,但为了确保
313
+ if frame_id == start:
314
+ print(f"调试: 写入标注帧到视频,形状: {annotated.shape}")
315
+
316
+ try:
317
+ self.writer.write(annotated)
318
+ except Exception as e:
319
+ print(f"调试: 写入帧到视频时出错: {str(e)}")
320
+ import traceback
321
+ traceback.print_exc()
322
+
323
+ # 更新进度和回调
324
+ processed_frames += 1
325
+ progress = int(100 * processed_frames / frames_to_process)
326
+
327
+ if progress != last_progress and callback:
328
+ callback(progress, annotated, frame_res)
329
+ last_progress = progress
330
+
331
+ frame_id += 1
332
+
333
+ # 释放资源
334
+ self.cap.release()
335
+ if self.writer:
336
+ self.writer.release()
337
+ print(f"调试: 视频写入完成,保存到: {output_path}")
338
+
339
+ return self.results
340
+
341
+ def save_results(self, csv_path):
342
+ """导出分析结果到 CSV"""
343
+ import csv
344
+ with open(csv_path, 'w', newline='') as f:
345
+ writer = csv.writer(f)
346
+ writer.writerow(['frame_id', 'mouse_id', 'score', 'is_struggling'])
347
+ for fid, frs in enumerate(self.results):
348
+ for fr in frs:
349
+ writer.writerow([
350
+ fid + self.start_frame,
351
+ fr['id'],
352
+ f"{fr['score']:.4f}",
353
+ 1 if fr.get('is_struggling', False) else 0
354
+ ])
355
+
356
+ def generate_time_series_plot(self, threshold=None):
357
+ """生成时序图分析"""
358
+ try:
359
+ print(f"Starting to generate time series plot with {len(self.results)} frames of data")
360
+
361
+ if not self.results or len(self.results) < 10:
362
+ print("Not enough data for time series plot (need at least 10 frames)")
363
+ return None
364
+
365
+ # 使用传入的阈值或默认阈值
366
+ if threshold is None:
367
+ threshold = self.struggle_threshold
368
+
369
+ # 使用保存的帧率,确保不会出现除以零的情况
370
+ fps = getattr(self, 'fps', None)
371
+ if fps is None or fps <= 0:
372
+ fps = 30 # 使用默认帧率
373
+ print(f"Warning: Invalid frame rate detected, using default: {fps} fps")
374
+ else:
375
+ print(f"Using frame rate: {fps} fps")
376
+
377
+ # 处理数据
378
+ frames = []
379
+ mouse_data = {}
380
+ mouse_positions = {} # 用于存储每只老鼠的平均X坐标
381
+
382
+ for frame_id, frame_results in enumerate(self.results):
383
+ frames.append(frame_id + self.start_frame) # 使用真实帧号
384
+ for result in frame_results:
385
+ mouse_id = result['id']
386
+ if mouse_id not in mouse_data:
387
+ mouse_data[mouse_id] = {'frames': [], 'seconds': [], 'scores': [], 'struggling': []}
388
+ mouse_positions[mouse_id] = [] # 初始化X坐标列表
389
+
390
+ frame_num = frame_id + self.start_frame
391
+ second = frame_num / fps # 转换为秒
392
+
393
+ mouse_data[mouse_id]['frames'].append(frame_num)
394
+ mouse_data[mouse_id]['seconds'].append(second)
395
+ mouse_data[mouse_id]['scores'].append(result['score'])
396
+ mouse_data[mouse_id]['struggling'].append(1 if result.get('is_struggling', False) else 0)
397
+
398
+ # 记录质心的X坐标
399
+ if 'centroid' in result:
400
+ mouse_positions[mouse_id].append(result['centroid'][0])
401
+
402
+ print(f"Processed data for {len(mouse_data)} mice")
403
+ if not mouse_data:
404
+ print("No valid mouse data to plot")
405
+ return None
406
+
407
+ # 计算每只老鼠的平均X坐标并按从左到右排序
408
+ avg_positions = {}
409
+ for mouse_id, positions in mouse_positions.items():
410
+ if positions:
411
+ avg_positions[mouse_id] = sum(positions) / len(positions)
412
+ else:
413
+ avg_positions[mouse_id] = float('inf') # 如果没有位置数据,放到最后
414
+
415
+ # 按从左到右排序老鼠ID
416
+ sorted_mice = sorted(mouse_data.keys(), key=lambda mid: avg_positions.get(mid, float('inf')))
417
+ print(f"Mice sorted from left to right: {sorted_mice}")
418
+
419
+ # 对数据进行平滑处理
420
+ def smooth_data(data, window_size=5):
421
+ """使用移动平均平滑数据"""
422
+ if len(data) < window_size:
423
+ return data
424
+ smoothed = []
425
+ for i in range(len(data)):
426
+ start = max(0, i - window_size // 2)
427
+ end = min(len(data), i + window_size // 2 + 1)
428
+ window = data[start:end]
429
+ smoothed.append(sum(window) / len(window))
430
+ return smoothed
431
+
432
+ # 创建子图
433
+ num_mice = len(mouse_data)
434
+ fig, axes = plt.subplots(num_mice, 1, figsize=(12, 4*num_mice), sharex=True)
435
+
436
+ # 如果只有一只鼠,确保axes是列表
437
+ if num_mice == 1:
438
+ axes = [axes]
439
+
440
+ # 绘制每只老鼠的挣扎得分曲线,按从左到右的顺序
441
+ for idx, mouse_id in enumerate(sorted_mice):
442
+ data = mouse_data[mouse_id]
443
+ ax = axes[idx]
444
+
445
+ # 平滑数据
446
+ smoothed_scores = smooth_data(data['scores'], window_size=5)
447
+
448
+ # 绘制曲线
449
+ ax.plot(data['seconds'], smoothed_scores, label=f"Smoothed", color='blue', linewidth=2)
450
+ ax.plot(data['seconds'], data['scores'], label=f"Raw", color='lightblue', alpha=0.5, linewidth=1)
451
+
452
+ # 标记挣扎区域
453
+ for i, is_struggling in enumerate(data['struggling']):
454
+ if is_struggling:
455
+ ax.axvspan(data['seconds'][i]-0.5/fps, data['seconds'][i]+0.5/fps, alpha=0.1, color='red')
456
+
457
+ # 绘制阈值线
458
+ ax.axhline(y=threshold, color='r', linestyle='--', label=f"Threshold ({threshold:.2f})")
459
+
460
+ # 设置图表
461
+ ax.set_ylabel('Struggle Score')
462
+ position_text = f"(Position: Left #{sorted_mice.index(mouse_id)+1})" if mouse_id in avg_positions else ""
463
+ ax.set_title(f'Mouse {mouse_id} Struggle Score {position_text}')
464
+ ax.legend(loc='upper right')
465
+ ax.grid(True)
466
+
467
+ # 设置Y轴范围0-1
468
+ ax.set_ylim(-0.05, 1.05)
469
+
470
+ # 设置共享的X轴标签
471
+ axes[-1].set_xlabel('Time (seconds)')
472
+
473
+ # 动态调整x轴范围,精确到0.1秒
474
+ if frames:
475
+ start_time = self.start_frame / fps
476
+ end_time = max(frames) / fps
477
+ # 扩展一点范围以便更好地显示
478
+ axes[-1].set_xlim(start_time, end_time)
479
+
480
+ # 设置次要刻度(细网格线)
481
+ tick_interval = 0.1 # 保持0.1秒的细网格
482
+ minor_ticks = np.arange(start_time, end_time + tick_interval, tick_interval)
483
+ axes[-1].set_xticks(minor_ticks, minor=True)
484
+
485
+ # 设置主要刻度(标签和粗网格线)- 整秒
486
+ major_start = math.ceil(start_time)
487
+ major_end = math.floor(end_time)
488
+ major_ticks = np.arange(major_start, major_end + 1, 1.0) # 整秒刻度
489
+ axes[-1].set_xticks(major_ticks)
490
+ axes[-1].set_xticklabels([f"{int(t)}" for t in major_ticks]) # 整数秒标签
491
+
492
+ # 设置网格
493
+ axes[-1].grid(True, which='both')
494
+ axes[-1].grid(which='minor', alpha=0.2)
495
+ axes[-1].grid(which='major', alpha=0.5)
496
+
497
+ plt.tight_layout()
498
+
499
+ # 保存图表到临时文件并返回路径
500
+ temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
501
+ plt.savefig(temp_file.name, dpi=150, bbox_inches='tight')
502
+ plt.close()
503
+
504
+ print(f"Time series plot saved to: {temp_file.name}")
505
+ return temp_file.name
506
+
507
+ except Exception as e:
508
+ import traceback
509
+ print(f"Error generating time series plot: {str(e)}")
510
+ traceback.print_exc()
511
+ return None
512
+
513
+ if __name__ == "__main__":
514
+ import argparse
515
+
516
+ parser = argparse.ArgumentParser(description="鼠强迫游泳实验挣扎度分析")
517
+ parser.add_argument('--video', type=str, required=True, help='输入视频路径')
518
+ parser.add_argument('--model', type=str, required=True, help='模型文件路径')
519
+ parser.add_argument('--output', type=str, help='输出视频路径')
520
+ parser.add_argument('--csv', type=str, help='输出CSV结果路径')
521
+ parser.add_argument('--conf', type=float, default=0.25, help='置信度阈值')
522
+ parser.add_argument('--iou', type=float, default=0.45, help='IOU阈值')
523
+ parser.add_argument('--max-det', type=int, default=20, help='最大检测数量')
524
+ parser.add_argument('--threshold', type=float, default=0.3, help='挣扎阈值')
525
+ parser.add_argument('--start', type=int, default=0, help='起始帧')
526
+ parser.add_argument('--end', type=int, default=None, help='结束帧')
527
+ parser.add_argument('--verbose', action='store_true', help='详细输出')
528
+
529
+ args = parser.parse_args()
530
+
531
+ # 设置输出路径
532
+ if not args.output:
533
+ video_name = os.path.splitext(os.path.basename(args.video))[0]
534
+ args.output = os.path.join(os.path.dirname(args.video), f"{video_name}_out.mp4")
535
+
536
+ if not args.csv:
537
+ video_name = os.path.splitext(os.path.basename(args.video))[0]
538
+ args.csv = os.path.join(os.path.dirname(args.video), f"{video_name}_results.csv")
539
+
540
+ # 创建分析器并处理
541
+ analyzer = MouseTrackerAnalyzer(
542
+ model_path=args.model,
543
+ conf=args.conf,
544
+ iou=args.iou,
545
+ max_det=args.max_det,
546
+ verbose=args.verbose
547
+ )
548
+ analyzer.struggle_threshold = args.threshold
549
+
550
+ # 进度回调函数
551
+ def progress_callback(progress, frame, results):
552
+ print(f"处理进度: {progress}%, 检测到 {len(results)} 个对象")
553
+
554
+ # 处理视频
555
+ analyzer.process_video(
556
+ video_path=args.video,
557
+ output_path=args.output,
558
+ start_frame=args.start,
559
+ end_frame=args.end,
560
+ callback=progress_callback
561
+ )
562
+
563
+ # 保存结果
564
+ analyzer.save_results(args.csv)
565
+
566
+ # 生成分析图表
567
+ plot_path = analyzer.generate_time_series_plot()
568
+ if plot_path:
569
+ print(f"挣扎度时序分析图已保存到: {plot_path}")
570
+
571
+ print(f"分析完成,视频已保存到: {args.output}")
572
+ print(f"结果数据已保存到: {args.csv}")