|
--- |
|
library_name: transformers |
|
tags: |
|
- vision |
|
- cell-biology |
|
- dino |
|
pipeline_tag: image-feature-extraction |
|
model-index: |
|
- name: cellrepDINO |
|
results: [] |
|
--- |
|
|
|
# CellrepDINO Model |
|
|
|
This is a custom DINO model for extracting rich representations of cell microscopy in condensed vector/array form. The forward method of the cellrepDINO model gives embeddings that can be used |
|
for relevant downstream tasks like perturbation prediction, mechanism of action (MoA) classification, nuclei size shape estimation, etc. Simply train a basic linear or logistic model using the embeddings. |
|
|
|
## Model Details |
|
- Architecture: DINOv2 |
|
- Type: Vision Transformer |
|
- Input Size: 518x518 |
|
- Patch Size: 14 |
|
- Image Size: 1024 |
|
- Center Crop: 518 |
|
|
|
## Setup |
|
|
|
Please git clone the repository via `git clone --filter=blob:none https://huggingface.co/lhphillips/cellrepDINO`. Then `cd` to the directory, and run `pip install -e .` |
|
|
|
## Example Usage |
|
|
|
There are different types of embeddings of embeddings one can extract, we recommend the mean/median embeddings over the patch tokens or the class token embedding. |
|
The code below is an example of the mean of the patch token embeddings. To get the median simply replace `batch_outputs['x_norm_patchtokens'].mean(dim=1)` with `batch_outputs['x_norm_patchtokens'].median(dim=1)`. |
|
To get the class token embeddings: `batch_embeddings = batch_outputs['x_norm_clstoken']['x_norm_clstoken']`. |
|
|
|
|
|
``` |
|
from transformers import AutoModel, AutoProcessor |
|
from PIL import Image |
|
import torch |
|
|
|
# Set up device |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
# Load model and processor |
|
model = AutoModel.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True) |
|
processor = AutoProcessor.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True) |
|
|
|
# Move model to device |
|
model = model.to(device) |
|
model.eval() |
|
|
|
# For multiple images: |
|
image_paths = ["image1.png", "image2.png"] |
|
images = [Image.open(path) for path in image_paths] |
|
|
|
# Process batch of images |
|
batch_inputs = processor.preprocess(images=images, return_tensors="pt") |
|
batch_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch_inputs.items()} |
|
|
|
# Generate embeddings for batch |
|
with torch.no_grad(): |
|
batch_outputs = model(**batch_inputs) |
|
batch_embeddings = batch_outputs['x_norm_patchtokens'].mean(dim=1) |
|
|
|
``` |