Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from ultralytics import YOLO | |
from fastapi import FastAPI | |
from PIL import Image | |
import torch | |
import spaces | |
import numpy as np | |
import cv2 | |
from pathlib import Path | |
import tempfile | |
app = FastAPI() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = YOLO('kunin-mice-pose.v0.1.0.pt') | |
# 添加认证函数 | |
def auth_user(username, password): | |
# 你可以添加多个用户 | |
valid_credentials = { | |
"kunin": "123456", | |
"user1": "password1", | |
# 可以添加更多用户 | |
} | |
return username in valid_credentials and valid_credentials[username] == password | |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8): | |
""" | |
处理视频并进行小鼠检测 | |
Args: | |
video_path: 输入视频路径 | |
process_seconds: 处理时长(秒) | |
conf_threshold: 置信度阈值(0-1) | |
max_det: 每帧最大检测数量 | |
""" | |
# 创建临时目录保存输出视频 | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
output_path = tmp_file.name | |
# 获取视频信息 | |
cap = cv2.VideoCapture(video_path) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
cap.release() | |
# 创建视频写入器 | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video_writer = cv2.VideoWriter( | |
output_path, | |
fourcc, | |
fps, | |
(width, height) | |
) | |
# 设置推理参数并处理视频 | |
results = model.predict( | |
source=video_path, | |
device=device, | |
conf=conf_threshold, # 使用用户设置的置信度阈值 | |
save=False, | |
show=False, | |
stream=True, | |
line_width=2, | |
show_boxes=True, | |
show_labels=True, | |
show_conf=True, | |
vid_stride=1, | |
max_det=max_det, # 使用用户设置的最大检测数量 | |
) | |
# 处理结果 | |
frame_count = 0 | |
detection_info = [] | |
for r in results: | |
# 获取绘制了预测结果的帧 | |
frame = r.plot() | |
# 收集检测信息 | |
frame_info = { | |
"frame": frame_count + 1, | |
"count": len(r.boxes), | |
"detections": [] | |
} | |
for box in r.boxes: | |
conf = float(box.conf[0]) | |
cls = int(box.cls[0]) | |
cls_name = r.names[cls] | |
frame_info["detections"].append({ | |
"class": cls_name, | |
"confidence": f"{conf:.2%}" | |
}) | |
detection_info.append(frame_info) | |
# 写入视频 | |
video_writer.write(frame) | |
frame_count += 1 | |
if process_seconds and frame_count >= total_frames: | |
break | |
# 释放视频写入器 | |
video_writer.release() | |
# 生成分析报告 | |
report = f"""视频分析报告: | |
参数设置: | |
- 置信度阈值: {conf_threshold:.2f} | |
- 最大检测数量: {max_det} | |
- 处理时长: {process_seconds}秒 | |
分析结果: | |
- 处理帧数: {frame_count} | |
- 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f} | |
- 最大检测数: {max([info['count'] for info in detection_info])} | |
- 最小检测数: {min([info['count'] for info in detection_info])} | |
置信度分布: | |
{np.histogram([float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']], bins=5)[0].tolist()} | |
""" | |
return output_path, report | |
# 创建 Gradio 界面 | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)") | |
gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior") | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="输入视频") | |
process_seconds = gr.Number( | |
label="处理时长(秒,0表示处理整个视频)", | |
value=20 | |
) | |
conf_threshold = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.2, | |
step=0.05, | |
label="置信度阈值", | |
info="越高越严格,建议范围0.2-0.5" | |
) | |
max_det = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=8, | |
step=1, | |
label="最大检测数量", | |
info="每帧最多检测的目标数量" | |
) | |
process_btn = gr.Button("开始处理") | |
with gr.Column(): | |
video_output = gr.Video(label="检测结果") | |
report_output = gr.Textbox(label="分析报告") | |
process_btn.click( | |
fn=process_video, | |
inputs=[video_input, process_seconds, conf_threshold, max_det], | |
outputs=[video_output, report_output] | |
) | |
gr.Markdown(""" | |
### 使用说明 | |
1. 上传视频文件 | |
2. 设置处理参数: | |
- 处理时长:需要分析的视频时长(秒) | |
- 置信度阈值:检测的置信度要求(越高越严格) | |
- 最大检测数量:每帧最多检测的目标数量 | |
3. 等待处理完成 | |
4. 查看检测结果视频和分析报告 | |
### 注意事项 | |
- 支持常见视频格式(mp4, avi 等) | |
- 建议视频分辨率不超过 1920x1080 | |
- 处理时间与视频长度和分辨率相关 | |
- 置信度建议范围:0.2-0.5 | |
- 最大检测数量建议根据实际场景设置 | |
""") | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
auth=auth_user, | |
auth_message="请输入用户名和密码" | |
) |