import torch import gradio as gr from torchvision import transforms from huggingface_hub import hf_hub_download import segmentation_models_pytorch as smp import numpy as np # Set the number of output classes (from your label_colors.txt, you have 4 classes) NUM_CLASSES = 4 # Define a mapping from class indices to RGB colors # For example: background: black, oil: (255, 0, 124), others: (255, 204, 51), water: (51, 221, 255) COLOR_MAPPING = { 0: [0, 0, 0], 1: [255, 0, 124], 2: [255, 204, 51], 3: [51, 221, 255] } def colorize_mask(mask): """ Convert a 2D mask (with class indices) into a color image. Args: mask (np.ndarray): 2D numpy array with class indices. Returns: np.ndarray: Color image (H x W x 3) with each class colored according to COLOR_MAPPING. """ h, w = mask.shape color_mask = np.zeros((h, w, 3), dtype=np.uint8) for cls, color in COLOR_MAPPING.items(): color_mask[mask == cls] = color return color_mask # Download the model state dictionary from your Hugging Face repository model_path = hf_hub_download(repo_id="TheArchitect416/oil-spill-segmentation-model", filename="model.pth") # Create the model using segmentation_models_pytorch. # This must match the architecture used during training. model = smp.Unet( encoder_name="resnet34", # For example, resnet34 was used in training. encoder_weights="imagenet", # Use pretrained weights from ImageNet. in_channels=3, # RGB images. classes=NUM_CLASSES # Number of segmentation classes. ) # Load the state dict (mapping the keys appropriately) model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() # Define preprocessing transforms (should match what was used during training) preprocess = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet means std=(0.229, 0.224, 0.225)) ]) # Define the inference function def predict(image): """ Accepts a PIL image, preprocesses it, runs the model, and returns the predicted colored segmentation mask. """ # Preprocess the image input_tensor = preprocess(image).unsqueeze(0) # shape: [1, 3, 256, 256] with torch.no_grad(): output = model(input_tensor) # Get the predicted class for each pixel pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8) # Convert the 2D class-index mask to a color mask colored_mask = colorize_mask(pred_mask) return colored_mask # Create a Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Image(type="numpy"), title="Oil Spill Segmentation", description="Segment oil spills in aerial images." ) print("Gradio version:", gr.__version__) # Launch the interface if __name__ == "__main__": iface.queue() iface.launch(server_name="0.0.0.0", server_port=7860)