File size: 3,060 Bytes
bdb9d5d
 
 
 
7b8e27e
 
bdb9d5d
7b8e27e
 
bdb9d5d
b4542c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b8e27e
bdb9d5d
7b8e27e
 
b4542c5
7b8e27e
b4542c5
5889000
 
 
7b8e27e
5889000
7b8e27e
bdb9d5d
 
 
7b8e27e
 
bdb9d5d
7b8e27e
 
 
bdb9d5d
 
7b8e27e
bdb9d5d
7b8e27e
 
b4542c5
7b8e27e
 
b4542c5
7b8e27e
bdb9d5d
7b8e27e
 
b4542c5
7b8e27e
b4542c5
 
 
 
bdb9d5d
7b8e27e
bdb9d5d
 
 
7b8e27e
 
 
bdb9d5d
 
d780a6e
 
5889000
 
0ba9aa3
d780a6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import gradio as gr
from torchvision import transforms
from huggingface_hub import hf_hub_download
import segmentation_models_pytorch as smp
import numpy as np

# Set the number of output classes (from your label_colors.txt, you have 4 classes)
NUM_CLASSES = 4

# Define a mapping from class indices to RGB colors
# For example: background: black, oil: (255, 0, 124), others: (255, 204, 51), water: (51, 221, 255)
COLOR_MAPPING = {
    0: [0, 0, 0],
    1: [255, 0, 124],
    2: [255, 204, 51],
    3: [51, 221, 255]
}

def colorize_mask(mask):
    """
    Convert a 2D mask (with class indices) into a color image.
    
    Args:
        mask (np.ndarray): 2D numpy array with class indices.
        
    Returns:
        np.ndarray: Color image (H x W x 3) with each class colored according to COLOR_MAPPING.
    """
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for cls, color in COLOR_MAPPING.items():
        color_mask[mask == cls] = color
    return color_mask

# Download the model state dictionary from your Hugging Face repository
model_path = hf_hub_download(repo_id="TheArchitect416/oil-spill-segmentation-model", filename="model.pth")

# Create the model using segmentation_models_pytorch.
# This must match the architecture used during training.
model = smp.Unet(
    encoder_name="resnet34",       # For example, resnet34 was used in training.
    encoder_weights="imagenet",      # Use pretrained weights from ImageNet.
    in_channels=3,                   # RGB images.
    classes=NUM_CLASSES              # Number of segmentation classes.
)

# Load the state dict (mapping the keys appropriately)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()

# Define preprocessing transforms (should match what was used during training)
preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),  # ImageNet means
                         std=(0.229, 0.224, 0.225))
])

# Define the inference function
def predict(image):
    """
    Accepts a PIL image, preprocesses it, runs the model,
    and returns the predicted colored segmentation mask.
    """
    # Preprocess the image
    input_tensor = preprocess(image).unsqueeze(0)  # shape: [1, 3, 256, 256]
    
    with torch.no_grad():
        output = model(input_tensor)
    
    # Get the predicted class for each pixel
    pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
    
    # Convert the 2D class-index mask to a color mask
    colored_mask = colorize_mask(pred_mask)
    return colored_mask

# Create a Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Image(type="numpy"),
    title="Oil Spill Segmentation",
    description="Segment oil spills in aerial images."
)

print("Gradio version:", gr.__version__)

# Launch the interface
if __name__ == "__main__":
    iface.queue()
    iface.launch(server_name="0.0.0.0", server_port=7860)