Introduction
Virtual cells represent a promosing paradigm to understand cell biological mechanisms, behavior and transitions. The realization of virtual cells relies on the accurate modeling of cellular dynamics from large-scale, multi-modal single-cell data. However, experiment-specific technical noise and intrinsic biological heterogeneity pose major challenges to current models. While much recent work has focused on data normalization and preprocessing, relatively little effort has been devoted to improving model architecture. To address this gap, we present scDifformer, a context-aware transformer model augmented with a denoising diffusion module and a dedicated post-training phase. By combining the strengths of transformer and diffusion models, scDifformer achieves state-of-the-art (SOTA) performance in cell type annotation across diverse datasets. It further demonstrates robust capability in resolving immune cell identities across multiple tissues, accurately recovering key marker genes, functional pathways, and cross-tissue differentiation trajectories. Finally, by integrating scDifformer with a graph neural network, we extend its utility to spatial transcriptomics, significantly enhancing spot-level deconvolution accuracy. Altogether, scDifformer provides a scalable and biologically grounded framework for modeling heterogeneous single-cell data, offering a powerful foundation for the development of high-fidelity, multi-modal virtual cell models.
Three-stage training paradigm of scDifformer
Unlike conventional pre-training approaches founded upon a two-phase architecture consisting of "pre-training" and "fine-tuning" stages, the proposed scDifformer subjects the pre-trained weights to continued post-training using diffusion probabilistic models, rather than directly fine-tuning the pre-trained model after the initial pre-training phase. This additional post-training serves to further reduce noise and refine the learned representations by allowing for incremental parameter updates within the scDifformer framework. The integrated scDifformer architecture achieves representational enhancement and learns richer contextualized embeddings for downstream scRNA-seq analysis tasks.
Usage
You can use the following code template to load and fine-tune the model
# reload pretrained model
model = BertForSequenceClassification.from_pretrained(pretrain_model_path,
num_labels=num_labels,
output_attentions=False,
output_hidden_states=False).to("cuda")
# set training arguments
training_args = {
"learning_rate": max_lr,
"do_train": True,
"do_eval": True,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"logging_steps": logging_steps,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": weight_decay,
"per_device_train_batch_size": batch_size,
"per_device_eval_batch_size": batch_size,
"num_train_epochs": epochs,
"load_best_model_at_end": True,
"output_dir": output_dir,
}
training_args_init = TrainingArguments(**training_args)
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForAging(),
train_dataset=train_split,
eval_dataset=eval_split,
compute_metrics=compute_metrics
)
# train the model
trainer.train()
predictions = trainer.predict(eval_split)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
pickle.dump(predictions, fp)
trainer.save_metrics("eval", predictions.metrics)
trainer.save_model(output_dir)
contact: [email protected]