|
--- |
|
license: mit |
|
tags: |
|
- pytorch |
|
- unet |
|
- image-segmentation |
|
- biomedical |
|
- fluorescence-microscopy |
|
- amyloid |
|
- fibril |
|
- research |
|
library_name: segmentation_models.pytorch |
|
inference: false |
|
datasets: [] |
|
language: [] |
|
model-index: |
|
- name: UNet Fibril Segmentation |
|
results: [] |
|
--- |
|
|
|
|
|
# 🧠 UNet Fibril Segmentation Model |
|
|
|
 |
|
|
|
A **UNet-based deep learning model** trained for **semantic segmentation of fibrillar structures** in **single-molecule fluorescence microscopy images**. This model is specifically designed to identify and segment **amyloid fibrils**, which are critical in research related to neurodegenerative diseases such as Alzheimer’s. |
|
|
|
--- |
|
|
|
## 🔬 Model Overview |
|
|
|
- **Architecture**: [UNet](https://arxiv.org/abs/1505.04597) |
|
- **Encoder**: ResNet34 (pretrained on ImageNet) |
|
- **Input Channels**: 1 (grayscale) |
|
- **Output**: Binary mask of fibril regions |
|
- **Loss Function**: BCEWithLogitsLoss / Dice Loss (combined) |
|
- **Framework**: PyTorch + [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) |
|
|
|
--- |
|
|
|
## 🧠 Use Case |
|
|
|
The model is built for biomedical researchers and computer vision practitioners working in: |
|
|
|
- **Neuroscience research** (e.g., Alzheimer's, Parkinson’s) |
|
- **Amyloid aggregation studies** |
|
- **Single-molecule fluorescence microscopy** |
|
- **Self-supervised denoising + segmentation pipelines** |
|
|
|
--- |
|
|
|
## 🧪 Dataset |
|
|
|
The model was trained on a curated dataset of **fluorescence microscopy images** annotated for fibrillar structures. Images were grayscale, of size `512x512` or `256x256`, manually labeled using Fiji/ImageJ or custom annotation tools. |
|
|
|
*Note: If you're a researcher and would like to contribute more annotated data or collaborate on a dataset release, please reach out.* |
|
|
|
--- |
|
|
|
## 📦 Files |
|
|
|
- `unet_fibril_seg_model.pth` — Trained PyTorch weights |
|
- `inference.py` — Inference script for running the model |
|
- `preprocessing.py` — Image normalization and transforms |
|
|
|
--- |
|
|
|
## 🖼️ Example |
|
|
|
```python |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision import transforms |
|
import segmentation_models_pytorch as smp |
|
|
|
# Load model |
|
model = smp.Unet( |
|
encoder_name="resnet34", |
|
encoder_weights="imagenet", |
|
in_channels=1, |
|
classes=1, |
|
) |
|
model.load_state_dict(torch.load("unet_fibril_seg_model.pth", map_location="cpu")) |
|
model.eval() |
|
|
|
# Load image |
|
img = Image.open("test_image.png").convert("L") |
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor() |
|
]) |
|
input_tensor = transform(img).unsqueeze(0) |
|
|
|
# Predict |
|
with torch.no_grad(): |
|
pred = model(input_tensor) |
|
pred_mask = torch.sigmoid(pred).squeeze().numpy() |
|
binary_mask = (pred_mask > 0.5).astype(np.uint8) |
|
|
|
# Save output |
|
Image.fromarray(binary_mask * 255).save("predicted_mask.png") |
|
|