File size: 2,047 Bytes
59c3137
 
 
 
70e7526
59c3137
639e661
 
 
 
 
 
 
 
387dfb8
 
 
dd95c76
639e661
 
 
 
 
387dfb8
639e661
387dfb8
639e661
 
 
 
 
387dfb8
639e661
 
387dfb8
 
 
 
639e661
387dfb8
639e661
 
387dfb8
639e661
 
387dfb8
639e661
387dfb8
639e661
387dfb8
 
 
 
639e661
387dfb8
639e661
387dfb8
639e661
 
 
 
 
387dfb8
639e661
387dfb8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torchvision import transforms
from PIL import Image
import io
import os

from model import get_model

class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the handler. Load the Faster R-CNN model.
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_weights_path = os.path.join(path, "model.pt")  # Adjust path

        # Load model
        self.model = get_model(num_classes=4) 
        checkpoint = torch.load(self.model_weights_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.to(self.device)
        self.model.eval()

        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((640, 640)), 
            transforms.ToTensor(),
        ])

    def __call__(self, data):
        """
        Process incoming binary image data and return object detection results.
        """
        try:
            # Read raw binary data (image file)
            image_bytes = data.get("body", b"") 
            if not image_bytes:
                return {"error": "No image data provided in request."}

            
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

            
            input_tensor = self.transform(image).unsqueeze(0).to(self.device)


            with torch.no_grad():
                predictions = self.model(input_tensor)

            
            boxes = predictions[0]["boxes"].cpu().tolist()
            labels = predictions[0]["labels"].cpu().tolist()
            scores = predictions[0]["scores"].cpu().tolist()

            
            threshold = 0.5
            results = [
                {"box": box, "label": label, "score": score}
                for box, label, score in zip(boxes, labels, scores)
                if score > threshold
            ]

            return {"predictions": results}
        except Exception as e:
            return {"error": str(e)}