RP3D-DiagModel / README.md
QiaoyuZheng's picture
Update README.md
88fa447 verified
---
license: apache-2.0
---
# RP3D-DiagModel
## About Checkpoint
The detailed parameter we use for training is in the following:
```
start_class: 0
end_clas: 5569
backbone: 'resnet'
level: 'articles' # represents the disorder level
depth: 32
ltype: 'MultiLabel' # represents the Binary Cross Entropy Loss
augment: True # represents the medical data augmentation
split: 'late' # represents the late fusion strategy
```
### Load Model
```
# Load backnone
model = RadNet(num_cls=num_classes, backbone=backbone, depth=depth, ltype=ltype, augment=augment, fuse=fuse, ke=ke, encoded=encoded, adapter=adapter)
pretrained_weights = torch.load("path/to/pytorch_model_32_late.bin")
missing, unexpect = model.load_state_dict(pretrained_weights,strict=False)
print("missing_cpt:", missing)
print("unexpect_cpt:", unexpect)
# If KE is set True, load text encoder
medcpt = MedCPT_clinical(bert_model_name = 'ncbi/MedCPT-Query-Encoder')
checkpoint = torch.load('path/to/epoch_state.pt',map_location='cpu')['state_dict']
load_checkpoint = {key.replace('module.', ''): value for key, value in checkpoint.items()}
missing, unexpect = medcpt.load_state_dict(load_checkpoint, strict=False)
print("missing_cpt:", missing)
print("unexpect_cpt:", unexpect)
```
## Why we provide this checkpoint?
All the early fusion checkpoint can be further finetuned from this checkpoint. If you need other checkpoints using different parameter settings, there are two possible ways:
### Finetune from this checkpoint
'''
checkpoint: "None"
safetensor: path to this checkpoint(pytorch_model.bin)
'''
### Contact Us
Email the author: [email protected]
## About Dataset
Please refer to [RP3D-DiagDS](https://huggingface.co/datasets/QiaoyuZheng/RP3D-DiagDS)
For more information, please refer to our instructions on [github](https://github.com/qiaoyu-zheng/RP3D-Diag) to download and use.