Update app.py
Browse files
app.py
CHANGED
@@ -3,12 +3,24 @@ import cv2
|
|
3 |
import numpy as np
|
4 |
import gradio as gr
|
5 |
import tempfile
|
|
|
6 |
from mouse_tracker import MouseTrackerAnalyzer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# 全局变量
|
9 |
analyzer = None
|
10 |
video_file_path = None
|
11 |
-
model_file_path = "
|
12 |
total_frames = 0
|
13 |
output_path = None
|
14 |
|
@@ -82,7 +94,17 @@ def preview_frame(video_file, frame_num):
|
|
82 |
return frame, f"帧 {frame_num}"
|
83 |
|
84 |
# 开始分析
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
global analyzer, output_path, model_file_path
|
87 |
|
88 |
if not video:
|
@@ -97,6 +119,26 @@ def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold)
|
|
97 |
csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
|
98 |
|
99 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
# 创建分析器
|
101 |
analyzer = MouseTrackerAnalyzer(
|
102 |
model_path=model_file_path,
|
@@ -187,8 +229,13 @@ def create_interface():
|
|
187 |
# 只保留视频选择,移除模型选择
|
188 |
video_input = gr.Video(label="输入视频")
|
189 |
|
190 |
-
#
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
# 参数设置
|
194 |
with gr.Row():
|
@@ -251,6 +298,17 @@ def create_interface():
|
|
251 |
|
252 |
# 启动应用
|
253 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
# 检查模型文件是否存在
|
256 |
if not os.path.exists(model_file_path):
|
@@ -259,5 +317,11 @@ if __name__ == "__main__":
|
|
259 |
print(f"使用模型: {model_file_path}")
|
260 |
|
261 |
app = create_interface()
|
262 |
-
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import numpy as np
|
4 |
import gradio as gr
|
5 |
import tempfile
|
6 |
+
import torch
|
7 |
from mouse_tracker import MouseTrackerAnalyzer
|
8 |
+
import huggingface_hub
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
|
11 |
+
# 检查是否在Hugging Face Spaces环境中
|
12 |
+
try:
|
13 |
+
import spaces
|
14 |
+
is_spaces = True
|
15 |
+
print("检测到Hugging Face Spaces环境")
|
16 |
+
except ImportError:
|
17 |
+
is_spaces = False
|
18 |
+
print("在本地环境运行")
|
19 |
|
20 |
# 全局变量
|
21 |
analyzer = None
|
22 |
video_file_path = None
|
23 |
+
model_file_path = "weights/fst-v1.2-n.onnx" # 直接指定模型文件路径
|
24 |
total_frames = 0
|
25 |
output_path = None
|
26 |
|
|
|
94 |
return frame, f"帧 {frame_num}"
|
95 |
|
96 |
# 开始分析
|
97 |
+
# 为HF Spaces环境添加GPU装饰器
|
98 |
+
if is_spaces:
|
99 |
+
@spaces.GPU(duration=120) # 申请GPU资源,持续120秒
|
100 |
+
def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
|
101 |
+
return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
|
102 |
+
else:
|
103 |
+
def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
|
104 |
+
return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
|
105 |
+
|
106 |
+
# 实际的分析实现
|
107 |
+
def _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold):
|
108 |
global analyzer, output_path, model_file_path
|
109 |
|
110 |
if not video:
|
|
|
119 |
csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
|
120 |
|
121 |
try:
|
122 |
+
# 检查设备
|
123 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
124 |
+
print(f"使用设备: {device}")
|
125 |
+
|
126 |
+
# 确保模型文件存在
|
127 |
+
if not os.path.exists(model_file_path):
|
128 |
+
# 如果在Hugging Face Spaces环境中,尝试从Hub下载模型
|
129 |
+
if is_spaces:
|
130 |
+
try:
|
131 |
+
print(f"尝试从Hugging Face Hub下载模型: {os.path.basename(model_file_path)}")
|
132 |
+
model_file_path = hf_hub_download(
|
133 |
+
repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME", # 替换为您的仓库
|
134 |
+
filename="weights/fst-v1.2-n.onnx"
|
135 |
+
)
|
136 |
+
print(f"模型已下载到: {model_file_path}")
|
137 |
+
except Exception as e:
|
138 |
+
print(f"从Hub下载模型失败: {str(e)}")
|
139 |
+
else:
|
140 |
+
print(f"警告: 模型文件 {model_file_path} 不存在!")
|
141 |
+
|
142 |
# 创建分析器
|
143 |
analyzer = MouseTrackerAnalyzer(
|
144 |
model_path=model_file_path,
|
|
|
229 |
# 只保留视频选择,移除模型选择
|
230 |
video_input = gr.Video(label="输入视频")
|
231 |
|
232 |
+
# 显示当前使用的模型和设备信息
|
233 |
+
device_info = "GPU" if torch.cuda.is_available() else "CPU"
|
234 |
+
model_info = gr.Textbox(
|
235 |
+
label="系统信息",
|
236 |
+
value=f"使用模型: {os.path.basename(model_file_path)} | 计算设备: {device_info}",
|
237 |
+
interactive=False
|
238 |
+
)
|
239 |
|
240 |
# 参数设置
|
241 |
with gr.Row():
|
|
|
298 |
|
299 |
# 启动应用
|
300 |
if __name__ == "__main__":
|
301 |
+
# 清除可能干扰的代理设置
|
302 |
+
if 'http_proxy' in os.environ:
|
303 |
+
del os.environ['http_proxy']
|
304 |
+
if 'https_proxy' in os.environ:
|
305 |
+
del os.environ['https_proxy']
|
306 |
+
if 'all_proxy' in os.environ:
|
307 |
+
del os.environ['all_proxy']
|
308 |
+
|
309 |
+
# 检查设备和模型
|
310 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
311 |
+
print(f"使用设备: {device}")
|
312 |
|
313 |
# 检查模型文件是否存在
|
314 |
if not os.path.exists(model_file_path):
|
|
|
317 |
print(f"使用模型: {model_file_path}")
|
318 |
|
319 |
app = create_interface()
|
320 |
+
|
321 |
+
# 根据环境决定启动方式
|
322 |
+
if is_spaces:
|
323 |
+
# Hugging Face Spaces环境中的启动方式
|
324 |
+
app.launch()
|
325 |
+
else:
|
326 |
+
# 本地环境的启动方式
|
327 |
+
app.launch(server_name="127.0.0.1", server_port=7860, share=False)
|