File size: 1,277 Bytes
5c38c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import io
import torch
from torchvision import models, transforms

# 加载预训练的ResNet-50模型
model = models.resnet50(pretrained=True)
model.eval()  # 设置模型为评估模式

# 图像预处理
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 创建FastAPI应用实例
app = FastAPI()

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    contents = await file.read()
    image = Image.open(io.BytesIO(contents)).convert("RGB")
    
    # 预处理图片
    input_tensor = preprocess(image)
    input_batch = input_tensor.unsqueeze(0)  # 添加批处理维度
    
    with torch.no_grad():
        output = model(input_batch)
    
    # 获取预测结果
    _, predicted_idx = torch.max(output, 1)
    
    # 可以在此处添加代码来获取类别名称,这里只返回索引
    return JSONResponse(content={"predicted_class": int(predicted_idx[0])})

# 运行服务
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)