Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -32,7 +32,7 @@ def login(username, password):
|
|
32 |
return gr.update(visible=False), gr.update(visible=True), "登录成功"
|
33 |
return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误"
|
34 |
|
35 |
-
@spaces.GPU
|
36 |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
|
37 |
"""
|
38 |
处理视频并进行小鼠检测
|
@@ -91,10 +91,27 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
91 |
frame_count = 0
|
92 |
detection_info = []
|
93 |
|
|
|
|
|
|
|
|
|
94 |
for r in results:
|
95 |
-
# 获取绘制了预测结果的帧
|
96 |
frame = r.plot()
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
# 收集检测信息
|
99 |
frame_info = {
|
100 |
"frame": frame_count + 1,
|
@@ -124,21 +141,51 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
124 |
video_writer.release()
|
125 |
|
126 |
# 生成分析报告
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
report = f"""视频分析报告:
|
128 |
参数设置:
|
129 |
- 置信度阈值: {conf_threshold:.2f}
|
130 |
- 最大检测数量: {max_det}
|
131 |
- 处理时长: {process_seconds}秒
|
|
|
132 |
分析结果:
|
133 |
- 处理帧数: {frame_count}
|
134 |
- 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f}
|
135 |
- 最大检测数: {max([info['count'] for info in detection_info])}
|
136 |
- 最小检测数: {min([info['count'] for info in detection_info])}
|
|
|
137 |
置信度分布:
|
138 |
-
{
|
139 |
"""
|
140 |
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# 创建 Gradio 界面
|
144 |
with gr.Blocks() as demo:
|
@@ -182,6 +229,9 @@ with gr.Blocks() as demo:
|
|
182 |
|
183 |
with gr.Column():
|
184 |
video_output = gr.Video(label="检测结果")
|
|
|
|
|
|
|
185 |
report_output = gr.Textbox(label="分析报告")
|
186 |
|
187 |
gr.Markdown("""
|
@@ -212,7 +262,7 @@ with gr.Blocks() as demo:
|
|
212 |
process_btn.click(
|
213 |
fn=process_video,
|
214 |
inputs=[video_input, process_seconds, conf_threshold, max_det],
|
215 |
-
outputs=[video_output, report_output]
|
216 |
)
|
217 |
|
218 |
if __name__ == "__main__":
|
|
|
32 |
return gr.update(visible=False), gr.update(visible=True), "登录成功"
|
33 |
return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误"
|
34 |
|
35 |
+
@spaces.GPU(duration=300)
|
36 |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
|
37 |
"""
|
38 |
处理视频并进行小鼠检测
|
|
|
91 |
frame_count = 0
|
92 |
detection_info = []
|
93 |
|
94 |
+
# 用于记录轨迹和热图数据
|
95 |
+
all_positions = []
|
96 |
+
heatmap = np.zeros((height, width), dtype=np.float32)
|
97 |
+
|
98 |
for r in results:
|
|
|
99 |
frame = r.plot()
|
100 |
|
101 |
+
# 收集位置信息
|
102 |
+
if hasattr(r, 'keypoints') and r.keypoints is not None:
|
103 |
+
for kpts in r.keypoints:
|
104 |
+
if isinstance(kpts, torch.Tensor):
|
105 |
+
kpts = kpts.cpu().numpy()
|
106 |
+
# 使用第一个关键点(比如头部)作为位置参考
|
107 |
+
if len(kpts) > 0:
|
108 |
+
pos = kpts[0][:2] # 取x,y坐标
|
109 |
+
all_positions.append(pos)
|
110 |
+
# 更新热图
|
111 |
+
x, y = int(pos[0]), int(pos[1])
|
112 |
+
if 0 <= x < width and 0 <= y < height:
|
113 |
+
heatmap[y, x] += 1
|
114 |
+
|
115 |
# 收集检测信息
|
116 |
frame_info = {
|
117 |
"frame": frame_count + 1,
|
|
|
141 |
video_writer.release()
|
142 |
|
143 |
# 生成分析报告
|
144 |
+
confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
|
145 |
+
hist, bins = np.histogram(confidences, bins=5)
|
146 |
+
|
147 |
+
confidence_report = "\n".join([
|
148 |
+
f"置信度 {bins[i]:.2f}-{bins[i+1]:.2f}: {hist[i]:3d}个检测 ({hist[i]/len(confidences)*100:.1f}%)"
|
149 |
+
for i in range(len(hist))
|
150 |
+
])
|
151 |
+
|
152 |
report = f"""视频分析报告:
|
153 |
参数设置:
|
154 |
- 置信度阈值: {conf_threshold:.2f}
|
155 |
- 最大检测数量: {max_det}
|
156 |
- 处理时长: {process_seconds}秒
|
157 |
+
|
158 |
分析结果:
|
159 |
- 处理帧数: {frame_count}
|
160 |
- 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f}
|
161 |
- 最大检测数: {max([info['count'] for info in detection_info])}
|
162 |
- 最小检测数: {min([info['count'] for info in detection_info])}
|
163 |
+
|
164 |
置信度分布:
|
165 |
+
{confidence_report}
|
166 |
"""
|
167 |
|
168 |
+
# 生成轨迹图
|
169 |
+
trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255 # 白色背景
|
170 |
+
points = np.array(all_positions, dtype=np.int32)
|
171 |
+
if len(points) > 1:
|
172 |
+
# 绘制轨迹线
|
173 |
+
cv2.polylines(trajectory_img, [points], False, (0, 0, 255), 2)
|
174 |
+
# 绘制起点和终点
|
175 |
+
cv2.circle(trajectory_img, tuple(points[0]), 5, (0, 255, 0), -1) # 绿色起点
|
176 |
+
cv2.circle(trajectory_img, tuple(points[-1]), 5, (255, 0, 0), -1) # 红色终点
|
177 |
+
|
178 |
+
# 生成热图
|
179 |
+
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
|
180 |
+
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
|
181 |
+
|
182 |
+
# 保存图像
|
183 |
+
trajectory_path = output_path.replace('.mp4', '_trajectory.png')
|
184 |
+
heatmap_path = output_path.replace('.mp4', '_heatmap.png')
|
185 |
+
cv2.imwrite(trajectory_path, trajectory_img)
|
186 |
+
cv2.imwrite(heatmap_path, heatmap_colored)
|
187 |
+
|
188 |
+
return output_path, trajectory_path, heatmap_path, report
|
189 |
|
190 |
# 创建 Gradio 界面
|
191 |
with gr.Blocks() as demo:
|
|
|
229 |
|
230 |
with gr.Column():
|
231 |
video_output = gr.Video(label="检测结果")
|
232 |
+
with gr.Row():
|
233 |
+
trajectory_output = gr.Image(label="运动轨迹")
|
234 |
+
heatmap_output = gr.Image(label="热力图")
|
235 |
report_output = gr.Textbox(label="分析报告")
|
236 |
|
237 |
gr.Markdown("""
|
|
|
262 |
process_btn.click(
|
263 |
fn=process_video,
|
264 |
inputs=[video_input, process_seconds, conf_threshold, max_det],
|
265 |
+
outputs=[video_output, trajectory_output, heatmap_output, report_output]
|
266 |
)
|
267 |
|
268 |
if __name__ == "__main__":
|