torinriley commited on
Commit
639e661
1 Parent(s): 1e11062

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +61 -55
handler.py CHANGED
@@ -4,58 +4,64 @@ from torchvision import transforms
4
  from PIL import Image
5
  import io
6
 
7
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
8
- MODEL_FILENAME = "model.pt"
9
- MODEL_PATH = os.path.join(BASE_DIR, MODEL_FILENAME)
10
-
11
- NUM_CLASSES = 4
12
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- def load_model(model_path, num_classes):
15
- from torchvision.models.detection import fasterrcnn_resnet50_fpn
16
- model = fasterrcnn_resnet50_fpn(pretrained=False, num_classes=num_classes)
17
- checkpoint = torch.load(model_path, map_location=DEVICE)
18
- model.load_state_dict(checkpoint["model_state_dict"])
19
- model.to(DEVICE)
20
- model.eval()
21
- return model
22
-
23
- model = load_model(MODEL_PATH, NUM_CLASSES)
24
-
25
- transform = transforms.Compose([
26
- transforms.Resize((640, 640)),
27
- transforms.ToTensor(),
28
- ])
29
-
30
- def detect_objects(image_bytes):
31
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
32
- input_tensor = transform(image).unsqueeze(0).to(DEVICE)
33
-
34
- with torch.no_grad():
35
- predictions = model(input_tensor)
36
-
37
- boxes = predictions[0]["boxes"].cpu().tolist()
38
- labels = predictions[0]["labels"].cpu().tolist()
39
- scores = predictions[0]["scores"].cpu().tolist()
40
-
41
- confidence_threshold = 0.5
42
- results = [
43
- {"box": box, "label": label, "score": score}
44
- for box, label, score in zip(boxes, labels, scores)
45
- if score > confidence_threshold
46
- ]
47
-
48
- return {"predictions": results}
49
-
50
- def inference(payload):
51
- import base64
52
- try:
53
- if "image" not in payload:
54
- return {"error": "No image provided. Please send a Base64-encoded image."}
55
-
56
- image_bytes = base64.b64decode(payload["image"])
57
-
58
- results = detect_objects(image_bytes)
59
- return results
60
- except Exception as e:
61
- return {"error": str(e)}
 
 
 
 
 
 
 
4
  from PIL import Image
5
  import io
6
 
7
+ # Import your Faster R-CNN model definition
8
+ from model import get_model
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path: str = ""):
12
+ """
13
+ Initialize the handler. Load the Faster R-CNN model.
14
+ """
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ self.model_weights_path = os.path.join(path, "model.pt") # Adjust for your file name
17
+
18
+ # Load the model
19
+ self.model = get_model(num_classes=4) # Modify for your num_classes
20
+ print(f"Loading weights from: {self.model_weights_path}")
21
+ checkpoint = torch.load(self.model_weights_path, map_location=self.device)
22
+ self.model.load_state_dict(checkpoint["model_state_dict"])
23
+ self.model.to(self.device)
24
+ self.model.eval()
25
+
26
+ # Define image preprocessing
27
+ self.transform = transforms.Compose([
28
+ transforms.Resize((640, 640)), # Adjust size to match your training setup
29
+ transforms.ToTensor(),
30
+ ])
31
+
32
+ def __call__(self, data):
33
+ """
34
+ Process the incoming request and return object detection predictions.
35
+ """
36
+ try:
37
+ # Expect input data to include a Base64-encoded image
38
+ if "image" not in data:
39
+ return [{"error": "No 'image' provided in request."}]
40
+
41
+ # Convert Base64-encoded image to bytes
42
+ image_bytes = data["image"].encode("latin1")
43
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
44
+
45
+ # Preprocess the image
46
+ input_tensor = self.transform(image).unsqueeze(0).to(self.device)
47
+
48
+ # Run inference
49
+ with torch.no_grad():
50
+ outputs = self.model(input_tensor)
51
+
52
+ # Extract results
53
+ boxes = outputs[0]["boxes"].cpu().tolist()
54
+ labels = outputs[0]["labels"].cpu().tolist()
55
+ scores = outputs[0]["scores"].cpu().tolist()
56
+
57
+ # Confidence threshold
58
+ threshold = 0.5
59
+ predictions = [
60
+ {"box": box, "label": label, "score": score}
61
+ for box, label, score in zip(boxes, labels, scores)
62
+ if score > threshold
63
+ ]
64
+
65
+ return [{"predictions": predictions}]
66
+ except Exception as e:
67
+ return [{"error": str(e)}]