Hakureirm commited on
Commit
2150861
·
1 Parent(s): 6136f56

Migrate to ZeroGPU and update requirements for compatibility

Browse files
Files changed (1) hide show
  1. app.py +4 -28
app.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
  import numpy as np
6
  import io
7
  import torch
8
- import spaces # 导入 spaces 模块
9
 
10
  # 初始化 FastAPI 和模型
11
  app = FastAPI()
@@ -14,10 +14,10 @@ app = FastAPI()
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  model = YOLO('NailongKiller.yolo11n.pt').to(device)
16
 
17
- @spaces.GPU # 使用装饰器标记需要 GPU 的函数
18
  def predict(img):
19
- # numpy 数组转换为 PyTorch 张量
20
- img_tensor = torch.from_numpy(img).to(device)
21
  results = model.predict(img_tensor)
22
  return results[0].plot()
23
 
@@ -38,30 +38,6 @@ demo = gr.Interface(
38
  cache_examples=True
39
  )
40
 
41
- # API 端点
42
- @app.post("/detect/")
43
- async def detect_api(file: UploadFile = File(...)):
44
- contents = await file.read()
45
- image = Image.open(io.BytesIO(contents))
46
- image_np = np.array(image)
47
-
48
- # 将图像移动到 GPU
49
- image_tensor = torch.from_numpy(image_np).to(device)
50
-
51
- results = model.predict(image_tensor)
52
- result = results[0]
53
-
54
- detections = []
55
- for box in result.boxes:
56
- detection = {
57
- "bbox": box.xyxy[0].tolist(),
58
- "confidence": float(box.conf[0]),
59
- "class": int(box.cls[0])
60
- }
61
- detections.append(detection)
62
-
63
- return {"detections": detections}
64
-
65
  # 挂载 Gradio 到 FastAPI
66
  app = gr.mount_gradio_app(app, demo, path="/")
67
 
 
5
  import numpy as np
6
  import io
7
  import torch
8
+ import spaces
9
 
10
  # 初始化 FastAPI 和模型
11
  app = FastAPI()
 
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  model = YOLO('NailongKiller.yolo11n.pt').to(device)
16
 
17
+ @spaces.GPU
18
  def predict(img):
19
+ img_resized = np.array(Image.fromarray(img).resize((640, 640)))
20
+ img_tensor = torch.from_numpy(img_resized).permute(2, 0, 1).unsqueeze(0).to(device)
21
  results = model.predict(img_tensor)
22
  return results[0].plot()
23
 
 
38
  cache_examples=True
39
  )
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # 挂载 Gradio 到 FastAPI
42
  app = gr.mount_gradio_app(app, demo, path="/")
43