Hakureirm commited on
Commit
72f5340
·
verified ·
1 Parent(s): 04b4cc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -5
app.py CHANGED
@@ -32,7 +32,7 @@ def login(username, password):
32
  return gr.update(visible=False), gr.update(visible=True), "登录成功"
33
  return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误"
34
 
35
- @spaces.GPU
36
  def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
37
  """
38
  处理视频并进行小鼠检测
@@ -91,10 +91,27 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
91
  frame_count = 0
92
  detection_info = []
93
 
 
 
 
 
94
  for r in results:
95
- # 获取绘制了预测结果的帧
96
  frame = r.plot()
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # 收集检测信息
99
  frame_info = {
100
  "frame": frame_count + 1,
@@ -124,21 +141,51 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
124
  video_writer.release()
125
 
126
  # 生成分析报告
 
 
 
 
 
 
 
 
127
  report = f"""视频分析报告:
128
  参数设置:
129
  - 置信度阈值: {conf_threshold:.2f}
130
  - 最大检测数量: {max_det}
131
  - 处理时长: {process_seconds}秒
 
132
  分析结果:
133
  - 处理帧数: {frame_count}
134
  - 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f}
135
  - 最大检测数: {max([info['count'] for info in detection_info])}
136
  - 最小检测数: {min([info['count'] for info in detection_info])}
 
137
  置信度分布:
138
- {np.histogram([float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']], bins=5)[0].tolist()}
139
  """
140
 
141
- return output_path, report
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # 创建 Gradio 界面
144
  with gr.Blocks() as demo:
@@ -182,6 +229,9 @@ with gr.Blocks() as demo:
182
 
183
  with gr.Column():
184
  video_output = gr.Video(label="检测结果")
 
 
 
185
  report_output = gr.Textbox(label="分析报告")
186
 
187
  gr.Markdown("""
@@ -212,7 +262,7 @@ with gr.Blocks() as demo:
212
  process_btn.click(
213
  fn=process_video,
214
  inputs=[video_input, process_seconds, conf_threshold, max_det],
215
- outputs=[video_output, report_output]
216
  )
217
 
218
  if __name__ == "__main__":
 
32
  return gr.update(visible=False), gr.update(visible=True), "登录成功"
33
  return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误"
34
 
35
+ @spaces.GPU(duration=300)
36
  def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
37
  """
38
  处理视频并进行小鼠检测
 
91
  frame_count = 0
92
  detection_info = []
93
 
94
+ # 用于记录轨迹和热图数据
95
+ all_positions = []
96
+ heatmap = np.zeros((height, width), dtype=np.float32)
97
+
98
  for r in results:
 
99
  frame = r.plot()
100
 
101
+ # 收集位置信息
102
+ if hasattr(r, 'keypoints') and r.keypoints is not None:
103
+ for kpts in r.keypoints:
104
+ if isinstance(kpts, torch.Tensor):
105
+ kpts = kpts.cpu().numpy()
106
+ # 使用第一个关键点(比如头部)作为位置参考
107
+ if len(kpts) > 0:
108
+ pos = kpts[0][:2] # 取x,y坐标
109
+ all_positions.append(pos)
110
+ # 更新热图
111
+ x, y = int(pos[0]), int(pos[1])
112
+ if 0 <= x < width and 0 <= y < height:
113
+ heatmap[y, x] += 1
114
+
115
  # 收集检测信息
116
  frame_info = {
117
  "frame": frame_count + 1,
 
141
  video_writer.release()
142
 
143
  # 生成分析报告
144
+ confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
145
+ hist, bins = np.histogram(confidences, bins=5)
146
+
147
+ confidence_report = "\n".join([
148
+ f"置信度 {bins[i]:.2f}-{bins[i+1]:.2f}: {hist[i]:3d}个检测 ({hist[i]/len(confidences)*100:.1f}%)"
149
+ for i in range(len(hist))
150
+ ])
151
+
152
  report = f"""视频分析报告:
153
  参数设置:
154
  - 置信度阈值: {conf_threshold:.2f}
155
  - 最大检测数量: {max_det}
156
  - 处理时长: {process_seconds}秒
157
+
158
  分析结果:
159
  - 处理帧数: {frame_count}
160
  - 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f}
161
  - 最大检测数: {max([info['count'] for info in detection_info])}
162
  - 最小检测数: {min([info['count'] for info in detection_info])}
163
+
164
  置信度分布:
165
+ {confidence_report}
166
  """
167
 
168
+ # 生成轨迹图
169
+ trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255 # 白色背景
170
+ points = np.array(all_positions, dtype=np.int32)
171
+ if len(points) > 1:
172
+ # 绘制轨迹线
173
+ cv2.polylines(trajectory_img, [points], False, (0, 0, 255), 2)
174
+ # 绘制起点和终点
175
+ cv2.circle(trajectory_img, tuple(points[0]), 5, (0, 255, 0), -1) # 绿色起点
176
+ cv2.circle(trajectory_img, tuple(points[-1]), 5, (255, 0, 0), -1) # 红色终点
177
+
178
+ # 生成热图
179
+ heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
180
+ heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
181
+
182
+ # 保存图像
183
+ trajectory_path = output_path.replace('.mp4', '_trajectory.png')
184
+ heatmap_path = output_path.replace('.mp4', '_heatmap.png')
185
+ cv2.imwrite(trajectory_path, trajectory_img)
186
+ cv2.imwrite(heatmap_path, heatmap_colored)
187
+
188
+ return output_path, trajectory_path, heatmap_path, report
189
 
190
  # 创建 Gradio 界面
191
  with gr.Blocks() as demo:
 
229
 
230
  with gr.Column():
231
  video_output = gr.Video(label="检测结果")
232
+ with gr.Row():
233
+ trajectory_output = gr.Image(label="运动轨迹")
234
+ heatmap_output = gr.Image(label="热力图")
235
  report_output = gr.Textbox(label="分析报告")
236
 
237
  gr.Markdown("""
 
262
  process_btn.click(
263
  fn=process_video,
264
  inputs=[video_input, process_seconds, conf_threshold, max_det],
265
+ outputs=[video_output, trajectory_output, heatmap_output, report_output]
266
  )
267
 
268
  if __name__ == "__main__":