File size: 1,797 Bytes
1e11062
59c3137
 
 
 
 
1e11062
 
 
 
 
59c3137
 
 
 
 
 
 
 
 
 
 
1e11062
 
59c3137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e11062
59c3137
 
1e11062
 
 
 
59c3137
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import torch
from torchvision import transforms
from PIL import Image
import io

BASE_DIR = os.path.dirname(os.path.abspath(__file__))  
MODEL_FILENAME = "model.pt"
MODEL_PATH = os.path.join(BASE_DIR, MODEL_FILENAME)

NUM_CLASSES = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(model_path, num_classes):
    from torchvision.models.detection import fasterrcnn_resnet50_fpn
    model = fasterrcnn_resnet50_fpn(pretrained=False, num_classes=num_classes)
    checkpoint = torch.load(model_path, map_location=DEVICE)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(DEVICE)
    model.eval()
    return model

model = load_model(MODEL_PATH, NUM_CLASSES)

transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
])

def detect_objects(image_bytes):
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        predictions = model(input_tensor)

    boxes = predictions[0]["boxes"].cpu().tolist()
    labels = predictions[0]["labels"].cpu().tolist()
    scores = predictions[0]["scores"].cpu().tolist()

    confidence_threshold = 0.5
    results = [
        {"box": box, "label": label, "score": score}
        for box, label, score in zip(boxes, labels, scores)
        if score > confidence_threshold
    ]

    return {"predictions": results}

def inference(payload):
    import base64
    try:
        if "image" not in payload:
            return {"error": "No image provided. Please send a Base64-encoded image."}

        image_bytes = base64.b64decode(payload["image"])

        results = detect_objects(image_bytes)
        return results
    except Exception as e:
        return {"error": str(e)}