File size: 1,745 Bytes
8d9f842
 
e45be51
 
 
 
0ce6f4c
 
8d9f842
e45be51
 
8d9f842
0ce6f4c
 
 
 
 
57f4ecb
0ce6f4c
57f4ecb
 
8d9f842
e45be51
e9671ed
57f4ecb
 
 
 
 
 
 
 
 
 
6fce26b
57f4ecb
 
e9671ed
0f3261d
e45be51
 
 
 
 
 
 
0ce6f4c
 
e45be51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57f4ecb
8d9f842
e45be51
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
from ultralytics import YOLO
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import numpy as np
import io
import torch
import spaces  # 导入 spaces 模块

# 初始化 FastAPI 和模型
app = FastAPI()

# 检查 GPU 是否可用,并选择设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = YOLO('NailongKiller.yolo11n.pt').to(device)

@spaces.GPU  # 使用装饰器标记需要 GPU 的函数
def predict(img):
    img = img.to(device)
    results = model.predict(img)
    return results[0].plot()

# Gradio 界面
demo = gr.Interface(
    predict,
    inputs=[
        gr.Image(label="输入图片")
    ],
    outputs=[
        gr.Image(label="检测结果", type="numpy")
    ],
    title="🐉 奶龙杀手 (NailongKiller)",
    description="上传图片来检测奶龙 | Upload an image to detect Nailong",
    examples=[
        ["example1.jpg"]
    ],
    cache_examples=True
)

# API 端点
@app.post("/detect/")
async def detect_api(file: UploadFile = File(...)):
    contents = await file.read()
    image = Image.open(io.BytesIO(contents))
    image_np = np.array(image)
    
    image_np = torch.from_numpy(image_np).to(device)
    
    results = model.predict(image_np)
    result = results[0]
    
    detections = []
    for box in result.boxes:
        detection = {
            "bbox": box.xyxy[0].tolist(),
            "confidence": float(box.conf[0]),
            "class": int(box.cls[0])
        }
        detections.append(detection)
    
    return {"detections": detections}

# 挂载 Gradio 到 FastAPI
app = gr.mount_gradio_app(app, demo, path="/")

# 启动应用
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)