Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,18 +9,7 @@ import numpy as np
|
|
9 |
import cv2
|
10 |
from pathlib import Path
|
11 |
import tempfile
|
12 |
-
import imageio
|
13 |
from tqdm import tqdm
|
14 |
-
import logging
|
15 |
-
import torch.nn.functional as F
|
16 |
-
|
17 |
-
# 新增: 配置logging
|
18 |
-
logging.basicConfig(
|
19 |
-
level=logging.INFO,
|
20 |
-
format='%(asctime)s - %(levelname)s - %(message)s',
|
21 |
-
datefmt='%Y-%m-%d %H:%M:%S'
|
22 |
-
)
|
23 |
-
logger = logging.getLogger(__name__)
|
24 |
|
25 |
# 从环境变量获取密码
|
26 |
APP_USERNAME = "admin" # 用户名保持固定
|
@@ -28,7 +17,9 @@ APP_PASSWORD = os.getenv("APP_PASSWORD", "default_password") # 从环境变量
|
|
28 |
|
29 |
app = FastAPI()
|
30 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
31 |
model = YOLO('kunin-mice-pose.v0.1.5n.pt')
|
|
|
32 |
|
33 |
# 定义认证状态
|
34 |
class AuthState:
|
@@ -39,38 +30,26 @@ auth_state = AuthState()
|
|
39 |
|
40 |
def login(username, password):
|
41 |
"""登录验证"""
|
42 |
-
logger.info(f"用户尝试登录: {username}")
|
43 |
if username == APP_USERNAME and password == APP_PASSWORD:
|
44 |
auth_state.is_logged_in = True
|
45 |
-
logger.info("登录成功")
|
46 |
return gr.update(visible=False), gr.update(visible=True), "登录成功"
|
47 |
-
logger.warning("登录失败:用户名或密码错误")
|
48 |
return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误"
|
49 |
|
50 |
-
@spaces.GPU(duration=
|
51 |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
|
52 |
"""
|
53 |
处理视频并进行小鼠检测
|
54 |
-
Args:
|
55 |
-
video_path: 输入视频路径
|
56 |
-
process_seconds: 处理时长(秒)
|
57 |
-
conf_threshold: 置信度阈值(0-1)
|
58 |
-
max_det: 每帧最大检测数量
|
59 |
"""
|
60 |
-
|
61 |
-
logger.info(f"参数设置 - 处理时长: {process_seconds}秒, 置信度阈值: {conf_threshold}, 最大检测数: {max_det}")
|
62 |
|
63 |
if not auth_state.is_logged_in:
|
64 |
-
logger.warning("用户未登录,拒绝访问")
|
65 |
return None, "请先登录"
|
66 |
|
67 |
-
|
68 |
-
logger.info("创建临时输出目录")
|
69 |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
|
70 |
output_path = tmp_file.name
|
71 |
|
72 |
-
|
73 |
-
logger.info("读取视频信息")
|
74 |
cap = cv2.VideoCapture(video_path)
|
75 |
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
76 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
@@ -78,9 +57,9 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
78 |
total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
79 |
cap.release()
|
80 |
|
81 |
-
|
82 |
|
83 |
-
|
84 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
85 |
video_writer = cv2.VideoWriter(
|
86 |
output_path,
|
@@ -89,11 +68,10 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
89 |
(width, height)
|
90 |
)
|
91 |
|
92 |
-
# 计算基于分辨率的线宽
|
93 |
base_size = min(width, height)
|
94 |
-
line_thickness = max(1, int(base_size * 0.002))
|
95 |
|
96 |
-
|
97 |
results = model.predict(
|
98 |
source=video_path,
|
99 |
device=device,
|
@@ -108,28 +86,26 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
108 |
vid_stride=1,
|
109 |
max_det=max_det,
|
110 |
retina_masks=True,
|
111 |
-
verbose=False
|
112 |
)
|
113 |
|
114 |
-
logger.info("开始处理检测结果")
|
115 |
frame_count = 0
|
116 |
detection_info = []
|
117 |
all_positions = []
|
118 |
heatmap = np.zeros((height, width), dtype=np.float32)
|
119 |
|
120 |
-
|
121 |
-
|
122 |
|
123 |
for r in results:
|
124 |
frame = r.plot()
|
125 |
|
126 |
-
# 收集位置信息
|
127 |
if hasattr(r, 'keypoints') and r.keypoints is not None:
|
128 |
kpts = r.keypoints.data
|
129 |
if isinstance(kpts, torch.Tensor):
|
130 |
kpts = kpts.cpu().numpy()
|
131 |
|
132 |
-
if kpts.shape == (1, 8, 3):
|
133 |
x, y = int(kpts[0, 0, 0]), int(kpts[0, 0, 1])
|
134 |
all_positions.append([x, y])
|
135 |
|
@@ -141,7 +117,6 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
141 |
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma)
|
142 |
heatmap += temp_heatmap
|
143 |
|
144 |
-
# 收集检测信息
|
145 |
frame_info = {
|
146 |
"frame": frame_count + 1,
|
147 |
"count": len(r.boxes),
|
@@ -161,16 +136,17 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
161 |
video_writer.write(frame)
|
162 |
|
163 |
frame_count += 1
|
164 |
-
|
165 |
|
166 |
if process_seconds and frame_count >= total_frames:
|
167 |
break
|
168 |
|
169 |
-
|
|
|
|
|
170 |
video_writer.release()
|
171 |
-
|
172 |
-
|
173 |
-
# 生成分析报告
|
174 |
confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
|
175 |
hist, bins = np.histogram(confidences, bins=5)
|
176 |
|
@@ -193,267 +169,113 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
|
|
193 |
{confidence_report}
|
194 |
"""
|
195 |
|
196 |
-
def
|
197 |
-
"""GPU加速版本的轨迹过滤"""
|
198 |
if len(positions) < 3:
|
199 |
return positions
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
# 计算相邻点之间的距离
|
205 |
-
diffs = points[1:] - points[:-1]
|
206 |
-
distances = torch.norm(diffs, dim=1)
|
207 |
-
|
208 |
-
# 找出需要插值的位置
|
209 |
-
mask = distances > max_jump_distance
|
210 |
-
valid_indices = (~mask).nonzero().squeeze()
|
211 |
-
|
212 |
-
if len(valid_indices) < 2:
|
213 |
-
return positions
|
214 |
-
|
215 |
-
# 使用GPU进行插值
|
216 |
-
filtered_points = []
|
217 |
-
last_valid_idx = 0
|
218 |
|
219 |
-
for i in
|
220 |
-
|
221 |
-
next_idx = valid_indices[i+1].item()
|
222 |
|
223 |
-
|
|
|
224 |
|
225 |
-
if
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
-
|
|
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
points_tensor = torch.tensor(filtered_points, device=device)
|
242 |
-
kernel_size = 5
|
243 |
-
padding = kernel_size // 2
|
244 |
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
padding=padding
|
251 |
-
).squeeze()
|
252 |
-
smoothed_y = F.conv1d(
|
253 |
-
points_tensor[:, 1].view(1, 1, -1),
|
254 |
-
weights,
|
255 |
-
padding=padding
|
256 |
-
).squeeze()
|
257 |
|
258 |
-
|
259 |
-
|
|
|
260 |
|
261 |
-
return
|
262 |
|
263 |
-
|
264 |
-
trajectory_img =
|
265 |
points = np.array(all_positions, dtype=np.int32)
|
266 |
if len(points) > 1:
|
267 |
-
filtered_points =
|
268 |
points = np.array(filtered_points, dtype=np.int32)
|
269 |
|
270 |
for i in range(len(points) - 1):
|
271 |
ratio = i / (len(points) - 1)
|
272 |
-
color =
|
273 |
-
int((1 - ratio) * 255),
|
274 |
-
50,
|
275 |
-
int(ratio * 255)
|
276 |
-
|
277 |
-
|
278 |
-
# 使用GPU绘制线段
|
279 |
-
pt1, pt2 = points[i], points[i + 1]
|
280 |
-
draw_line_gpu(trajectory_img, pt1, pt2, color, 2)
|
281 |
|
282 |
-
trajectory_img
|
283 |
-
|
284 |
-
# 修改热力图生成部分
|
285 |
-
def generate_heatmap(points, width, height):
|
286 |
-
"""简化版本的热力图生成函数"""
|
287 |
-
logger.info(f"开始生成热力图,共 {len(points)} 个点")
|
288 |
-
|
289 |
-
# 使用numpy生成热力图
|
290 |
-
heatmap = np.zeros((height, width), dtype=np.float32)
|
291 |
-
|
292 |
-
for point in points:
|
293 |
-
try:
|
294 |
-
# 确保是整数坐标
|
295 |
-
x = min(max(0, int(point[0])), width-1)
|
296 |
-
y = min(max(0, int(point[1])), height-1)
|
297 |
-
|
298 |
-
# 在点位置标记1
|
299 |
-
heatmap[y, x] = 1.0
|
300 |
-
|
301 |
-
# 直接使用OpenCV的高斯模糊
|
302 |
-
temp_heatmap = cv2.GaussianBlur(heatmap, (31, 31), 10)
|
303 |
-
heatmap = temp_heatmap
|
304 |
-
|
305 |
-
except Exception as e:
|
306 |
-
logger.warning(f"处理点 {point} 时出错: {e}")
|
307 |
-
continue
|
308 |
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
if np.max(heatmap) > 0:
|
318 |
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
|
319 |
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
|
320 |
-
|
321 |
-
heatmap_colored = cv2.addWeighted(heatmap_colored,
|
322 |
-
else:
|
323 |
-
# 如果没有检测到点,返回白色图像
|
324 |
-
heatmap_colored = np.full((height, width, 3), 255, dtype=np.uint8)
|
325 |
-
|
326 |
-
logger.info("热力图生成完成")
|
327 |
-
|
328 |
-
trajectory_frames = []
|
329 |
-
heatmap_frames = []
|
330 |
-
|
331 |
-
base_trajectory = np.zeros((height, width, 3), dtype=np.uint8) + 255
|
332 |
-
base_heatmap = np.zeros((height, width), dtype=np.float32)
|
333 |
-
|
334 |
-
frame_interval = max(1, len(filtered_points) // 30)
|
335 |
-
|
336 |
-
for i in range(0, len(filtered_points), frame_interval):
|
337 |
-
current_points = filtered_points[:i+1]
|
338 |
-
|
339 |
-
frame_trajectory = base_trajectory.copy()
|
340 |
-
if len(current_points) > 1:
|
341 |
-
points = np.array(current_points, dtype=np.int32)
|
342 |
-
for j in range(len(points) - 1):
|
343 |
-
ratio = j / (len(current_points) - 1)
|
344 |
-
color = (
|
345 |
-
int((1 - ratio) * 255),
|
346 |
-
50,
|
347 |
-
int(ratio * 255)
|
348 |
-
)
|
349 |
-
cv2.line(frame_trajectory, tuple(points[j]), tuple(points[j + 1]), color, 2)
|
350 |
-
|
351 |
-
cv2.circle(frame_trajectory, tuple(points[-1]), 8, (0, 0, 255), -1)
|
352 |
-
trajectory_frames.append(frame_trajectory)
|
353 |
-
|
354 |
-
frame_heatmap = base_heatmap.copy()
|
355 |
-
for x, y in current_points:
|
356 |
-
if 0 <= x < width and 0 <= y < height:
|
357 |
-
temp_heatmap = np.zeros((height, width), dtype=np.float32)
|
358 |
-
temp_heatmap[y, x] = 1
|
359 |
-
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10)
|
360 |
-
frame_heatmap += temp_heatmap
|
361 |
-
|
362 |
-
if np.max(frame_heatmap) > 0:
|
363 |
-
frame_heatmap_norm = cv2.normalize(frame_heatmap, None, 0, 255, cv2.NORM_MINMAX)
|
364 |
-
frame_heatmap_color = cv2.applyColorMap(frame_heatmap_norm.astype(np.uint8), cv2.COLORMAP_JET)
|
365 |
-
frame_heatmap_color = cv2.addWeighted(frame_heatmap_color, 0.7, np.full_like(frame_heatmap_color, 255), 0.3, 0)
|
366 |
-
heatmap_frames.append(frame_heatmap_color)
|
367 |
-
|
368 |
-
logger.info("开始生成轨迹图和热力图")
|
369 |
-
trajectory_gif_path = output_path.replace('.mp4', '_trajectory.gif')
|
370 |
-
heatmap_gif_path = output_path.replace('.mp4', '_heatmap.gif')
|
371 |
-
|
372 |
-
imageio.mimsave(trajectory_gif_path, trajectory_frames, duration=50)
|
373 |
-
imageio.mimsave(heatmap_gif_path, heatmap_frames, duration=50)
|
374 |
|
|
|
375 |
trajectory_path = output_path.replace('.mp4', '_trajectory.png')
|
376 |
heatmap_path = output_path.replace('.mp4', '_heatmap.png')
|
377 |
cv2.imwrite(trajectory_path, trajectory_img)
|
378 |
-
if np.max(heatmap) > 0:
|
379 |
-
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
|
380 |
-
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
|
381 |
-
heatmap_colored = cv2.addWeighted(heatmap_colored, 0.7, np.full_like(heatmap_colored, 255), 0.3, 0)
|
382 |
cv2.imwrite(heatmap_path, heatmap_colored)
|
383 |
|
384 |
-
|
385 |
-
|
386 |
-
imageio.mimsave(trajectory_gif_path, trajectory_frames, duration=50)
|
387 |
-
imageio.mimsave(heatmap_gif_path, heatmap_frames, duration=50)
|
388 |
-
logger.info("GIF动画生成完成")
|
389 |
-
|
390 |
-
logger.info("所有处理完成,准备返回结果")
|
391 |
-
return output_path, trajectory_path, heatmap_path, trajectory_gif_path, heatmap_gif_path, report
|
392 |
-
|
393 |
-
def gaussian_blur_gpu(tensor, kernel_size=31, sigma=10):
|
394 |
-
"""GPU版本的高斯模糊"""
|
395 |
-
channels = 1
|
396 |
-
kernel = get_gaussian_kernel2d(kernel_size, sigma).to(device)
|
397 |
-
kernel = kernel.view(1, 1, kernel_size, kernel_size)
|
398 |
-
tensor = tensor.view(1, 1, tensor.shape[0], tensor.shape[1])
|
399 |
-
|
400 |
-
return F.conv2d(tensor, kernel, padding=kernel_size//2).squeeze()
|
401 |
-
|
402 |
-
def get_gaussian_kernel2d(kernel_size, sigma):
|
403 |
-
"""生成2D高斯核"""
|
404 |
-
kernel_x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size)
|
405 |
-
x, y = torch.meshgrid(kernel_x, kernel_x, indexing='ij')
|
406 |
-
kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
|
407 |
-
return kernel / kernel.sum()
|
408 |
-
|
409 |
-
def draw_line_gpu(image, pt1, pt2, color, thickness=1):
|
410 |
-
"""GPU版本的线段绘制"""
|
411 |
-
x1, y1 = map(int, pt1) # 确保是整数
|
412 |
-
x2, y2 = map(int, pt2) # 确保是整数
|
413 |
-
dx = abs(x2 - x1)
|
414 |
-
dy = abs(y2 - y1)
|
415 |
-
|
416 |
-
# 防止除零错误
|
417 |
-
steps = max(dx, dy)
|
418 |
-
if steps == 0:
|
419 |
-
# 如果是同一个点,直接画点
|
420 |
-
if 0 <= x1 < image.shape[1] and 0 <= y1 < image.shape[0]:
|
421 |
-
image[y1, x1] = color
|
422 |
-
return
|
423 |
-
|
424 |
-
x_inc = (x2 - x1) / steps
|
425 |
-
y_inc = (y2 - y1) / steps
|
426 |
-
|
427 |
-
x = x1
|
428 |
-
y = y1
|
429 |
-
|
430 |
-
points = torch.zeros((int(steps) + 1, 2), device=device)
|
431 |
-
for i in range(int(steps) + 1):
|
432 |
-
points[i] = torch.tensor([x, y])
|
433 |
-
x += x_inc
|
434 |
-
y += y_inc
|
435 |
-
|
436 |
-
points = points.long() # 转换为整数类型
|
437 |
-
valid_points = (points[:, 0] >= 0) & (points[:, 0] < image.shape[1]) & \
|
438 |
-
(points[:, 1] >= 0) & (points[:, 1] < image.shape[0])
|
439 |
-
points = points[valid_points]
|
440 |
-
|
441 |
-
color = color.to(image.dtype)
|
442 |
-
|
443 |
-
if thickness > 1:
|
444 |
-
for dx in range(-thickness//2, thickness//2 + 1):
|
445 |
-
for dy in range(-thickness//2, thickness//2 + 1):
|
446 |
-
offset_points = points + torch.tensor([dx, dy], device=device, dtype=torch.long)
|
447 |
-
valid_offset = (offset_points[:, 0] >= 0) & (offset_points[:, 0] < image.shape[1]) & \
|
448 |
-
(offset_points[:, 1] >= 0) & (offset_points[:, 1] < image.shape[0])
|
449 |
-
offset_points = offset_points[valid_offset]
|
450 |
-
image[offset_points[:, 1], offset_points[:, 0]] = color
|
451 |
-
else:
|
452 |
-
image[points[:, 1], points[:, 0]] = color
|
453 |
|
454 |
# 创建 Gradio 界面
|
455 |
with gr.Blocks() as demo:
|
456 |
-
gr.Markdown("#
|
457 |
|
458 |
with gr.Group() as login_interface:
|
459 |
username = gr.Textbox(label="用户名")
|
@@ -472,7 +294,7 @@ with gr.Blocks() as demo:
|
|
472 |
value=20
|
473 |
)
|
474 |
conf_threshold = gr.Slider(
|
475 |
-
minimum=0.
|
476 |
maximum=1.0,
|
477 |
value=0.2,
|
478 |
step=0.05,
|
@@ -485,7 +307,7 @@ with gr.Blocks() as demo:
|
|
485 |
value=1,
|
486 |
step=1,
|
487 |
label="最大检测数量",
|
488 |
-
info="
|
489 |
)
|
490 |
process_btn = gr.Button("开始处理")
|
491 |
|
@@ -493,17 +315,14 @@ with gr.Blocks() as demo:
|
|
493 |
video_output = gr.Video(label="检测结果")
|
494 |
with gr.Row():
|
495 |
trajectory_output = gr.Image(label="运动轨迹")
|
496 |
-
trajectory_gif_output = gr.Image(label="轨迹动画")
|
497 |
-
with gr.Row():
|
498 |
heatmap_output = gr.Image(label="热力图")
|
499 |
-
heatmap_gif_output = gr.Image(label="热力图动画")
|
500 |
report_output = gr.Textbox(label="分析报告")
|
501 |
|
502 |
gr.Markdown("""
|
503 |
### 使用说明
|
504 |
1. 上传视频文件
|
505 |
2. 设置处理参数:
|
506 |
-
-
|
507 |
- 置信度阈值:检测的置信度要求(越高越严格)
|
508 |
- 最大检测数量:每帧最多检测的目标数量
|
509 |
3. 等待处理完成
|
@@ -526,20 +345,8 @@ with gr.Blocks() as demo:
|
|
526 |
process_btn.click(
|
527 |
fn=process_video,
|
528 |
inputs=[video_input, process_seconds, conf_threshold, max_det],
|
529 |
-
outputs=[video_output, trajectory_output, heatmap_output,
|
530 |
-
trajectory_gif_output, heatmap_gif_output, report_output]
|
531 |
)
|
532 |
|
533 |
if __name__ == "__main__":
|
534 |
-
try:
|
535 |
-
# GPU相关操作
|
536 |
-
if torch.cuda.is_available():
|
537 |
-
logger.info("使用GPU进行轨迹和热力图计算")
|
538 |
-
# ... GPU操作 ...
|
539 |
-
else:
|
540 |
-
logger.info("使用CPU进行轨迹和热力图计算")
|
541 |
-
# ... CPU操作 ...
|
542 |
-
except Exception as e:
|
543 |
-
logger.error(f"处理轨迹和热力图时出错: {str(e)}")
|
544 |
-
raise
|
545 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
9 |
import cv2
|
10 |
from pathlib import Path
|
11 |
import tempfile
|
|
|
12 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# 从环境变量获取密码
|
15 |
APP_USERNAME = "admin" # 用户名保持固定
|
|
|
17 |
|
18 |
app = FastAPI()
|
19 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
+
print(f"使用设备: {device}")
|
21 |
model = YOLO('kunin-mice-pose.v0.1.5n.pt')
|
22 |
+
print("模型加载完成")
|
23 |
|
24 |
# 定义认证状态
|
25 |
class AuthState:
|
|
|
30 |
|
31 |
def login(username, password):
|
32 |
"""登录验证"""
|
|
|
33 |
if username == APP_USERNAME and password == APP_PASSWORD:
|
34 |
auth_state.is_logged_in = True
|
|
|
35 |
return gr.update(visible=False), gr.update(visible=True), "登录成功"
|
|
|
36 |
return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误"
|
37 |
|
38 |
+
@spaces.GPU(duration=120)
|
39 |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
|
40 |
"""
|
41 |
处理视频并进行小鼠检测
|
|
|
|
|
|
|
|
|
|
|
42 |
"""
|
43 |
+
print("开始处理视频...")
|
|
|
44 |
|
45 |
if not auth_state.is_logged_in:
|
|
|
46 |
return None, "请先登录"
|
47 |
|
48 |
+
print("创建临时输出文件...")
|
|
|
49 |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
|
50 |
output_path = tmp_file.name
|
51 |
|
52 |
+
print("读取视频信息...")
|
|
|
53 |
cap = cv2.VideoCapture(video_path)
|
54 |
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
55 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
57 |
total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
58 |
cap.release()
|
59 |
|
60 |
+
print(f"视频信息: {width}x{height} @ {fps}fps, 总帧数: {total_frames}")
|
61 |
|
62 |
+
print("初始化视频写入器...")
|
63 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
64 |
video_writer = cv2.VideoWriter(
|
65 |
output_path,
|
|
|
68 |
(width, height)
|
69 |
)
|
70 |
|
|
|
71 |
base_size = min(width, height)
|
72 |
+
line_thickness = max(1, int(base_size * 0.002))
|
73 |
|
74 |
+
print("开始YOLO推理...")
|
75 |
results = model.predict(
|
76 |
source=video_path,
|
77 |
device=device,
|
|
|
86 |
vid_stride=1,
|
87 |
max_det=max_det,
|
88 |
retina_masks=True,
|
89 |
+
verbose=False
|
90 |
)
|
91 |
|
|
|
92 |
frame_count = 0
|
93 |
detection_info = []
|
94 |
all_positions = []
|
95 |
heatmap = np.zeros((height, width), dtype=np.float32)
|
96 |
|
97 |
+
print("处理检测结果...")
|
98 |
+
progress_bar = tqdm(total=total_frames, desc="处理帧")
|
99 |
|
100 |
for r in results:
|
101 |
frame = r.plot()
|
102 |
|
|
|
103 |
if hasattr(r, 'keypoints') and r.keypoints is not None:
|
104 |
kpts = r.keypoints.data
|
105 |
if isinstance(kpts, torch.Tensor):
|
106 |
kpts = kpts.cpu().numpy()
|
107 |
|
108 |
+
if kpts.shape == (1, 8, 3):
|
109 |
x, y = int(kpts[0, 0, 0]), int(kpts[0, 0, 1])
|
110 |
all_positions.append([x, y])
|
111 |
|
|
|
117 |
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma)
|
118 |
heatmap += temp_heatmap
|
119 |
|
|
|
120 |
frame_info = {
|
121 |
"frame": frame_count + 1,
|
122 |
"count": len(r.boxes),
|
|
|
136 |
video_writer.write(frame)
|
137 |
|
138 |
frame_count += 1
|
139 |
+
progress_bar.update(1)
|
140 |
|
141 |
if process_seconds and frame_count >= total_frames:
|
142 |
break
|
143 |
|
144 |
+
progress_bar.close()
|
145 |
+
print("视频处理完成")
|
146 |
+
|
147 |
video_writer.release()
|
148 |
+
|
149 |
+
print("生成分析报告...")
|
|
|
150 |
confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']]
|
151 |
hist, bins = np.histogram(confidences, bins=5)
|
152 |
|
|
|
169 |
{confidence_report}
|
170 |
"""
|
171 |
|
172 |
+
def filter_trajectories(positions, width, height, max_jump_distance=100):
|
|
|
173 |
if len(positions) < 3:
|
174 |
return positions
|
175 |
|
176 |
+
filtered_positions = []
|
177 |
+
last_valid_pos = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
for i, pos in enumerate(positions):
|
180 |
+
x, y = pos
|
|
|
181 |
|
182 |
+
if not (0 <= x < width and 0 <= y < height):
|
183 |
+
continue
|
184 |
|
185 |
+
if last_valid_pos is None:
|
186 |
+
filtered_positions.append(pos)
|
187 |
+
last_valid_pos = pos
|
188 |
+
continue
|
189 |
+
|
190 |
+
distance = np.sqrt((x - last_valid_pos[0])**2 + (y - last_valid_pos[1])**2)
|
191 |
+
|
192 |
+
if distance > max_jump_distance:
|
193 |
+
if len(filtered_positions) > 0:
|
194 |
+
next_valid_pos = None
|
195 |
+
for next_pos in positions[i:]:
|
196 |
+
nx, ny = next_pos
|
197 |
+
if (0 <= nx < width and 0 <= ny < height):
|
198 |
+
next_distance = np.sqrt((nx - last_valid_pos[0])**2 + (ny - last_valid_pos[1])**2)
|
199 |
+
if next_distance <= max_jump_distance:
|
200 |
+
next_valid_pos = next_pos
|
201 |
+
break
|
202 |
+
|
203 |
+
if next_valid_pos is not None:
|
204 |
+
steps = max(2, int(distance / max_jump_distance))
|
205 |
+
for j in range(1, steps):
|
206 |
+
alpha = j / steps
|
207 |
+
interp_x = int(last_valid_pos[0] * (1 - alpha) + next_valid_pos[0] * alpha)
|
208 |
+
interp_y = int(last_valid_pos[1] * (1 - alpha) + next_valid_pos[1] * alpha)
|
209 |
+
filtered_positions.append([interp_x, interp_y])
|
210 |
+
filtered_positions.append(next_valid_pos)
|
211 |
+
last_valid_pos = next_valid_pos
|
212 |
+
else:
|
213 |
+
filtered_positions.append(pos)
|
214 |
+
last_valid_pos = pos
|
215 |
|
216 |
+
window_size = 5
|
217 |
+
smoothed_positions = []
|
218 |
|
219 |
+
if len(filtered_positions) >= window_size:
|
220 |
+
smoothed_positions.extend(filtered_positions[:window_size//2])
|
|
|
|
|
|
|
221 |
|
222 |
+
for i in range(window_size//2, len(filtered_positions) - window_size//2):
|
223 |
+
window = filtered_positions[i-window_size//2:i+window_size//2+1]
|
224 |
+
smoothed_x = int(np.mean([p[0] for p in window]))
|
225 |
+
smoothed_y = int(np.mean([p[1] for p in window]))
|
226 |
+
smoothed_positions.append([smoothed_x, smoothed_y])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
+
smoothed_positions.extend(filtered_positions[-window_size//2:])
|
229 |
+
else:
|
230 |
+
smoothed_positions = filtered_positions
|
231 |
|
232 |
+
return smoothed_positions
|
233 |
|
234 |
+
print("生成轨迹图...")
|
235 |
+
trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255
|
236 |
points = np.array(all_positions, dtype=np.int32)
|
237 |
if len(points) > 1:
|
238 |
+
filtered_points = filter_trajectories(points.tolist(), width, height)
|
239 |
points = np.array(filtered_points, dtype=np.int32)
|
240 |
|
241 |
for i in range(len(points) - 1):
|
242 |
ratio = i / (len(points) - 1)
|
243 |
+
color = (
|
244 |
+
int((1 - ratio) * 255),
|
245 |
+
50,
|
246 |
+
int(ratio * 255)
|
247 |
+
)
|
248 |
+
cv2.line(trajectory_img, tuple(points[i]), tuple(points[i + 1]), color, 2)
|
|
|
|
|
|
|
249 |
|
250 |
+
cv2.circle(trajectory_img, tuple(points[0]), 8, (0, 255, 0), -1)
|
251 |
+
cv2.circle(trajectory_img, tuple(points[-1]), 8, (0, 0, 255), -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
+
arrow_interval = max(len(points) // 20, 1)
|
254 |
+
for i in range(0, len(points) - arrow_interval, arrow_interval):
|
255 |
+
pt1 = tuple(points[i])
|
256 |
+
pt2 = tuple(points[i + arrow_interval])
|
257 |
+
angle = np.arctan2(pt2[1] - pt1[1], pt2[0] - pt1[0])
|
258 |
+
cv2.arrowedLine(trajectory_img, pt1, pt2, (100, 100, 100), 1, tipLength=0.2)
|
259 |
+
|
260 |
+
print("生成热力图...")
|
261 |
if np.max(heatmap) > 0:
|
262 |
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX)
|
263 |
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET)
|
264 |
+
alpha = 0.7
|
265 |
+
heatmap_colored = cv2.addWeighted(heatmap_colored, alpha, np.full_like(heatmap_colored, 255), 1-alpha, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
+
print("保存结果图像...")
|
268 |
trajectory_path = output_path.replace('.mp4', '_trajectory.png')
|
269 |
heatmap_path = output_path.replace('.mp4', '_heatmap.png')
|
270 |
cv2.imwrite(trajectory_path, trajectory_img)
|
|
|
|
|
|
|
|
|
271 |
cv2.imwrite(heatmap_path, heatmap_colored)
|
272 |
|
273 |
+
print("处理完成!")
|
274 |
+
return output_path, trajectory_path, heatmap_path, report
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
# 创建 Gradio 界面
|
277 |
with gr.Blocks() as demo:
|
278 |
+
gr.Markdown("# 🐭 小鼠行为分析 (Mice Behavior Analysis)")
|
279 |
|
280 |
with gr.Group() as login_interface:
|
281 |
username = gr.Textbox(label="用户名")
|
|
|
294 |
value=20
|
295 |
)
|
296 |
conf_threshold = gr.Slider(
|
297 |
+
minimum=0.1,
|
298 |
maximum=1.0,
|
299 |
value=0.2,
|
300 |
step=0.05,
|
|
|
307 |
value=1,
|
308 |
step=1,
|
309 |
label="最大检测数量",
|
310 |
+
info="每帧最多检测的目标数量"
|
311 |
)
|
312 |
process_btn = gr.Button("开始处理")
|
313 |
|
|
|
315 |
video_output = gr.Video(label="检测结果")
|
316 |
with gr.Row():
|
317 |
trajectory_output = gr.Image(label="运动轨迹")
|
|
|
|
|
318 |
heatmap_output = gr.Image(label="热力图")
|
|
|
319 |
report_output = gr.Textbox(label="分析报告")
|
320 |
|
321 |
gr.Markdown("""
|
322 |
### 使用说明
|
323 |
1. 上传视频文件
|
324 |
2. 设置处理参数:
|
325 |
+
- 处理时长:需要分析的视频时长(秒)
|
326 |
- 置信度阈值:检测的置信度要求(越高越严格)
|
327 |
- 最大检测数量:每帧最多检测的目标数量
|
328 |
3. 等待处理完成
|
|
|
345 |
process_btn.click(
|
346 |
fn=process_video,
|
347 |
inputs=[video_input, process_seconds, conf_threshold, max_det],
|
348 |
+
outputs=[video_output, trajectory_output, heatmap_output, report_output]
|
|
|
349 |
)
|
350 |
|
351 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|