torinriley commited on
Commit
59c3137
1 Parent(s): 7547739

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +57 -0
helper.py CHANGED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import io
5
+
6
+ MODEL_PATH = "model_checkpoint.pt"
7
+ NUM_CLASSES = 4
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Load Faster R-CNN model
11
+ def load_model(model_path, num_classes):
12
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
13
+ model = fasterrcnn_resnet50_fpn(pretrained=False, num_classes=num_classes)
14
+ checkpoint = torch.load(model_path, map_location=DEVICE)
15
+ model.load_state_dict(checkpoint["model_state_dict"])
16
+ model.to(DEVICE)
17
+ model.eval()
18
+ return model
19
+
20
+ transform = transforms.Compose([
21
+ transforms.Resize((640, 640)),
22
+ transforms.ToTensor(),
23
+ ])
24
+
25
+ model = load_model(MODEL_PATH, NUM_CLASSES)
26
+
27
+ def detect_objects(image_bytes):
28
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
29
+ input_tensor = transform(image).unsqueeze(0).to(DEVICE)
30
+
31
+ with torch.no_grad():
32
+ predictions = model(input_tensor)
33
+
34
+ boxes = predictions[0]["boxes"].cpu().tolist()
35
+ labels = predictions[0]["labels"].cpu().tolist()
36
+ scores = predictions[0]["scores"].cpu().tolist()
37
+
38
+ confidence_threshold = 0.5
39
+ results = [
40
+ {"box": box, "label": label, "score": score}
41
+ for box, label, score in zip(boxes, labels, scores)
42
+ if score > confidence_threshold
43
+ ]
44
+
45
+ return {"predictions": results}
46
+
47
+ def inference(payload):
48
+ try:
49
+ if "image" not in payload:
50
+ return {"error": "No image provided. Please send an image."}
51
+
52
+ image_bytes = payload["image"].encode("latin1")
53
+
54
+ results = detect_objects(image_bytes)
55
+ return results
56
+ except Exception as e:
57
+ return {"error": str(e)}