Hakureirm commited on
Commit
26be8c4
·
verified ·
1 Parent(s): 457c927

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -45
app.py CHANGED
@@ -1,48 +1,54 @@
1
  import gradio as gr
2
  from ultralytics import YOLO
3
  from fastapi import FastAPI
4
- import cv2
5
  import torch
6
  import spaces
7
  import numpy as np
 
8
  from pathlib import Path
 
9
 
10
  app = FastAPI()
11
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
- model = YOLO('kunin-mice-pose.v0.1.0.pt') # 使用你的小鼠检测模型
13
 
14
  @spaces.GPU
15
- def process_video(video_path, process_seconds=20):
16
- # 创建临时输出路径
17
- output_path = Path("temp_output.mp4")
 
 
 
 
 
 
 
 
 
18
 
19
  # 获取视频信息
20
  cap = cv2.VideoCapture(video_path)
21
  fps = int(cap.get(cv2.CAP_PROP_FPS))
22
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
23
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
24
- total_frames = int(process_seconds * fps)
 
25
 
26
  # 创建视频写入器
27
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
28
  video_writer = cv2.VideoWriter(
29
- str(output_path),
30
  fourcc,
31
  fps,
32
  (width, height)
33
  )
34
 
35
- # 统计信息
36
- frame_count = 0
37
- total_detections = 0
38
- max_mice = 0
39
- detection_stats = []
40
-
41
- # 处理视频
42
  results = model.predict(
43
  source=video_path,
44
  device=device,
45
- conf=0.5,
46
  save=False,
47
  show=False,
48
  stream=True,
@@ -51,60 +57,109 @@ def process_video(video_path, process_seconds=20):
51
  show_labels=True,
52
  show_conf=True,
53
  vid_stride=1,
 
54
  )
55
 
 
 
 
 
56
  for r in results:
57
- # 获取当前帧的检测结果
58
  frame = r.plot()
59
- num_mice = len(r.boxes)
60
- max_mice = max(max_mice, num_mice)
61
- total_detections += num_mice
62
- detection_stats.append(num_mice)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # 写入视频
65
  video_writer.write(frame)
66
 
67
  frame_count += 1
68
- if frame_count >= total_frames:
69
  break
70
 
71
- # 释放资源
72
  video_writer.release()
73
- cap.release()
74
 
75
- # 生成统计信息
76
- avg_mice = total_detections / frame_count if frame_count > 0 else 0
77
- output_text = f"""小鼠检测统计:
 
 
78
  - 处理时长: {process_seconds}秒
79
- - 总帧数: {frame_count}
80
- - 最大检测数量: {max_mice}只
81
- - 平均检测数量: {avg_mice:.1f}
82
- - 检测帧率: {fps} FPS"""
83
-
84
- # 如果有检测到的帧
85
- if detection_stats:
86
- output_text += "\n\n帧检测分布:"
87
- for count in range(max(detection_stats) + 1):
88
- frames = detection_stats.count(count)
89
- percentage = frames / frame_count * 100
90
- output_text += f"\n{count}只小鼠: {frames}帧 ({percentage:.1f}%)"
91
 
92
- return str(output_path), output_text
93
 
94
  demo = gr.Interface(
95
  fn=process_video,
96
  inputs=[
97
  gr.Video(label="输入视频"),
98
- gr.Slider(minimum=1, maximum=60, value=20, step=1, label="处理时长(秒)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  ],
100
  outputs=[
101
  gr.Video(label="检测结果"),
102
- gr.Textbox(label="检测统计")
103
  ],
104
- title="🐁 小鼠行为检测系统",
105
- description="上传视频来检测小鼠 | Upload a video to detect mice",
106
- # examples=[["example.mp4", 20]],
107
- # cache_examples=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
109
 
110
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from ultralytics import YOLO
3
  from fastapi import FastAPI
4
+ from PIL import Image
5
  import torch
6
  import spaces
7
  import numpy as np
8
+ import cv2
9
  from pathlib import Path
10
+ import tempfile
11
 
12
  app = FastAPI()
13
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ model = YOLO('kunin-mice-pose.v0.1.0.pt')
15
 
16
  @spaces.GPU
17
+ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8):
18
+ """
19
+ 处理视频并进行小鼠检测
20
+ Args:
21
+ video_path: 输入视频路径
22
+ process_seconds: 处理时长(秒)
23
+ conf_threshold: 置信度阈值(0-1)
24
+ max_det: 每帧最大检测数量
25
+ """
26
+ # 创建临时目录保存输出视频
27
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
28
+ output_path = tmp_file.name
29
 
30
  # 获取视频信息
31
  cap = cv2.VideoCapture(video_path)
32
  fps = int(cap.get(cv2.CAP_PROP_FPS))
33
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
35
+ total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
36
+ cap.release()
37
 
38
  # 创建视频写入器
39
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
40
  video_writer = cv2.VideoWriter(
41
+ output_path,
42
  fourcc,
43
  fps,
44
  (width, height)
45
  )
46
 
47
+ # 设置推理参数并处理视频
 
 
 
 
 
 
48
  results = model.predict(
49
  source=video_path,
50
  device=device,
51
+ conf=conf_threshold, # 使用用户设置的置信度阈值
52
  save=False,
53
  show=False,
54
  stream=True,
 
57
  show_labels=True,
58
  show_conf=True,
59
  vid_stride=1,
60
+ max_det=max_det, # 使用用户设置的最大检测数量
61
  )
62
 
63
+ # 处理结果
64
+ frame_count = 0
65
+ detection_info = []
66
+
67
  for r in results:
68
+ # 获取绘制了预测结果的帧
69
  frame = r.plot()
70
+
71
+ # 收集检测信息
72
+ frame_info = {
73
+ "frame": frame_count + 1,
74
+ "count": len(r.boxes),
75
+ "detections": []
76
+ }
77
+
78
+ for box in r.boxes:
79
+ conf = float(box.conf[0])
80
+ cls = int(box.cls[0])
81
+ cls_name = r.names[cls]
82
+ frame_info["detections"].append({
83
+ "class": cls_name,
84
+ "confidence": f"{conf:.2%}"
85
+ })
86
+
87
+ detection_info.append(frame_info)
88
 
89
  # 写入视频
90
  video_writer.write(frame)
91
 
92
  frame_count += 1
93
+ if process_seconds and frame_count >= total_frames:
94
  break
95
 
96
+ # 释放视频写入器
97
  video_writer.release()
 
98
 
99
+ # 生成分析报告
100
+ report = f"""视频分析报告:
101
+ 参数设置:
102
+ - 置信度阈值: {conf_threshold:.2f}
103
+ - 最大检测数量: {max_det}
104
  - 处理时长: {process_seconds}秒
105
+
106
+ 分析结果:
107
+ - 处理帧数: {frame_count}
108
+ - 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f}
109
+ - 最大检测数: {max([info['count'] for info in detection_info])}
110
+ - 最小检测数: {min([info['count'] for info in detection_info])}
111
+
112
+ 置信度分布:
113
+ {np.histogram([float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']], bins=5)[0].tolist()}
114
+ """
 
 
115
 
116
+ return output_path, report
117
 
118
  demo = gr.Interface(
119
  fn=process_video,
120
  inputs=[
121
  gr.Video(label="输入视频"),
122
+ gr.Number(label="处理时长(秒,0表示处理整个视频)", value=20),
123
+ gr.Slider(
124
+ minimum=0.1,
125
+ maximum=1.0,
126
+ value=0.2,
127
+ step=0.05,
128
+ label="置信度阈值",
129
+ info="越高越严格,建议范围0.2-0.5"
130
+ ),
131
+ gr.Slider(
132
+ minimum=1,
133
+ maximum=10,
134
+ value=8,
135
+ step=1,
136
+ label="最大检测数量",
137
+ info="每帧最多检测的目标数量"
138
+ )
139
  ],
140
  outputs=[
141
  gr.Video(label="检测结果"),
142
+ gr.Textbox(label="分析报告")
143
  ],
144
+ title="🐁 小鼠行为分析 (Mice Behavior Analysis)",
145
+ description="上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior",
146
+ article="""
147
+ ### 使用说明
148
+ 1. 上传视频文件
149
+ 2. 设置处理参数:
150
+ - 处理时长:需要分析的视频时长(秒)
151
+ - 置信度阈值:检测的置信度要求(越高越严格)
152
+ - 最大检测数量:每帧最多检测的目标数量
153
+ 3. 等待处理完成
154
+ 4. 查看检测结果视频和分析报告
155
+
156
+ ### 注意事项
157
+ - 支持常见视频格式(mp4, avi 等)
158
+ - 建议视频分辨率不超过 1920x1080
159
+ - 处理时间与视频长度和分辨率相关
160
+ - 置信度建议范围:0.2-0.5
161
+ - 最大检测数量建议根据实际场景设置
162
+ """
163
  )
164
 
165
  if __name__ == "__main__":