Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import tempfile
|
|
12 |
import imageio
|
13 |
from tqdm import tqdm
|
14 |
import logging
|
|
|
15 |
|
16 |
# 新增: 配置logging
|
17 |
logging.basicConfig(
|
@@ -49,7 +50,7 @@ def login(username, password):
|
|
49 |
@spaces.GPU(duration=300)
|
50 |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
|
51 |
"""
|
52 |
-
|
53 |
Args:
|
54 |
video_path: 输入视频路径
|
55 |
process_seconds: 处理时长(秒)
|
@@ -167,7 +168,7 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
167 |
|
168 |
pbar.close() # 关闭进度条
|
169 |
video_writer.release()
|
170 |
-
logger.info(f"
|
171 |
|
172 |
# 生成分析报告
|
173 |
confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
|
@@ -183,118 +184,114 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
183 |
- 置信度阈值: {conf_threshold:.2f}
|
184 |
- 最大检测数量: {max_det}
|
185 |
- 处理时长: {process_seconds}秒
|
186 |
-
|
187 |
分析结果:
|
188 |
- 处理帧数: {frame_count}
|
189 |
- 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f}
|
190 |
- 最大检测数: {max([info['count'] for info in detection_info])}
|
191 |
- 最小检测数: {min([info['count'] for info in detection_info])}
|
192 |
-
|
193 |
置信度分布:
|
194 |
{confidence_report}
|
195 |
"""
|
196 |
|
197 |
-
def
|
198 |
-
"""
|
199 |
-
过滤轨迹中的异常点
|
200 |
-
Args:
|
201 |
-
positions: 位置列表 [[x1,y1], [x2,y2],...]
|
202 |
-
width: 视频宽度
|
203 |
-
height: 视频高度
|
204 |
-
max_jump_distance: 允许��最大跳跃距离
|
205 |
-
"""
|
206 |
if len(positions) < 3:
|
207 |
return positions
|
208 |
|
209 |
-
|
210 |
-
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
-
|
224 |
|
225 |
-
if
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
if next_valid_pos is not None:
|
237 |
-
steps = max(2, int(distance / max_jump_distance))
|
238 |
-
for j in range(1, steps):
|
239 |
-
alpha = j / steps
|
240 |
-
interp_x = int(last_valid_pos[0] * (1 - alpha) + next_valid_pos[0] * alpha)
|
241 |
-
interp_y = int(last_valid_pos[1] * (1 - alpha) + next_valid_pos[1] * alpha)
|
242 |
-
filtered_positions.append([interp_x, interp_y])
|
243 |
-
filtered_positions.append(next_valid_pos)
|
244 |
-
last_valid_pos = next_valid_pos
|
245 |
-
else:
|
246 |
-
filtered_positions.append(pos)
|
247 |
-
last_valid_pos = pos
|
248 |
|
249 |
-
|
250 |
-
smoothed_positions = []
|
251 |
|
252 |
-
|
253 |
-
|
|
|
|
|
|
|
254 |
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
-
|
262 |
-
|
263 |
-
smoothed_positions = filtered_positions
|
264 |
|
265 |
-
return
|
266 |
|
267 |
# 修改轨迹图生成部分
|
268 |
-
trajectory_img =
|
269 |
points = np.array(all_positions, dtype=np.int32)
|
270 |
if len(points) > 1:
|
271 |
-
filtered_points =
|
272 |
points = np.array(filtered_points, dtype=np.int32)
|
273 |
|
274 |
for i in range(len(points) - 1):
|
275 |
ratio = i / (len(points) - 1)
|
276 |
-
color = (
|
277 |
int((1 - ratio) * 255), # B
|
278 |
50, # G
|
279 |
int(ratio * 255) # R
|
280 |
-
)
|
281 |
-
|
|
|
|
|
|
|
282 |
|
283 |
-
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
-
|
287 |
-
for i in range(0, len(points) - arrow_interval, arrow_interval):
|
288 |
-
pt1 = tuple(points[i])
|
289 |
-
pt2 = tuple(points[i + arrow_interval])
|
290 |
-
angle = np.arctan2(pt2[1] - pt1[1], pt2[0] - pt1[0])
|
291 |
-
cv2.arrowedLine(trajectory_img, pt1, pt2, (100, 100, 100), 1, tipLength=0.2)
|
292 |
-
|
293 |
-
if np.max(heatmap) > 0:
|
294 |
-
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
|
295 |
-
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
|
296 |
-
alpha = 0.7
|
297 |
-
heatmap_colored = cv2.addWeighted(heatmap_colored, alpha, np.full_like(heatmap_colored, 255), 1-alpha, 0)
|
298 |
|
299 |
trajectory_frames = []
|
300 |
heatmap_frames = []
|
@@ -357,6 +354,62 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
357 |
logger.info("所有处理完成,准备返回结果")
|
358 |
return output_path, trajectory_path, heatmap_path, trajectory_gif_path, heatmap_gif_path, report
|
359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
# 创建 Gradio 界面
|
361 |
with gr.Blocks() as demo:
|
362 |
gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)")
|
|
|
12 |
import imageio
|
13 |
from tqdm import tqdm
|
14 |
import logging
|
15 |
+
import torch.nn.functional as F
|
16 |
|
17 |
# 新增: 配置logging
|
18 |
logging.basicConfig(
|
|
|
50 |
@spaces.GPU(duration=300)
|
51 |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
|
52 |
"""
|
53 |
+
处理视频并进行��鼠检测
|
54 |
Args:
|
55 |
video_path: 输入视频路径
|
56 |
process_seconds: 处理时长(秒)
|
|
|
168 |
|
169 |
pbar.close() # 关闭进度条
|
170 |
video_writer.release()
|
171 |
+
logger.info(f"视频处��完成,共处理 {frame_count} 帧")
|
172 |
|
173 |
# 生成分析报告
|
174 |
confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
|
|
|
184 |
- 置信度阈值: {conf_threshold:.2f}
|
185 |
- 最大检测数量: {max_det}
|
186 |
- 处理时长: {process_seconds}秒
|
|
|
187 |
分析结果:
|
188 |
- 处理帧数: {frame_count}
|
189 |
- 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f}
|
190 |
- 最大检测数: {max([info['count'] for info in detection_info])}
|
191 |
- 最小检测数: {min([info['count'] for info in detection_info])}
|
|
|
192 |
置信度分布:
|
193 |
{confidence_report}
|
194 |
"""
|
195 |
|
196 |
+
def filter_trajectories_gpu(positions, width, height, max_jump_distance=100):
|
197 |
+
"""GPU加速版本的轨迹过滤"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
if len(positions) < 3:
|
199 |
return positions
|
200 |
|
201 |
+
# 转换为GPU张量
|
202 |
+
points = torch.tensor(positions, device=device, dtype=torch.float32)
|
203 |
|
204 |
+
# 计算相邻点之间的距离
|
205 |
+
diffs = points[1:] - points[:-1]
|
206 |
+
distances = torch.norm(diffs, dim=1)
|
207 |
+
|
208 |
+
# 找出需要插值的位置
|
209 |
+
mask = distances > max_jump_distance
|
210 |
+
valid_indices = (~mask).nonzero().squeeze()
|
211 |
+
|
212 |
+
if len(valid_indices) < 2:
|
213 |
+
return positions
|
214 |
+
|
215 |
+
# 使用GPU进行插值
|
216 |
+
filtered_points = []
|
217 |
+
last_valid_idx = 0
|
218 |
+
|
219 |
+
for i in range(len(valid_indices)-1):
|
220 |
+
curr_idx = valid_indices[i].item()
|
221 |
+
next_idx = valid_indices[i+1].item()
|
222 |
|
223 |
+
filtered_points.append(points[curr_idx].tolist())
|
224 |
|
225 |
+
if next_idx - curr_idx > 1:
|
226 |
+
# 线性插值
|
227 |
+
steps = max(2, int((next_idx - curr_idx)))
|
228 |
+
interp_points = torch.linspace(0, 1, steps)
|
229 |
+
start_point = points[curr_idx]
|
230 |
+
end_point = points[next_idx]
|
231 |
+
|
232 |
+
interpolated = start_point[None] * (1 - interp_points[:, None]) + \
|
233 |
+
end_point[None] * interp_points[:, None]
|
234 |
+
|
235 |
+
filtered_points.extend(interpolated[1:-1].tolist())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
+
filtered_points.append(points[valid_indices[-1]].tolist())
|
|
|
238 |
|
239 |
+
# 平滑处理
|
240 |
+
if len(filtered_points) >= 5:
|
241 |
+
points_tensor = torch.tensor(filtered_points, device=device)
|
242 |
+
kernel_size = 5
|
243 |
+
padding = kernel_size // 2
|
244 |
|
245 |
+
# 使用1D卷积进行平滑
|
246 |
+
weights = torch.ones(1, 1, kernel_size, device=device) / kernel_size
|
247 |
+
smoothed_x = F.conv1d(
|
248 |
+
points_tensor[:, 0].view(1, 1, -1),
|
249 |
+
weights,
|
250 |
+
padding=padding
|
251 |
+
).squeeze()
|
252 |
+
smoothed_y = F.conv1d(
|
253 |
+
points_tensor[:, 1].view(1, 1, -1),
|
254 |
+
weights,
|
255 |
+
padding=padding
|
256 |
+
).squeeze()
|
257 |
|
258 |
+
smoothed_points = torch.stack([smoothed_x, smoothed_y], dim=1)
|
259 |
+
return smoothed_points.cpu().numpy().tolist()
|
|
|
260 |
|
261 |
+
return filtered_points
|
262 |
|
263 |
# 修改轨迹图生成部分
|
264 |
+
trajectory_img = torch.ones((height, width, 3), device=device) * 255
|
265 |
points = np.array(all_positions, dtype=np.int32)
|
266 |
if len(points) > 1:
|
267 |
+
filtered_points = filter_trajectories_gpu(points.tolist(), width, height)
|
268 |
points = np.array(filtered_points, dtype=np.int32)
|
269 |
|
270 |
for i in range(len(points) - 1):
|
271 |
ratio = i / (len(points) - 1)
|
272 |
+
color = torch.tensor([
|
273 |
int((1 - ratio) * 255), # B
|
274 |
50, # G
|
275 |
int(ratio * 255) # R
|
276 |
+
], device=device)
|
277 |
+
|
278 |
+
# 使用GPU绘制线段
|
279 |
+
pt1, pt2 = points[i], points[i + 1]
|
280 |
+
draw_line_gpu(trajectory_img, pt1, pt2, color, 2)
|
281 |
|
282 |
+
trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8)
|
283 |
+
|
284 |
+
if torch.cuda.is_available():
|
285 |
+
heatmap = torch.zeros((height, width), device=device)
|
286 |
+
for x, y in filtered_points:
|
287 |
+
if 0 <= x < width and 0 <= y < height:
|
288 |
+
temp_heatmap = torch.zeros((height, width), device=device)
|
289 |
+
temp_heatmap[y, x] = 1
|
290 |
+
# 使用GPU的高斯模糊
|
291 |
+
temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10)
|
292 |
+
heatmap += temp_heatmap
|
293 |
|
294 |
+
heatmap = heatmap.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
trajectory_frames = []
|
297 |
heatmap_frames = []
|
|
|
354 |
logger.info("所有处理完成,准备返回结果")
|
355 |
return output_path, trajectory_path, heatmap_path, trajectory_gif_path, heatmap_gif_path, report
|
356 |
|
357 |
+
def gaussian_blur_gpu(tensor, kernel_size=31, sigma=10):
|
358 |
+
"""GPU版本的高斯模糊"""
|
359 |
+
channels = 1
|
360 |
+
kernel = get_gaussian_kernel2d(kernel_size, sigma).to(device)
|
361 |
+
kernel = kernel.view(1, 1, kernel_size, kernel_size)
|
362 |
+
tensor = tensor.view(1, 1, tensor.shape[0], tensor.shape[1])
|
363 |
+
|
364 |
+
return F.conv2d(tensor, kernel, padding=kernel_size//2).squeeze()
|
365 |
+
|
366 |
+
def get_gaussian_kernel2d(kernel_size, sigma):
|
367 |
+
"""生成2D高斯核"""
|
368 |
+
kernel_x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
|
369 |
+
x, y = torch.meshgrid(kernel_x, kernel_x)
|
370 |
+
kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
|
371 |
+
return kernel / kernel.sum()
|
372 |
+
|
373 |
+
def draw_line_gpu(image, pt1, pt2, color, thickness=1):
|
374 |
+
"""GPU版本的线段绘制"""
|
375 |
+
x1, y1 = pt1
|
376 |
+
x2, y2 = pt2
|
377 |
+
dx = abs(x2 - x1)
|
378 |
+
dy = abs(y2 - y1)
|
379 |
+
|
380 |
+
if dx > dy:
|
381 |
+
steps = dx
|
382 |
+
else:
|
383 |
+
steps = dy
|
384 |
+
|
385 |
+
x_inc = (x2 - x1) / steps
|
386 |
+
y_inc = (y2 - y1) / steps
|
387 |
+
|
388 |
+
x = x1
|
389 |
+
y = y1
|
390 |
+
|
391 |
+
points = torch.zeros((int(steps) + 1, 2), device=device)
|
392 |
+
for i in range(int(steps) + 1):
|
393 |
+
points[i] = torch.tensor([x, y])
|
394 |
+
x += x_inc
|
395 |
+
y += y_inc
|
396 |
+
|
397 |
+
points = points.long()
|
398 |
+
valid_points = (points[:, 0] >= 0) & (points[:, 0] < image.shape[1]) & \
|
399 |
+
(points[:, 1] >= 0) & (points[:, 1] < image.shape[0])
|
400 |
+
points = points[valid_points]
|
401 |
+
|
402 |
+
if thickness > 1:
|
403 |
+
for dx in range(-thickness//2, thickness//2 + 1):
|
404 |
+
for dy in range(-thickness//2, thickness//2 + 1):
|
405 |
+
offset_points = points + torch.tensor([dx, dy], device=device)
|
406 |
+
valid_offset = (offset_points[:, 0] >= 0) & (offset_points[:, 0] < image.shape[1]) & \
|
407 |
+
(offset_points[:, 1] >= 0) & (offset_points[:, 1] < image.shape[0])
|
408 |
+
offset_points = offset_points[valid_offset]
|
409 |
+
image[offset_points[:, 1], offset_points[:, 0]] = color
|
410 |
+
else:
|
411 |
+
image[points[:, 1], points[:, 0]] = color
|
412 |
+
|
413 |
# 创建 Gradio 界面
|
414 |
with gr.Blocks() as demo:
|
415 |
gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)")
|