File size: 2,392 Bytes
59c3137 f53d612 59c3137 28a32c3 59c3137 f53d612 639e661 f53d612 639e661 f53d612 639e661 f53d612 639e661 cc3a466 639e661 cc3a466 387dfb8 cc3a466 639e661 cc3a466 639e661 cc3a466 639e661 cc3a466 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
from model import get_model
from torchvision.transforms import ToTensor
from PIL import Image
import io
import os
# Constants
NUM_CLASSES = 4
CONFIDENCE_THRESHOLD = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler: load the model.
"""
# Load the model
self.model_weights_path = os.path.join(path, "model.pt")
self.model = get_model(NUM_CLASSES).to(DEVICE)
checkpoint = torch.load(self.model_weights_path, map_location=DEVICE)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.eval()
# Preprocessing function
self.preprocess = ToTensor()
# Class labels
self.label_map = {1: "yellow", 2: "red", 3: "blue"}
def preprocess_frame(self, image_bytes):
"""
Convert raw binary image data to a tensor.
"""
# Load image from binary data
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE)
return image_tensor
def __call__(self, data):
"""
Process incoming raw binary image data.
"""
try:
if "body" not in data:
return {"error": "No image data provided in request."}
image_bytes = data["body"]
image_tensor = self.preprocess_frame(image_bytes)
with torch.no_grad():
predictions = self.model(image_tensor)
boxes = predictions[0]["boxes"].cpu().tolist()
labels = predictions[0]["labels"].cpu().tolist()
scores = predictions[0]["scores"].cpu().tolist()
results = []
for box, label, score in zip(boxes, labels, scores):
if score >= CONFIDENCE_THRESHOLD:
x1, y1, x2, y2 = map(int, box)
label_text = self.label_map.get(label, "unknown")
results.append({
"label": label_text,
"score": round(score, 2),
"box": {
"xmin": x1,
"ymin": y1,
"xmax": x2,
"ymax": y2
}
})
return results
except Exception as e:
return {"error": str(e)} |