File size: 3,224 Bytes
fd04d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import segmentation_models_pytorch as smp
from torchvision import transforms
from PIL import Image
import io
import json
import base64
import numpy as np

# Define the number of output classes (update if needed)
NUM_CLASSES = 4

# Define preprocessing transforms (should match what was used during training)
preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),  # ImageNet means
                         std=(0.229, 0.224, 0.225))
])

# Define class-color mapping for segmentation mask visualization
COLOR_MAPPING = {
    0: [0, 0, 0],        # Background
    1: [255, 0, 124],    # Oil
    2: [255, 204, 51],   # Others
    3: [51, 221, 255]    # Water
}

def colorize_mask(mask):
    """Convert a 2D segmentation mask into an RGB image."""
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for cls, color in COLOR_MAPPING.items():
        color_mask[mask == cls] = color
    return color_mask

class OilSpillSegmentationHandler:
    def __init__(self):
        """Load the model and set it to evaluation mode."""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = smp.Unet(
            encoder_name="resnet34",  # Ensure this matches your training
            encoder_weights=None,     # Weights are loaded from state_dict
            in_channels=3,
            classes=NUM_CLASSES
        )
        self.model.load_state_dict(torch.load("model.pth", map_location=self.device))
        self.model.to(self.device)
        self.model.eval()

    def preprocess(self, image_bytes):
        """Preprocess input image (convert to tensor)."""
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        image_tensor = preprocess(image).unsqueeze(0).to(self.device)
        return image_tensor, image

    def inference(self, image_tensor):
        """Run inference and return the segmentation mask."""
        with torch.no_grad():
            output = self.model(image_tensor)
            pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
        return pred_mask

    def postprocess(self, pred_mask):
        """Convert segmentation mask to colorized image."""
        colorized_mask = colorize_mask(pred_mask)
        return Image.fromarray(colorized_mask)

    def handle_request(self, request_body):
        """Handle API request: preprocess, infer, postprocess."""
        try:
            data = json.loads(request_body)
            image_bytes = base64.b64decode(data["image"])
            image_tensor, original_image = self.preprocess(image_bytes)
            pred_mask = self.inference(image_tensor)
            output_image = self.postprocess(pred_mask)

            # Convert output image to base64
            buffered = io.BytesIO()
            output_image.save(buffered, format="PNG")
            output_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

            return json.dumps({"output_image": output_b64})
        except Exception as e:
            return json.dumps({"error": str(e)})

# Instantiate the handler
handler = OilSpillSegmentationHandler()