File size: 2,021 Bytes
1e11062
59c3137
 
 
 
 
639e661
 
 
 
 
 
 
 
dd95c76
639e661
 
dd95c76
639e661
 
 
 
 
 
 
 
dd95c76
639e661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import torch
from torchvision import transforms
from PIL import Image
import io

from model import get_model

class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the handler. Load the Faster R-CNN model.
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_weights_path = os.path.join(path, "model.pt") 
        
        # Load the model
        self.model = get_model(num_classes=4) 
        print(f"Loading weights from: {self.model_weights_path}")
        checkpoint = torch.load(self.model_weights_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.to(self.device)
        self.model.eval()

        # Define image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((640, 640)),
            transforms.ToTensor(),
        ])

    def __call__(self, data):
        """
        Process the incoming request and return object detection predictions.
        """
        try:
            if "image" not in data:
                return [{"error": "No 'image' provided in request."}]

            image_bytes = data["image"].encode("latin1")
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

            input_tensor = self.transform(image).unsqueeze(0).to(self.device)

            with torch.no_grad():
                outputs = self.model(input_tensor)

            boxes = outputs[0]["boxes"].cpu().tolist()
            labels = outputs[0]["labels"].cpu().tolist()
            scores = outputs[0]["scores"].cpu().tolist()

            threshold = 0.5
            predictions = [
                {"box": box, "label": label, "score": score}
                for box, label, score in zip(boxes, labels, scores)
                if score > threshold
            ]

            return [{"predictions": predictions}]
        except Exception as e:
            return [{"error": str(e)}]