Hakureirm commited on
Commit
f8b6390
·
verified ·
1 Parent(s): 77d9dbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -8,7 +8,8 @@ import numpy as np
8
 
9
  app = FastAPI()
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
- model = YOLO('nailong_yolo11.onnx').to(device)
 
12
 
13
  @spaces.GPU
14
  def predict(img):
@@ -25,8 +26,8 @@ def predict(img):
25
  # 转换为numpy数组并进行预测
26
  img_array = np.array(input_image)
27
 
28
- # 进行预测
29
- results = model.predict(img_array)
30
  result = results[0]
31
 
32
  # 获取预测结果
@@ -36,7 +37,7 @@ def predict(img):
36
  info = {
37
  "detected": len(result.boxes) > 0,
38
  "count": len(result.boxes),
39
- "detections": [] # 存储每个检测目标的详细信息
40
  }
41
 
42
  if info["detected"]:
@@ -44,7 +45,7 @@ def predict(img):
44
  for box in result.boxes:
45
  conf = float(box.conf[0])
46
  cls = int(box.cls[0])
47
- cls_name = result.names[cls] # 获取类别名称
48
 
49
  detection_info = {
50
  "class": cls_name,
 
8
 
9
  app = FastAPI()
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+ # 移除 .to(device),在预测时指定设备
12
+ model = YOLO('nailong_yolo11.onnx')
13
 
14
  @spaces.GPU
15
  def predict(img):
 
26
  # 转换为numpy数组并进行预测
27
  img_array = np.array(input_image)
28
 
29
+ # 在predict时指定device
30
+ results = model.predict(img_array, device=device)
31
  result = results[0]
32
 
33
  # 获取预测结果
 
37
  info = {
38
  "detected": len(result.boxes) > 0,
39
  "count": len(result.boxes),
40
+ "detections": []
41
  }
42
 
43
  if info["detected"]:
 
45
  for box in result.boxes:
46
  conf = float(box.conf[0])
47
  cls = int(box.cls[0])
48
+ cls_name = result.names[cls]
49
 
50
  detection_info = {
51
  "class": cls_name,