🩺 UNet Model for COVID-19 CT Scan Segmentation

πŸ“Œ Model Overview

This UNet-based segmentation model is designed for automated segmentation of COVID-19 infected lung regions in CT scans. It enhances the classic U-Net with attention mechanisms to improve focus on infected regions.

  • Architecture: UNet + Attention Gates
  • Dataset: COVID-19 CT scans from Coronacases.org, Radiopaedia.org, and Zenodo Repository
  • Task: Image Segmentation (Lung Infection)
  • Metrics: Dice Coefficient, IoU, Hausdorff Distance, ASSD

πŸ“Š Training Details

  • Dataset Size: 20 CT scans (512 Γ— 512 Γ— 301 slices)
  • Preprocessing:
    • Normalization of pixel intensities [0,1]
    • HU Thresholding: [-1000, 1500]
    • Image resizing to 128 Γ— 128 pixels
    • Binarization of masks (0 = background, 1 = infected regions)
  • Augmentation:
    • Rotations: Β±5 degrees
    • Elastic transformations, Gaussian blur
    • Brightness/contrast variations
    • Final dataset: 2,252 CT slices
  • Training:
    • Optimizer: Adam (learning rate = 1e-4)
    • Loss Function: Weighted BCE-Dice Loss + Surface Loss
    • Batch Size: 16
    • Epochs: 25
    • Training Platform: NVIDIA Tesla T4 (Google Colab Pro)

πŸš€ Model Performance

Metric Non-Augmented Model Augmented Model
Dice Coefficient 0.8502 0.8658
IoU (Mean) 0.7445 0.8316
ASSD (Symmetric Distance) 0.3907 0.3888
Hausdorff Distance 8.4853 9.8995
ROC AUC Score 0.91 1.00

πŸ“Œ Key Findings:
βœ” Augmentation improved segmentation accuracy significantly
βœ” Attention U-Net outperformed other segmentation models


πŸ“₯ How to Use the Model

1️⃣ Load the Model

TensorFlow/Keras

import os
from huggingface_hub import hf_hub_download
from tensorflow.keras.models import load_model
from keras.saving import register_keras_serializable
import tensorflow.keras.backend as K

# βœ… Set Keras backend (optional)
os.environ["KERAS_BACKEND"] = "jax"

# βœ… Register and define missing functions
@register_keras_serializable()
def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

@register_keras_serializable()
def gl_sl(*args, **kwargs):
    pass  # Placeholder function (update if needed)

# βœ… Download the model from Hugging Face
model_path = hf_hub_download(repo_id="amal90888/unet-segmentation-model", filename="unet_model.keras")

# βœ… Load the model with registered custom objects
unet = load_model(model_path, custom_objects={"dice_coef": dice_coef, "gl_sl": gl_sl}, compile=False)

# βœ… Recompile with fresh optimizer and correct loss function
from tensorflow.keras.optimizers import Adam
unet.compile(optimizer=Adam(learning_rate=1e-4), loss="binary_crossentropy", metrics=["accuracy", dice_coef])

print("βœ… Model loaded and recompiled successfully!")
Downloads last month
147
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support