Hakureirm commited on
Commit
0ce6f4c
·
1 Parent(s): 49dfadb

Migrate to ZeroGPU and update requirements for compatibility

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -4,12 +4,19 @@ from fastapi import FastAPI, File, UploadFile
4
  from PIL import Image
5
  import numpy as np
6
  import io
 
 
7
 
8
  # 初始化 FastAPI 和模型
9
  app = FastAPI()
10
- model = YOLO('NailongKiller.yolo11n.pt')
11
 
 
 
 
 
 
12
  def predict(img):
 
13
  results = model.predict(img)
14
  return results[0].plot()
15
 
@@ -33,16 +40,15 @@ demo = gr.Interface(
33
  # API 端点
34
  @app.post("/detect/")
35
  async def detect_api(file: UploadFile = File(...)):
36
- # 读取上传的图片
37
  contents = await file.read()
38
  image = Image.open(io.BytesIO(contents))
39
  image_np = np.array(image)
40
 
41
- # 运行推理
 
42
  results = model.predict(image_np)
43
  result = results[0]
44
 
45
- # 返回检测结果
46
  detections = []
47
  for box in result.boxes:
48
  detection = {
 
4
  from PIL import Image
5
  import numpy as np
6
  import io
7
+ import torch
8
+ import spaces # 导入 spaces 模块
9
 
10
  # 初始化 FastAPI 和模型
11
  app = FastAPI()
 
12
 
13
+ # 检查 GPU 是否可用,并选择设备
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ model = YOLO('NailongKiller.yolo11n.pt').to(device)
16
+
17
+ @spaces.GPU # 使用装饰器标记需要 GPU 的函数
18
  def predict(img):
19
+ img = img.to(device)
20
  results = model.predict(img)
21
  return results[0].plot()
22
 
 
40
  # API 端点
41
  @app.post("/detect/")
42
  async def detect_api(file: UploadFile = File(...)):
 
43
  contents = await file.read()
44
  image = Image.open(io.BytesIO(contents))
45
  image_np = np.array(image)
46
 
47
+ image_np = torch.from_numpy(image_np).to(device)
48
+
49
  results = model.predict(image_np)
50
  result = results[0]
51
 
 
52
  detections = []
53
  for box in result.boxes:
54
  detection = {