Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from ultralytics import YOLO | |
from fastapi import FastAPI | |
from PIL import Image | |
import torch | |
import spaces | |
import numpy as np | |
import cv2 | |
from pathlib import Path | |
import tempfile | |
import imageio | |
from tqdm import tqdm | |
import logging | |
import torch.nn.functional as F | |
# 新增: 配置logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
logger = logging.getLogger(__name__) | |
# 从环境变量获取密码 | |
APP_USERNAME = "admin" # 用户名保持固定 | |
APP_PASSWORD = os.getenv("APP_PASSWORD", "default_password") # 从环境变量获取密码 | |
app = FastAPI() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = YOLO('kunin-mice-pose.v0.1.5n.pt') | |
# 定义认证状态 | |
class AuthState: | |
def __init__(self): | |
self.is_logged_in = False | |
auth_state = AuthState() | |
def login(username, password): | |
"""登录验证""" | |
logger.info(f"用户尝试登录: {username}") | |
if username == APP_USERNAME and password == APP_PASSWORD: | |
auth_state.is_logged_in = True | |
logger.info("登录成功") | |
return gr.update(visible=False), gr.update(visible=True), "登录成功" | |
logger.warning("登录失败:用户名或密码错误") | |
return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误" | |
def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8): | |
""" | |
处理视频并进行��鼠检测 | |
Args: | |
video_path: 输入视频路径 | |
process_seconds: 处理时长(秒) | |
conf_threshold: 置信度阈值(0-1) | |
max_det: 每帧最大检测数量 | |
""" | |
logger.info(f"开始处理视频: {video_path}") | |
logger.info(f"参数设置 - 处理时长: {process_seconds}秒, 置信度阈值: {conf_threshold}, 最大检测数: {max_det}") | |
if not auth_state.is_logged_in: | |
logger.warning("用户未登录,拒绝访问") | |
return None, "请先登录" | |
# 创建临时目录保存输出视频 | |
logger.info("创建临时输出目录") | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
output_path = tmp_file.name | |
# 获取视频信息 | |
logger.info("读取视频信息") | |
cap = cv2.VideoCapture(video_path) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
cap.release() | |
logger.info(f"视频信息 - FPS: {fps}, 分辨率: {width}x{height}, 总帧数: {total_frames}") | |
# 创建视频写入器 | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video_writer = cv2.VideoWriter( | |
output_path, | |
fourcc, | |
fps, | |
(width, height) | |
) | |
# 计算基于分辨率的线宽 | |
base_size = min(width, height) | |
line_thickness = max(1, int(base_size * 0.002)) # 0.2% 的最小边长 | |
logger.info("开始YOLO模型推理") | |
results = model.predict( | |
source=video_path, | |
device=device, | |
conf=conf_threshold, | |
save=False, | |
show=False, | |
stream=True, | |
line_width=line_thickness, | |
boxes=True, | |
show_labels=True, | |
show_conf=True, | |
vid_stride=1, | |
max_det=max_det, | |
retina_masks=True, | |
verbose=False # 关闭YOLO默认日志输出 | |
) | |
logger.info("开始处理检测结果") | |
frame_count = 0 | |
detection_info = [] | |
all_positions = [] | |
heatmap = np.zeros((height, width), dtype=np.float32) | |
# 新增: 创建进度条 | |
pbar = tqdm(total=total_frames, desc="处理视频", unit="帧") | |
for r in results: | |
frame = r.plot() | |
# 收集位置信息 | |
if hasattr(r, 'keypoints') and r.keypoints is not None: | |
kpts = r.keypoints.data | |
if isinstance(kpts, torch.Tensor): | |
kpts = kpts.cpu().numpy() | |
if kpts.shape == (1, 8, 3): # [num_objects, num_keypoints, xyz] | |
x, y = int(kpts[0, 0, 0]), int(kpts[0, 0, 1]) | |
all_positions.append([x, y]) | |
if 0 <= x < width and 0 <= y < height: | |
sigma = 10 | |
kernel_size = 31 | |
temp_heatmap = np.zeros((height, width), dtype=np.float32) | |
temp_heatmap[y, x] = 1 | |
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma) | |
heatmap += temp_heatmap | |
# 收集检测信息 | |
frame_info = { | |
"frame": frame_count + 1, | |
"count": len(r.boxes), | |
"detections": [] | |
} | |
for box in r.boxes: | |
conf = float(box.conf[0]) | |
cls = int(box.cls[0]) | |
cls_name = r.names[cls] | |
frame_info["detections"].append({ | |
"class": cls_name, | |
"confidence": f"{conf:.2%}" | |
}) | |
detection_info.append(frame_info) | |
video_writer.write(frame) | |
frame_count += 1 | |
pbar.update(1) # 更新进度条 | |
if process_seconds and frame_count >= total_frames: | |
break | |
pbar.close() # 关闭进度条 | |
video_writer.release() | |
logger.info(f"视频处理完成,共处理 {frame_count} 帧") | |
# 生成分析报告 | |
confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']] | |
hist, bins = np.histogram(confidences, bins=5) | |
confidence_report = "\n".join([ | |
f"置信度 {bins[i]:.2f}-{bins[i+1]:.2f}: {hist[i]:3d}个检测 ({hist[i]/len(confidences)*100:.1f}%)" | |
for i in range(len(hist)) | |
]) | |
report = f"""视频分析报告: | |
参数设置: | |
- 置信度阈值: {conf_threshold:.2f} | |
- 最大检测数量: {max_det} | |
- 处理时长: {process_seconds}秒 | |
分析结果: | |
- 处理帧数: {frame_count} | |
- 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f} | |
- 最大检测数: {max([info['count'] for info in detection_info])} | |
- 最小检测数: {min([info['count'] for info in detection_info])} | |
置信度分布: | |
{confidence_report} | |
""" | |
def filter_trajectories_gpu(positions, width, height, max_jump_distance=100): | |
"""GPU加速版本的轨迹过滤""" | |
if len(positions) < 3: | |
return positions | |
# 转换为GPU张量 | |
points = torch.tensor(positions, device=device, dtype=torch.float32) | |
# 计算相邻点之间的距离 | |
diffs = points[1:] - points[:-1] | |
distances = torch.norm(diffs, dim=1) | |
# 找出需要插值的位置 | |
mask = distances > max_jump_distance | |
valid_indices = (~mask).nonzero().squeeze() | |
if len(valid_indices) < 2: | |
return positions | |
# 使用GPU进行插值 | |
filtered_points = [] | |
last_valid_idx = 0 | |
for i in range(len(valid_indices)-1): | |
curr_idx = valid_indices[i].item() | |
next_idx = valid_indices[i+1].item() | |
filtered_points.append(points[curr_idx].tolist()) | |
if next_idx - curr_idx > 1: | |
# 线性插值 | |
steps = max(2, int((next_idx - curr_idx))) | |
interp_points = torch.linspace(0, 1, steps) | |
start_point = points[curr_idx] | |
end_point = points[next_idx] | |
interpolated = start_point[None] * (1 - interp_points[:, None]) + \ | |
end_point[None] * interp_points[:, None] | |
filtered_points.extend(interpolated[1:-1].tolist()) | |
filtered_points.append(points[valid_indices[-1]].tolist()) | |
# 平滑处理 | |
if len(filtered_points) >= 5: | |
points_tensor = torch.tensor(filtered_points, device=device) | |
kernel_size = 5 | |
padding = kernel_size // 2 | |
# 使用1D卷积进行平滑 | |
weights = torch.ones(1, 1, kernel_size, device=device) / kernel_size | |
smoothed_x = F.conv1d( | |
points_tensor[:, 0].view(1, 1, -1), | |
weights, | |
padding=padding | |
).squeeze() | |
smoothed_y = F.conv1d( | |
points_tensor[:, 1].view(1, 1, -1), | |
weights, | |
padding=padding | |
).squeeze() | |
smoothed_points = torch.stack([smoothed_x, smoothed_y], dim=1) | |
return smoothed_points.cpu().numpy().tolist() | |
return filtered_points | |
# 修改轨迹图生成部分 | |
trajectory_img = torch.ones((height, width, 3), device=device, dtype=torch.float32) | |
points = np.array(all_positions, dtype=np.int32) | |
if len(points) > 1: | |
filtered_points = filter_trajectories_gpu(points.tolist(), width, height) | |
points = np.array(filtered_points, dtype=np.int32) | |
for i in range(len(points) - 1): | |
ratio = i / (len(points) - 1) | |
color = torch.tensor([ | |
int((1 - ratio) * 255), # B | |
50, # G | |
int(ratio * 255) # R | |
], device=device, dtype=torch.float32) | |
# 使用GPU绘制线段 | |
pt1, pt2 = points[i], points[i + 1] | |
draw_line_gpu(trajectory_img, pt1, pt2, color, 2) | |
trajectory_img = trajectory_img.cpu().numpy().astype(np.uint8) | |
# 修改热力图生成部分 | |
if torch.cuda.is_available(): | |
logger.info("使用GPU生成热力图") | |
try: | |
heatmap = torch.zeros((height, width), device=device) | |
for pos in filtered_points: | |
# 确保坐标是整数并且在有效范围内 | |
x, y = map(int, pos) # 明确转换为整数 | |
if 0 <= x < width and 0 <= y < height: | |
temp_heatmap = torch.zeros((height, width), device=device) | |
temp_heatmap[int(y), int(x)] = 1 # 再次确保是整数 | |
# 使用GPU的高斯模糊 | |
temp_heatmap = gaussian_blur_gpu(temp_heatmap, kernel_size=31, sigma=10) | |
heatmap += temp_heatmap | |
heatmap = heatmap.cpu().numpy() | |
except Exception as e: | |
logger.error(f"GPU热力图生成失败: {str(e)}") | |
# 回退到CPU处理 | |
logger.info("切换到CPU生成热力图") | |
heatmap = np.zeros((height, width), dtype=np.float32) | |
for pos in filtered_points: | |
x, y = map(int, pos) | |
if 0 <= x < width and 0 <= y < height: | |
temp_heatmap = np.zeros((height, width), dtype=np.float32) | |
temp_heatmap[int(y), int(x)] = 1 | |
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10) | |
heatmap += temp_heatmap | |
else: | |
logger.info("使用CPU生成热力图") | |
heatmap = np.zeros((height, width), dtype=np.float32) | |
for pos in filtered_points: | |
x, y = map(int, pos) | |
if 0 <= x < width and 0 <= y < height: | |
temp_heatmap = np.zeros((height, width), dtype=np.float32) | |
temp_heatmap[int(y), int(x)] = 1 | |
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10) | |
heatmap += temp_heatmap | |
trajectory_frames = [] | |
heatmap_frames = [] | |
base_trajectory = np.zeros((height, width, 3), dtype=np.uint8) + 255 | |
base_heatmap = np.zeros((height, width), dtype=np.float32) | |
frame_interval = max(1, len(filtered_points) // 30) | |
for i in range(0, len(filtered_points), frame_interval): | |
current_points = filtered_points[:i+1] | |
frame_trajectory = base_trajectory.copy() | |
if len(current_points) > 1: | |
points = np.array(current_points, dtype=np.int32) | |
for j in range(len(points) - 1): | |
ratio = j / (len(current_points) - 1) | |
color = ( | |
int((1 - ratio) * 255), | |
50, | |
int(ratio * 255) | |
) | |
cv2.line(frame_trajectory, tuple(points[j]), tuple(points[j + 1]), color, 2) | |
cv2.circle(frame_trajectory, tuple(points[-1]), 8, (0, 0, 255), -1) | |
trajectory_frames.append(frame_trajectory) | |
frame_heatmap = base_heatmap.copy() | |
for x, y in current_points: | |
if 0 <= x < width and 0 <= y < height: | |
temp_heatmap = np.zeros((height, width), dtype=np.float32) | |
temp_heatmap[y, x] = 1 | |
temp_heatmap = cv2.GaussianBlur(temp_heatmap, (31, 31), 10) | |
frame_heatmap += temp_heatmap | |
if np.max(frame_heatmap) > 0: | |
frame_heatmap_norm = cv2.normalize(frame_heatmap, None, 0, 255, cv2.NORM_MINMAX) | |
frame_heatmap_color = cv2.applyColorMap(frame_heatmap_norm.astype(np.uint8), cv2.COLORMAP_JET) | |
frame_heatmap_color = cv2.addWeighted(frame_heatmap_color, 0.7, np.full_like(frame_heatmap_color, 255), 0.3, 0) | |
heatmap_frames.append(frame_heatmap_color) | |
logger.info("开始生成轨迹图和热力图") | |
trajectory_gif_path = output_path.replace('.mp4', '_trajectory.gif') | |
heatmap_gif_path = output_path.replace('.mp4', '_heatmap.gif') | |
imageio.mimsave(trajectory_gif_path, trajectory_frames, duration=50) | |
imageio.mimsave(heatmap_gif_path, heatmap_frames, duration=50) | |
trajectory_path = output_path.replace('.mp4', '_trajectory.png') | |
heatmap_path = output_path.replace('.mp4', '_heatmap.png') | |
cv2.imwrite(trajectory_path, trajectory_img) | |
if np.max(heatmap) > 0: | |
heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX) | |
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET) | |
heatmap_colored = cv2.addWeighted(heatmap_colored, 0.7, np.full_like(heatmap_colored, 255), 0.3, 0) | |
cv2.imwrite(heatmap_path, heatmap_colored) | |
logger.info("轨迹图和热力图生成完成") | |
logger.info("开始生成GIF动画") | |
imageio.mimsave(trajectory_gif_path, trajectory_frames, duration=50) | |
imageio.mimsave(heatmap_gif_path, heatmap_frames, duration=50) | |
logger.info("GIF动画生成完成") | |
logger.info("所有处理完成,准备返回结果") | |
return output_path, trajectory_path, heatmap_path, trajectory_gif_path, heatmap_gif_path, report | |
def gaussian_blur_gpu(tensor, kernel_size=31, sigma=10): | |
"""GPU版本的高斯模糊""" | |
channels = 1 | |
kernel = get_gaussian_kernel2d(kernel_size, sigma).to(device) | |
kernel = kernel.view(1, 1, kernel_size, kernel_size) | |
tensor = tensor.view(1, 1, tensor.shape[0], tensor.shape[1]) | |
return F.conv2d(tensor, kernel, padding=kernel_size//2).squeeze() | |
def get_gaussian_kernel2d(kernel_size, sigma): | |
"""生成2D高斯核""" | |
kernel_x = torch.linspace(-kernel_size//2, kernel_size//2, kernel_size) | |
x, y = torch.meshgrid(kernel_x, kernel_x, indexing='ij') | |
kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2)) | |
return kernel / kernel.sum() | |
def draw_line_gpu(image, pt1, pt2, color, thickness=1): | |
"""GPU版本的线段绘制""" | |
x1, y1 = map(int, pt1) # 确保是整数 | |
x2, y2 = map(int, pt2) # 确保是整数 | |
dx = abs(x2 - x1) | |
dy = abs(y2 - y1) | |
# 防止除零错误 | |
steps = max(dx, dy) | |
if steps == 0: | |
# 如果是同一个点,直接画点 | |
if 0 <= x1 < image.shape[1] and 0 <= y1 < image.shape[0]: | |
image[y1, x1] = color | |
return | |
x_inc = (x2 - x1) / steps | |
y_inc = (y2 - y1) / steps | |
x = x1 | |
y = y1 | |
points = torch.zeros((int(steps) + 1, 2), device=device) | |
for i in range(int(steps) + 1): | |
points[i] = torch.tensor([x, y]) | |
x += x_inc | |
y += y_inc | |
points = points.long() # 转换为整数类型 | |
valid_points = (points[:, 0] >= 0) & (points[:, 0] < image.shape[1]) & \ | |
(points[:, 1] >= 0) & (points[:, 1] < image.shape[0]) | |
points = points[valid_points] | |
color = color.to(image.dtype) | |
if thickness > 1: | |
for dx in range(-thickness//2, thickness//2 + 1): | |
for dy in range(-thickness//2, thickness//2 + 1): | |
offset_points = points + torch.tensor([dx, dy], device=device, dtype=torch.long) | |
valid_offset = (offset_points[:, 0] >= 0) & (offset_points[:, 0] < image.shape[1]) & \ | |
(offset_points[:, 1] >= 0) & (offset_points[:, 1] < image.shape[0]) | |
offset_points = offset_points[valid_offset] | |
image[offset_points[:, 1], offset_points[:, 0]] = color | |
else: | |
image[points[:, 1], points[:, 0]] = color | |
# 创建 Gradio 界面 | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🐁 小鼠行为分析 (Mice Behavior Analysis)") | |
with gr.Group() as login_interface: | |
username = gr.Textbox(label="用户名") | |
password = gr.Textbox(label="密码", type="password") | |
login_button = gr.Button("登录") | |
login_msg = gr.Textbox(label="消息", interactive=False) | |
with gr.Group(visible=False) as main_interface: | |
gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior") | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="输入视频") | |
process_seconds = gr.Number( | |
label="处理时长(秒,0表示处理整个视频)", | |
value=20 | |
) | |
conf_threshold = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.2, | |
step=0.05, | |
label="置信度阈值", | |
info="越高越严格,建议范围0.2-0.5" | |
) | |
max_det = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=1, | |
step=1, | |
label="最大检测数量", | |
info="每帧最多检测的目标数量" | |
) | |
process_btn = gr.Button("开始处理") | |
with gr.Column(): | |
video_output = gr.Video(label="检测结果") | |
with gr.Row(): | |
trajectory_output = gr.Image(label="运动轨迹") | |
trajectory_gif_output = gr.Image(label="轨迹动画") | |
with gr.Row(): | |
heatmap_output = gr.Image(label="热力图") | |
heatmap_gif_output = gr.Image(label="热力图动画") | |
report_output = gr.Textbox(label="分析报告") | |
gr.Markdown(""" | |
### 使用说明 | |
1. 上传视频文件 | |
2. 设置处理参数: | |
- 处理时间:需要分析的视频时长(秒) | |
- 置信度阈值:检测的置信度要求(越高越严格) | |
- 最大检测数量:每帧最多检测的目标数量 | |
3. 等待处理完成 | |
4. 查看检测结果视频和分析报告 | |
### 注意事项 | |
- 支持常见视频格式(mp4, avi 等) | |
- 建议视频分辨率不超过 1920x1080 | |
- 处理时间与视频长度和分辨率相关 | |
- 置信度建议范围:0.2-0.5 | |
- 最大检测数量建议根据实际场景设置 | |
""") | |
login_button.click( | |
fn=login, | |
inputs=[username, password], | |
outputs=[login_interface, main_interface, login_msg] | |
) | |
process_btn.click( | |
fn=process_video, | |
inputs=[video_input, process_seconds, conf_threshold, max_det], | |
outputs=[video_output, trajectory_output, heatmap_output, | |
trajectory_gif_output, heatmap_gif_output, report_output] | |
) | |
if __name__ == "__main__": | |
try: | |
# GPU相关操作 | |
if torch.cuda.is_available(): | |
logger.info("使用GPU进行轨迹和热力图计算") | |
# ... GPU操作 ... | |
else: | |
logger.info("使用CPU进行轨迹和热力图计算") | |
# ... CPU操作 ... | |
except Exception as e: | |
logger.error(f"处理轨迹和热力图时出错: {str(e)}") | |
raise | |
demo.launch(server_name="0.0.0.0", server_port=7860) |