mice-pose-gpu / app.py
Hakureirm's picture
Update app.py
b33582e verified
raw
history blame
5.89 kB
import gradio as gr
from ultralytics import YOLO
from fastapi import FastAPI
from PIL import Image
import torch
import spaces
import numpy as np
import cv2
from pathlib import Path
import tempfile
app = FastAPI()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = YOLO('kunin-mice-pose.v0.1.0.pt')
# 添加认证函数
def auth_user(username, password):
# 你可以添加多个用户
valid_credentials = {
"kunin": "123456",
"user1": "password1",
# 可以添加更多用户
}
return username in valid_credentials and valid_credentials[username] == password
@spaces.GPU
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
"""
处理视频并进行小鼠检测
Args:
video_path: 输入视频路径
process_seconds: 处理时长(秒)
conf_threshold: 置信度阈值(0-1)
max_det: 每帧最大检测数量
"""
# 创建临时目录保存输出视频
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
output_path = tmp_file.name
# 获取视频信息
cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
# 创建视频写入器
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(
output_path,
fourcc,
fps,
(width, height)
)
# 设置推理参数并处理视频
results = model.predict(
source=video_path,
device=device,
conf=conf_threshold, # 使用用户设置的置信度阈值
save=False,
show=False,
stream=True,
line_width=2,
show_boxes=True,
show_labels=True,
show_conf=True,
vid_stride=1,
max_det=max_det, # 使用用户设置的最大检测数量
)
# 处理结果
frame_count = 0
detection_info = []
for r in results:
# 获取绘制了预测结果的帧
frame = r.plot()
# 收集检测信息
frame_info = {
"frame": frame_count + 1,
"count": len(r.boxes),
"detections": []
}
for box in r.boxes:
conf = float(box.conf[0])
cls = int(box.cls[0])
cls_name = r.names[cls]
frame_info["detections"].append({
"class": cls_name,
"confidence": f"{conf:.2%}"
})
detection_info.append(frame_info)
# 写入视频
video_writer.write(frame)
frame_count += 1
if process_seconds and frame_count >= total_frames:
break
# 释放视频写入器
video_writer.release()
# 生成分析报告
report = f"""视频分析报告:
参数设置:
- 置信度阈值: {conf_threshold:.2f}
- 最大检测数量: {max_det}
- 处理时长: {process_seconds}
分析结果:
- 处理帧数: {frame_count}
- 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f}
- 最大检测数: {max([info['count'] for info in detection_info])}
- 最小检测数: {min([info['count'] for info in detection_info])}
置信度分布:
{np.histogram([float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']], bins=5)[0].tolist()}
"""
return output_path, report
# 创建 Gradio 界面
with gr.Blocks() as demo:
gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)")
gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior")
with gr.Row():
with gr.Column():
video_input = gr.Video(label="输入视频")
process_seconds = gr.Number(
label="处理时长(秒,0表示处理整个视频)",
value=20
)
conf_threshold = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.2,
step=0.05,
label="置信度阈值",
info="越高越严格,建议范围0.2-0.5"
)
max_det = gr.Slider(
minimum=1,
maximum=10,
value=8,
step=1,
label="最大检测数量",
info="每帧最多检测的目标数量"
)
process_btn = gr.Button("开始处理")
with gr.Column():
video_output = gr.Video(label="检测结果")
report_output = gr.Textbox(label="分析报告")
process_btn.click(
fn=process_video,
inputs=[video_input, process_seconds, conf_threshold, max_det],
outputs=[video_output, report_output]
)
gr.Markdown("""
### 使用说明
1. 上传视频文件
2. 设置处理参数:
- 处理时长:需要分析的视频时长(秒)
- 置信度阈值:检测的置信度要求(越高越严格)
- 最大检测数量:每帧最多检测的目标数量
3. 等待处理完成
4. 查看检测结果视频和分析报告
### 注意事项
- 支持常见视频格式(mp4, avi 等)
- 建议视频分辨率不超过 1920x1080
- 处理时间与视频长度和分辨率相关
- 置信度建议范围:0.2-0.5
- 最大检测数量建议根据实际场景设置
""")
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
auth=auth_user,
auth_message="请输入用户名和密码"
)