🧠 UNet Fibril Segmentation Model
A UNet-based deep learning model trained for semantic segmentation of fibrillar structures in single-molecule fluorescence microscopy images. This model is specifically designed to identify and segment amyloid fibrils, which are critical in research related to neurodegenerative diseases such as Alzheimer’s.
🔬 Model Overview
- Architecture: UNet
- Encoder: ResNet34 (pretrained on ImageNet)
- Input Channels: 1 (grayscale)
- Output: Binary mask of fibril regions
- Loss Function: BCEWithLogitsLoss / Dice Loss (combined)
- Framework: PyTorch + segmentation_models.pytorch
🧠 Use Case
The model is built for biomedical researchers and computer vision practitioners working in:
- Neuroscience research (e.g., Alzheimer's, Parkinson’s)
- Amyloid aggregation studies
- Single-molecule fluorescence microscopy
- Self-supervised denoising + segmentation pipelines
🧪 Dataset
The model was trained on a curated dataset of fluorescence microscopy images annotated for fibrillar structures. Images were grayscale, of size 512x512
or 256x256
, manually labeled using Fiji/ImageJ or custom annotation tools.
Note: If you're a researcher and would like to contribute more annotated data or collaborate on a dataset release, please reach out.
📦 Files
unet_fibril_seg_model.pth
— Trained PyTorch weightsinference.py
— Inference script for running the modelpreprocessing.py
— Image normalization and transforms
🖼️ Example
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import segmentation_models_pytorch as smp
# Load model
model = smp.Unet(
encoder_name="resnet34",
encoder_weights="imagenet",
in_channels=1,
classes=1,
)
model.load_state_dict(torch.load("unet_fibril_seg_model.pth", map_location="cpu"))
model.eval()
# Load image
img = Image.open("test_image.png").convert("L")
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
input_tensor = transform(img).unsqueeze(0)
# Predict
with torch.no_grad():
pred = model(input_tensor)
pred_mask = torch.sigmoid(pred).squeeze().numpy()
binary_mask = (pred_mask > 0.5).astype(np.uint8)
# Save output
Image.fromarray(binary_mask * 255).save("predicted_mask.png")