pytorch-images / app.py
Charles95's picture
Create app.py
5c38c6f verified
raw
history blame
1.28 kB
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import io
import torch
from torchvision import models, transforms
# 加载预训练的ResNet-50模型
model = models.resnet50(pretrained=True)
model.eval() # 设置模型为评估模式
# 图像预处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 创建FastAPI应用实例
app = FastAPI()
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# 预处理图片
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0) # 添加批处理维度
with torch.no_grad():
output = model(input_batch)
# 获取预测结果
_, predicted_idx = torch.max(output, 1)
# 可以在此处添加代码来获取类别名称,这里只返回索引
return JSONResponse(content={"predicted_class": int(predicted_idx[0])})
# 运行服务
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)