spot-yolo-new / handler.py
ryanhuangtw's picture
Upload 2 files
6278f13 verified
raw
history blame
904 Bytes
import torch
from transformers import AutoModelForObjectDetection, AutoProcessor
class CustomHandler:
def __init__(self):
self.model = None
self.processor = None
def initialize(self, model_dir):
# 載入模型和處理器
self.model = AutoModelForObjectDetection.from_pretrained(model_dir)
self.processor = AutoProcessor.from_pretrained(model_dir)
def preprocess(self, request):
# 解析輸入資料
inputs = request.get("inputs")
return self.processor(images=inputs, return_tensors="pt")
def inference(self, inputs):
# 執行推理
with torch.no_grad():
outputs = self.model(**inputs)
return outputs
def postprocess(self, outputs):
# 從模型輸出轉換成人類可讀格式
results = outputs.logits.softmax(-1).tolist()
return {"predictions": results}