🧠 UNet Fibril Segmentation Model

UNet Fibril Segmentation Banner

A UNet-based deep learning model trained for semantic segmentation of fibrillar structures in single-molecule fluorescence microscopy images. This model is specifically designed to identify and segment amyloid fibrils, which are critical in research related to neurodegenerative diseases such as Alzheimer’s.


🔬 Model Overview

  • Architecture: UNet
  • Encoder: ResNet34 (pretrained on ImageNet)
  • Input Channels: 1 (grayscale)
  • Output: Binary mask of fibril regions
  • Loss Function: BCEWithLogitsLoss / Dice Loss (combined)
  • Framework: PyTorch + segmentation_models.pytorch

🧠 Use Case

The model is built for biomedical researchers and computer vision practitioners working in:

  • Neuroscience research (e.g., Alzheimer's, Parkinson’s)
  • Amyloid aggregation studies
  • Single-molecule fluorescence microscopy
  • Self-supervised denoising + segmentation pipelines

🧪 Dataset

The model was trained on a curated dataset of fluorescence microscopy images annotated for fibrillar structures. Images were grayscale, of size 512x512 or 256x256, manually labeled using Fiji/ImageJ or custom annotation tools.

Note: If you're a researcher and would like to contribute more annotated data or collaborate on a dataset release, please reach out.


📦 Files

  • unet_fibril_seg_model.pth — Trained PyTorch weights
  • inference.py — Inference script for running the model
  • preprocessing.py — Image normalization and transforms

🖼️ Example

import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import segmentation_models_pytorch as smp

# Load model
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=1,
    classes=1,
)
model.load_state_dict(torch.load("unet_fibril_seg_model.pth", map_location="cpu"))
model.eval()

# Load image
img = Image.open("test_image.png").convert("L")
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])
input_tensor = transform(img).unsqueeze(0)

# Predict
with torch.no_grad():
    pred = model(input_tensor)
    pred_mask = torch.sigmoid(pred).squeeze().numpy()
    binary_mask = (pred_mask > 0.5).astype(np.uint8)

# Save output
Image.fromarray(binary_mask * 255).save("predicted_mask.png")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support