Hakureirm commited on
Commit
11c09d2
·
verified ·
1 Parent(s): d080468

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -13
app.py CHANGED
@@ -100,25 +100,49 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
100
 
101
  # 收集位置信息
102
  if hasattr(r, 'keypoints') and r.keypoints is not None:
 
 
 
 
103
  for kpts in r.keypoints:
104
  if isinstance(kpts, torch.Tensor):
105
  kpts = kpts.cpu().numpy()
106
- # 打印关键点格式以便调试
107
- # print(f"Keypoints shape: {kpts.shape}")
108
 
109
  # 确保关键点数据是正确的格式
110
  if isinstance(kpts, np.ndarray):
111
- if len(kpts.shape) == 2: # [n_points, 2/3]
112
- if kpts.shape[0] > 0: # 确保有关键点
113
- # 使用第一个关键点
114
- if kpts.shape[1] >= 2: # 确保有x,y坐标
115
- x, y = kpts[0, 0], kpts[0, 1]
116
  if isinstance(x, (int, float)) and isinstance(y, (int, float)):
117
  x, y = int(x), int(y)
118
  all_positions.append([x, y])
119
- # 更新热图
120
  if 0 <= x < width and 0 <= y < height:
121
- heatmap[y, x] += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  # 收集检测信息
124
  frame_info = {
@@ -177,11 +201,19 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
177
  trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255 # 白色背景
178
  points = np.array(all_positions, dtype=np.int32)
179
  if len(points) > 1:
180
- # 绘制轨迹线
181
- cv2.polylines(trajectory_img, [points], False, (0, 0, 255), 2)
 
 
 
 
 
 
 
 
182
  # 绘制起点和终点
183
- cv2.circle(trajectory_img, tuple(points[0]), 5, (0, 255, 0), -1) # 绿色起点
184
- cv2.circle(trajectory_img, tuple(points[-1]), 5, (255, 0, 0), -1) # 红色终点
185
 
186
  # 生成热图
187
  heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
 
100
 
101
  # 收集位置信息
102
  if hasattr(r, 'keypoints') and r.keypoints is not None:
103
+ # 打印关键点对象信息
104
+ print(f"Keypoints type: {type(r.keypoints)}")
105
+ print(f"Keypoints data: {r.keypoints}")
106
+
107
  for kpts in r.keypoints:
108
  if isinstance(kpts, torch.Tensor):
109
  kpts = kpts.cpu().numpy()
110
+ print(f"Single keypoints shape: {kpts.shape}") # 打印形状
111
+ print(f"Single keypoints data: {kpts}") # 打印数据
112
 
113
  # 确保关键点数据是正确的格式
114
  if isinstance(kpts, np.ndarray):
115
+ if len(kpts.shape) == 3: # [num_objects, num_keypoints, 3]
116
+ for obj_kpts in kpts:
117
+ if len(obj_kpts) > 0:
118
+ x, y = obj_kpts[0][:2] # 使用第一个关键点的x,y坐标
 
119
  if isinstance(x, (int, float)) and isinstance(y, (int, float)):
120
  x, y = int(x), int(y)
121
  all_positions.append([x, y])
122
+ # 更新热图,使用高斯核来平滑
123
  if 0 <= x < width and 0 <= y < height:
124
+ # 创建高斯核心点
125
+ sigma = 5 # 调整这个值来改变热点大小
126
+ kernel_size = 15 # 必须是奇数
127
+ temp_heatmap = np.zeros((height, width), dtype=np.float32)
128
+ temp_heatmap[y, x] = 1
129
+ temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma)
130
+ heatmap += temp_heatmap
131
+ elif len(kpts.shape) == 2: # [num_keypoints, 3]
132
+ if len(kpts) > 0:
133
+ x, y = kpts[0][:2] # 使用第一个关键点的x,y坐标
134
+ if isinstance(x, (int, float)) and isinstance(y, (int, float)):
135
+ x, y = int(x), int(y)
136
+ all_positions.append([x, y])
137
+ # 更新热图,使用高斯核来平滑
138
+ if 0 <= x < width and 0 <= y < height:
139
+ # 创建高斯核心点
140
+ sigma = 5 # 调整这个值来改变热点大小
141
+ kernel_size = 15 # 必须是奇数
142
+ temp_heatmap = np.zeros((height, width), dtype=np.float32)
143
+ temp_heatmap[y, x] = 1
144
+ temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma)
145
+ heatmap += temp_heatmap
146
 
147
  # 收集检测信息
148
  frame_info = {
 
201
  trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255 # 白色背景
202
  points = np.array(all_positions, dtype=np.int32)
203
  if len(points) > 1:
204
+ # 绘制轨迹线,使用渐变色
205
+ for i in range(len(points) - 1):
206
+ ratio = i / (len(points) - 1)
207
+ color = (
208
+ int((1 - ratio) * 255), # B
209
+ 0, # G
210
+ int(ratio * 255) # R
211
+ )
212
+ cv2.line(trajectory_img, tuple(points[i]), tuple(points[i + 1]), color, 2)
213
+
214
  # 绘制起点和终点
215
+ cv2.circle(trajectory_img, tuple(points[0]), 8, (0, 255, 0), -1) # 绿色起点
216
+ cv2.circle(trajectory_img, tuple(points[-1]), 8, (255, 0, 0), -1) # 红色终点
217
 
218
  # 生成热图
219
  heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)