Spaces:
Running
Running
File size: 3,060 Bytes
bdb9d5d 7b8e27e bdb9d5d 7b8e27e bdb9d5d b4542c5 7b8e27e bdb9d5d 7b8e27e b4542c5 7b8e27e b4542c5 5889000 7b8e27e 5889000 7b8e27e bdb9d5d 7b8e27e bdb9d5d 7b8e27e bdb9d5d 7b8e27e bdb9d5d 7b8e27e b4542c5 7b8e27e b4542c5 7b8e27e bdb9d5d 7b8e27e b4542c5 7b8e27e b4542c5 bdb9d5d 7b8e27e bdb9d5d 7b8e27e bdb9d5d d780a6e 5889000 0ba9aa3 d780a6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
import gradio as gr
from torchvision import transforms
from huggingface_hub import hf_hub_download
import segmentation_models_pytorch as smp
import numpy as np
# Set the number of output classes (from your label_colors.txt, you have 4 classes)
NUM_CLASSES = 4
# Define a mapping from class indices to RGB colors
# For example: background: black, oil: (255, 0, 124), others: (255, 204, 51), water: (51, 221, 255)
COLOR_MAPPING = {
0: [0, 0, 0],
1: [255, 0, 124],
2: [255, 204, 51],
3: [51, 221, 255]
}
def colorize_mask(mask):
"""
Convert a 2D mask (with class indices) into a color image.
Args:
mask (np.ndarray): 2D numpy array with class indices.
Returns:
np.ndarray: Color image (H x W x 3) with each class colored according to COLOR_MAPPING.
"""
h, w = mask.shape
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
for cls, color in COLOR_MAPPING.items():
color_mask[mask == cls] = color
return color_mask
# Download the model state dictionary from your Hugging Face repository
model_path = hf_hub_download(repo_id="TheArchitect416/oil-spill-segmentation-model", filename="model.pth")
# Create the model using segmentation_models_pytorch.
# This must match the architecture used during training.
model = smp.Unet(
encoder_name="resnet34", # For example, resnet34 was used in training.
encoder_weights="imagenet", # Use pretrained weights from ImageNet.
in_channels=3, # RGB images.
classes=NUM_CLASSES # Number of segmentation classes.
)
# Load the state dict (mapping the keys appropriately)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# Define preprocessing transforms (should match what was used during training)
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet means
std=(0.229, 0.224, 0.225))
])
# Define the inference function
def predict(image):
"""
Accepts a PIL image, preprocesses it, runs the model,
and returns the predicted colored segmentation mask.
"""
# Preprocess the image
input_tensor = preprocess(image).unsqueeze(0) # shape: [1, 3, 256, 256]
with torch.no_grad():
output = model(input_tensor)
# Get the predicted class for each pixel
pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
# Convert the 2D class-index mask to a color mask
colored_mask = colorize_mask(pred_mask)
return colored_mask
# Create a Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="numpy"),
title="Oil Spill Segmentation",
description="Segment oil spills in aerial images."
)
print("Gradio version:", gr.__version__)
# Launch the interface
if __name__ == "__main__":
iface.queue()
iface.launch(server_name="0.0.0.0", server_port=7860)
|