Hakureirm commited on
Commit
0b312c9
·
verified ·
1 Parent(s): d192c81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -4
app.py CHANGED
@@ -16,7 +16,7 @@ APP_PASSWORD = os.getenv("APP_PASSWORD", "default_password") # 从环境变量
16
 
17
  app = FastAPI()
18
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
- model = YOLO('kunin-mice-pose.v0.1.2.pt')
20
 
21
  # 定义认证状态
22
  class AuthState:
@@ -173,10 +173,95 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
173
  {confidence_report}
174
  """
175
 
176
- # 生成轨迹图
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255 # 白色背景
178
  points = np.array(all_positions, dtype=np.int32)
179
  if len(points) > 1:
 
 
 
 
180
  # 绘制轨迹线,使用渐变色
181
  for i in range(len(points) - 1):
182
  ratio = i / (len(points) - 1)
@@ -269,7 +354,7 @@ with gr.Blocks() as demo:
269
  ### 使用说明
270
  1. 上传视频文件
271
  2. 设置处理参数:
272
- - 处理时长:需要分析的视频时长(秒)
273
  - 置信度阈值:检测的置信度要求(越高越严格)
274
  - 最大检测数量:每帧最多检测的目标数量
275
  3. 等待处理完成
@@ -297,4 +382,4 @@ with gr.Blocks() as demo:
297
  )
298
 
299
  if __name__ == "__main__":
300
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
16
 
17
  app = FastAPI()
18
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+ model = YOLO('kunin-mice-pose.v0.1.3.pt')
20
 
21
  # 定义认证状态
22
  class AuthState:
 
173
  {confidence_report}
174
  """
175
 
176
+ # 在生成轨迹图之前,添加异常点过滤和轨迹平滑
177
+ def filter_trajectories(positions, width, height, max_jump_distance=100):
178
+ """
179
+ 过滤轨迹中的异常点
180
+ Args:
181
+ positions: 位置列表 [[x1,y1], [x2,y2],...]
182
+ width: 视频宽度
183
+ height: 视频高度
184
+ max_jump_distance: 允许的最大跳跃距离
185
+ """
186
+ if len(positions) < 3:
187
+ return positions
188
+
189
+ filtered_positions = []
190
+ last_valid_pos = None
191
+
192
+ for i, pos in enumerate(positions):
193
+ x, y = pos
194
+
195
+ # 检查点是否在有效范围内
196
+ if not (0 <= x < width and 0 <= y < height):
197
+ continue
198
+
199
+ # 第一个有效点
200
+ if last_valid_pos is None:
201
+ filtered_positions.append(pos)
202
+ last_valid_pos = pos
203
+ continue
204
+
205
+ # 计算与上一个点的距离
206
+ distance = np.sqrt((x - last_valid_pos[0])**2 + (y - last_valid_pos[1])**2)
207
+
208
+ if distance > max_jump_distance:
209
+ # 如果距离太大,进行插值
210
+ if len(filtered_positions) > 0:
211
+ # 寻找下一个有效点
212
+ next_valid_pos = None
213
+ for next_pos in positions[i:]:
214
+ nx, ny = next_pos
215
+ if (0 <= nx < width and 0 <= ny < height):
216
+ next_distance = np.sqrt((nx - last_valid_pos[0])**2 + (ny - last_valid_pos[1])**2)
217
+ if next_distance <= max_jump_distance:
218
+ next_valid_pos = next_pos
219
+ break
220
+
221
+ if next_valid_pos is not None:
222
+ # 线性插值
223
+ steps = max(2, int(distance / max_jump_distance))
224
+ for j in range(1, steps):
225
+ alpha = j / steps
226
+ interp_x = int(last_valid_pos[0] * (1 - alpha) + next_valid_pos[0] * alpha)
227
+ interp_y = int(last_valid_pos[1] * (1 - alpha) + next_valid_pos[1] * alpha)
228
+ filtered_positions.append([interp_x, interp_y])
229
+ filtered_positions.append(next_valid_pos)
230
+ last_valid_pos = next_valid_pos
231
+ else:
232
+ filtered_positions.append(pos)
233
+ last_valid_pos = pos
234
+
235
+ # 使用移动平均平滑轨迹
236
+ window_size = 5
237
+ smoothed_positions = []
238
+
239
+ if len(filtered_positions) >= window_size:
240
+ # 添加开始的点
241
+ smoothed_positions.extend(filtered_positions[:window_size//2])
242
+
243
+ # 平滑中间的点
244
+ for i in range(window_size//2, len(filtered_positions) - window_size//2):
245
+ window = filtered_positions[i-window_size//2:i+window_size//2+1]
246
+ smoothed_x = int(np.mean([p[0] for p in window]))
247
+ smoothed_y = int(np.mean([p[1] for p in window]))
248
+ smoothed_positions.append([smoothed_x, smoothed_y])
249
+
250
+ # 添加结束的点
251
+ smoothed_positions.extend(filtered_positions[-window_size//2:])
252
+ else:
253
+ smoothed_positions = filtered_positions
254
+
255
+ return smoothed_positions
256
+
257
+ # 修改轨迹图生成部分
258
  trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255 # 白色背景
259
  points = np.array(all_positions, dtype=np.int32)
260
  if len(points) > 1:
261
+ # 过滤和平滑轨迹
262
+ filtered_points = filter_trajectories(points.tolist(), width, height)
263
+ points = np.array(filtered_points, dtype=np.int32)
264
+
265
  # 绘制轨迹线,使用渐变色
266
  for i in range(len(points) - 1):
267
  ratio = i / (len(points) - 1)
 
354
  ### 使用说明
355
  1. 上传视频文件
356
  2. 设置处理参数:
357
+ - 处理时长:需要分析的视频时长(秒��
358
  - 置信度阈值:检测的置信度要求(越高越严格)
359
  - 最大检测数量:每帧最多检测的目标数量
360
  3. 等待处理完成
 
382
  )
383
 
384
  if __name__ == "__main__":
385
+ demo.launch(server_name="0.0.0.0", server_port=7860)