Hakureirm commited on
Commit
b9cb794
·
verified ·
1 Parent(s): 9adcbdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -6
app.py CHANGED
@@ -3,12 +3,24 @@ import cv2
3
  import numpy as np
4
  import gradio as gr
5
  import tempfile
 
6
  from mouse_tracker import MouseTrackerAnalyzer
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # 全局变量
9
  analyzer = None
10
  video_file_path = None
11
- model_file_path = "./fst-v1.2-n.onnx" # 直接指定模型文件路径
12
  total_frames = 0
13
  output_path = None
14
 
@@ -82,7 +94,17 @@ def preview_frame(video_file, frame_num):
82
  return frame, f"帧 {frame_num}"
83
 
84
  # 开始分析
85
- def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
 
 
 
 
 
 
 
 
 
 
86
  global analyzer, output_path, model_file_path
87
 
88
  if not video:
@@ -97,6 +119,26 @@ def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold)
97
  csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
98
 
99
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # 创建分析器
101
  analyzer = MouseTrackerAnalyzer(
102
  model_path=model_file_path,
@@ -187,8 +229,13 @@ def create_interface():
187
  # 只保留视频选择,移除模型选择
188
  video_input = gr.Video(label="输入视频")
189
 
190
- # 显示当前使用的模型
191
- model_info = gr.Textbox(label="模型信息", value=f"使用模型: {os.path.basename(model_file_path)}", interactive=False)
 
 
 
 
 
192
 
193
  # 参数设置
194
  with gr.Row():
@@ -251,6 +298,17 @@ def create_interface():
251
 
252
  # 启动应用
253
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  # 检查模型文件是否存在
256
  if not os.path.exists(model_file_path):
@@ -259,5 +317,11 @@ if __name__ == "__main__":
259
  print(f"使用模型: {model_file_path}")
260
 
261
  app = create_interface()
262
- # 使用简化的启动配置
263
- app.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
 
3
  import numpy as np
4
  import gradio as gr
5
  import tempfile
6
+ import torch
7
  from mouse_tracker import MouseTrackerAnalyzer
8
+ import huggingface_hub
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ # 检查是否在Hugging Face Spaces环境中
12
+ try:
13
+ import spaces
14
+ is_spaces = True
15
+ print("检测到Hugging Face Spaces环境")
16
+ except ImportError:
17
+ is_spaces = False
18
+ print("在本地环境运行")
19
 
20
  # 全局变量
21
  analyzer = None
22
  video_file_path = None
23
+ model_file_path = "weights/fst-v1.2-n.onnx" # 直接指定模型文件路径
24
  total_frames = 0
25
  output_path = None
26
 
 
94
  return frame, f"帧 {frame_num}"
95
 
96
  # 开始分析
97
+ # 为HF Spaces环境添加GPU装饰器
98
+ if is_spaces:
99
+ @spaces.GPU(duration=120) # 申请GPU资源,持续120秒
100
+ def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
101
+ return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
102
+ else:
103
+ def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
104
+ return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
105
+
106
+ # 实际的分析实现
107
+ def _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold):
108
  global analyzer, output_path, model_file_path
109
 
110
  if not video:
 
119
  csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
120
 
121
  try:
122
+ # 检查设备
123
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
124
+ print(f"使用设备: {device}")
125
+
126
+ # 确保模型文件存在
127
+ if not os.path.exists(model_file_path):
128
+ # 如果在Hugging Face Spaces环境中,尝试从Hub下载模型
129
+ if is_spaces:
130
+ try:
131
+ print(f"尝试从Hugging Face Hub下载模型: {os.path.basename(model_file_path)}")
132
+ model_file_path = hf_hub_download(
133
+ repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME", # 替换为您的仓库
134
+ filename="weights/fst-v1.2-n.onnx"
135
+ )
136
+ print(f"模型已下载到: {model_file_path}")
137
+ except Exception as e:
138
+ print(f"从Hub下载模型失败: {str(e)}")
139
+ else:
140
+ print(f"警告: 模型文件 {model_file_path} 不存在!")
141
+
142
  # 创建分析器
143
  analyzer = MouseTrackerAnalyzer(
144
  model_path=model_file_path,
 
229
  # 只保留视频选择,移除模型选择
230
  video_input = gr.Video(label="输入视频")
231
 
232
+ # 显示当前使用的模型和设备信息
233
+ device_info = "GPU" if torch.cuda.is_available() else "CPU"
234
+ model_info = gr.Textbox(
235
+ label="系统信息",
236
+ value=f"使用模型: {os.path.basename(model_file_path)} | 计算设备: {device_info}",
237
+ interactive=False
238
+ )
239
 
240
  # 参数设置
241
  with gr.Row():
 
298
 
299
  # 启动应用
300
  if __name__ == "__main__":
301
+ # 清除可能干扰的代理设置
302
+ if 'http_proxy' in os.environ:
303
+ del os.environ['http_proxy']
304
+ if 'https_proxy' in os.environ:
305
+ del os.environ['https_proxy']
306
+ if 'all_proxy' in os.environ:
307
+ del os.environ['all_proxy']
308
+
309
+ # 检查设备和模型
310
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
311
+ print(f"使用设备: {device}")
312
 
313
  # 检查模型文件是否存在
314
  if not os.path.exists(model_file_path):
 
317
  print(f"使用模型: {model_file_path}")
318
 
319
  app = create_interface()
320
+
321
+ # 根据环境决定启动方式
322
+ if is_spaces:
323
+ # Hugging Face Spaces环境中的启动方式
324
+ app.launch()
325
+ else:
326
+ # 本地环境的启动方式
327
+ app.launch(server_name="127.0.0.1", server_port=7860, share=False)