Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from ultralytics import YOLO | |
from fastapi import FastAPI | |
from PIL import Image | |
import torch | |
import spaces | |
import numpy as np | |
import cv2 | |
from pathlib import Path | |
import tempfile | |
# 从环境变量获取密码 | |
APP_USERNAME = "admin" # 用户名保持固定 | |
APP_PASSWORD = os.getenv("APP_PASSWORD", "default_password") # 从环境变量获取密码 | |
app = FastAPI() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = YOLO('kunin-mice-pose.v0.1.0.pt') | |
# 定义认证状态 | |
class AuthState: | |
def __init__(self): | |
self.is_logged_in = False | |
auth_state = AuthState() | |
def login(username, password): | |
"""登录验证""" | |
if username == APP_USERNAME and password == APP_PASSWORD: | |
auth_state.is_logged_in = True | |
return gr.update(visible=False), gr.update(visible=True), "登录成功" | |
return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误" | |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8): | |
""" | |
处理视频并进行小鼠检测 | |
Args: | |
video_path: 输入视频路径 | |
process_seconds: 处理时长(秒) | |
conf_threshold: 置信度阈值(0-1) | |
max_det: 每帧最大检测数量 | |
""" | |
if not auth_state.is_logged_in: | |
return None, "请先登录" | |
# 创建临时目录保存输出视频 | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
output_path = tmp_file.name | |
# 获取视频信息 | |
cap = cv2.VideoCapture(video_path) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
cap.release() | |
# 创建视频写入器 | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video_writer = cv2.VideoWriter( | |
output_path, | |
fourcc, | |
fps, | |
(width, height) | |
) | |
# 计算基于分辨率的线宽 | |
base_size = min(width, height) | |
line_thickness = max(1, int(base_size * 0.002)) # 0.2% 的最小边长 | |
# 设置推理参数并处理视频 | |
results = model.predict( | |
source=video_path, | |
device=device, | |
conf=conf_threshold, | |
save=False, | |
show=False, | |
stream=True, | |
line_width=line_thickness, # 线宽 | |
boxes=True, # 显示边界框 | |
show_labels=True, | |
show_conf=True, | |
vid_stride=1, | |
max_det=max_det, | |
retina_masks=True # 更精细的显示 | |
) | |
# 处理结果 | |
frame_count = 0 | |
detection_info = [] | |
# 用于记录轨迹和热图数据 | |
all_positions = [] | |
heatmap = np.zeros((height, width), dtype=np.float32) | |
for r in results: | |
frame = r.plot() | |
# 收集位置信息 | |
if hasattr(r, 'keypoints') and r.keypoints is not None: | |
# 打印关键点对象信息 | |
print(f"Keypoints type: {type(r.keypoints)}") | |
print(f"Keypoints data: {r.keypoints}") | |
for kpts in r.keypoints: | |
if isinstance(kpts, torch.Tensor): | |
kpts = kpts.cpu().numpy() | |
print(f"Single keypoints shape: {kpts.shape}") # 打印形状 | |
print(f"Single keypoints data: {kpts}") # 打印数据 | |
# 确保关键点数据是正确的格式 | |
if isinstance(kpts, np.ndarray): | |
if len(kpts.shape) == 3: # [num_objects, num_keypoints, 3] | |
for obj_kpts in kpts: | |
if len(obj_kpts) > 0: | |
x, y = obj_kpts[0][:2] # 使用第一个关键点的x,y坐标 | |
if isinstance(x, (int, float)) and isinstance(y, (int, float)): | |
x, y = int(x), int(y) | |
all_positions.append([x, y]) | |
# 更新热图,使用高斯核来平滑 | |
if 0 <= x < width and 0 <= y < height: | |
# 创建高斯核心点 | |
sigma = 5 # 调整这个值来改变热点大小 | |
kernel_size = 15 # 必须是奇数 | |
temp_heatmap = np.zeros((height, width), dtype=np.float32) | |
temp_heatmap[y, x] = 1 | |
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma) | |
heatmap += temp_heatmap | |
elif len(kpts.shape) == 2: # [num_keypoints, 3] | |
if len(kpts) > 0: | |
x, y = kpts[0][:2] # 使用第一个关键点的x,y坐标 | |
if isinstance(x, (int, float)) and isinstance(y, (int, float)): | |
x, y = int(x), int(y) | |
all_positions.append([x, y]) | |
# 更新热图,使用高斯核来平滑 | |
if 0 <= x < width and 0 <= y < height: | |
# 创建高斯核心点 | |
sigma = 5 # 调整这个值来改变热点大小 | |
kernel_size = 15 # 必须是奇数 | |
temp_heatmap = np.zeros((height, width), dtype=np.float32) | |
temp_heatmap[y, x] = 1 | |
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma) | |
heatmap += temp_heatmap | |
# 收集检测信息 | |
frame_info = { | |
"frame": frame_count + 1, | |
"count": len(r.boxes), | |
"detections": [] | |
} | |
for box in r.boxes: | |
conf = float(box.conf[0]) | |
cls = int(box.cls[0]) | |
cls_name = r.names[cls] | |
frame_info["detections"].append({ | |
"class": cls_name, | |
"confidence": f"{conf:.2%}" | |
}) | |
detection_info.append(frame_info) | |
# 写入视频 | |
video_writer.write(frame) | |
frame_count += 1 | |
if process_seconds and frame_count >= total_frames: | |
break | |
# 释放视频写入器 | |
video_writer.release() | |
# 生成分析报告 | |
confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']] | |
hist, bins = np.histogram(confidences, bins=5) | |
confidence_report = "\n".join([ | |
f"置信度 {bins[i]:.2f}-{bins[i+1]:.2f}: {hist[i]:3d}个检测 ({hist[i]/len(confidences)*100:.1f}%)" | |
for i in range(len(hist)) | |
]) | |
report = f"""视频分析报告: | |
参数设置: | |
- 置信度阈值: {conf_threshold:.2f} | |
- 最大检测数量: {max_det} | |
- 处理时长: {process_seconds}秒 | |
分析结果: | |
- 处理帧数: {frame_count} | |
- 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f} | |
- 最大检测数: {max([info['count'] for info in detection_info])} | |
- 最小检测数: {min([info['count'] for info in detection_info])} | |
置信度分布: | |
{confidence_report} | |
""" | |
# 生成轨迹图 | |
trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255 # 白色背景 | |
points = np.array(all_positions, dtype=np.int32) | |
if len(points) > 1: | |
# 绘制轨迹线,使用渐变色 | |
for i in range(len(points) - 1): | |
ratio = i / (len(points) - 1) | |
color = ( | |
int((1 - ratio) * 255), # B | |
0, # G | |
int(ratio * 255) # R | |
) | |
cv2.line(trajectory_img, tuple(points[i]), tuple(points[i + 1]), color, 2) | |
# 绘制起点和终点 | |
cv2.circle(trajectory_img, tuple(points[0]), 8, (0, 255, 0), -1) # 绿色起点 | |
cv2.circle(trajectory_img, tuple(points[-1]), 8, (255, 0, 0), -1) # 红色终点 | |
# 生成热图 | |
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX) | |
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET) | |
# 保存图像 | |
trajectory_path = output_path.replace('.mp4', '_trajectory.png') | |
heatmap_path = output_path.replace('.mp4', '_heatmap.png') | |
cv2.imwrite(trajectory_path, trajectory_img) | |
cv2.imwrite(heatmap_path, heatmap_colored) | |
return output_path, trajectory_path, heatmap_path, report | |
# 创建 Gradio 界面 | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)") | |
# 登录界面 | |
with gr.Group() as login_interface: | |
username = gr.Textbox(label="用户名") | |
password = gr.Textbox(label="密码", type="password") | |
login_button = gr.Button("登录") | |
login_msg = gr.Textbox(label="消息", interactive=False) | |
# 主界面 | |
with gr.Group(visible=False) as main_interface: | |
gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior") | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="输入视频") | |
process_seconds = gr.Number( | |
label="处理时长(秒,0表示处理整个视频)", | |
value=20 | |
) | |
conf_threshold = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.2, | |
step=0.05, | |
label="置信度阈值", | |
info="越高越严格,建议范围0.2-0.5" | |
) | |
max_det = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=8, | |
step=1, | |
label="最大检测数量", | |
info="每帧最多检测的目标数量" | |
) | |
process_btn = gr.Button("开始处理") | |
with gr.Column(): | |
video_output = gr.Video(label="检测结果") | |
with gr.Row(): | |
trajectory_output = gr.Image(label="运动轨迹") | |
heatmap_output = gr.Image(label="热力图") | |
report_output = gr.Textbox(label="分析报告") | |
gr.Markdown(""" | |
### 使用说明 | |
1. 上传视频文件 | |
2. 设置处理参数: | |
- 处理时长:需要分析的视频时长(秒) | |
- 置信度阈值:检测的置信度要求(越高越严格) | |
- 最大检测数量:每帧最多检测的目标数量 | |
3. 等待处理完成 | |
4. 查看检测结果视频和分析报告 | |
### 注意事项 | |
- 支持常见视频格式(mp4, avi 等) | |
- 建议视频分辨率不超过 1920x1080 | |
- 处理时间与视频长度和分辨率相关 | |
- 置信度建议范围:0.2-0.5 | |
- 最大检测数量建议根据实际场景设置 | |
""") | |
# 设置事件处理 | |
login_button.click( | |
fn=login, | |
inputs=[username, password], | |
outputs=[login_interface, main_interface, login_msg] | |
) | |
process_btn.click( | |
fn=process_video, | |
inputs=[video_input, process_seconds, conf_threshold, max_det], | |
outputs=[video_output, trajectory_output, heatmap_output, report_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |