SALT-SAM / AllinonSAM /eval /chestXDet /generate_predictions_baselines.py
pythn's picture
Upload with huggingface_hub
4a1f918 verified
import torch
import yaml
import sys
import copy
import os
sys.path.append("/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/SVDSAM/")
from data_utils import *
from model import *
from utils import *
from baselines import UNet, UNext, medt_net
from vit_seg_modeling import VisionTransformer
from vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from axialnet import MedT
label_names = ['Effusion', 'Nodule', 'Cardiomegaly', 'Fibrosis', 'Consolidation', 'Emphysema', 'Mass', 'Fracture', 'Calcification', 'Pleural Thickening', 'Pneumothorax', 'Atelectasis', 'Diffuse Nodule']
# visualize_li = [[1,0,0],[0,1,0],[1,0,0], [0,0,1], [0,0,1]]
label_dict = {}
# visualize_dict = {}
for i,ln in enumerate(label_names):
label_dict[ln] = i
# visualize_dict[ln] = visualize_li[i]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_folder', default='config_tmp.yml',
help='data folder file path')
parser.add_argument('--data_config', default='config_tmp.yml',
help='data config file path')
parser.add_argument('--model_config', default='model_baseline.yml',
help='model config file path')
parser.add_argument('--pretrained_path', default=None,
help='pretrained model path')
parser.add_argument('--save_path', default='checkpoints/temp.pth',
help='pretrained model path')
parser.add_argument('--gt_path', default='',
help='ground truth path')
parser.add_argument('--device', default='cuda:0', help='device to train on')
parser.add_argument('--codes', default='1,2,1,3,3', help='numeric label to save per instrument')
args = parser.parse_args()
return args
def main():
args = parse_args()
with open(args.data_config, 'r') as f:
data_config = yaml.load(f, Loader=yaml.FullLoader)
with open(args.model_config, 'r') as f:
model_config = yaml.load(f, Loader=yaml.FullLoader)
codes = args.codes.split(',')
codes = [int(c) for c in codes]
label_dict = {
'Effusion': 1,
'Nodule': 2,
'Cardiomegaly': 3,
'Fibrosis': 4,
'Consolidation': 5,
'Emphysema': 6,
'Mass': 7,
'Fracture': 8,
'Calcification': 9,
'Pleural Thickening': 10,
'Pneumothorax': 11,
'Atelectasis': 12,
'Diffuse Nodule': 13
}
#make folder to save visualizations
os.makedirs(os.path.join(args.save_path,"preds"),exist_ok=True)
os.makedirs(os.path.join(args.save_path,"rescaled_preds"),exist_ok=True)
if args.gt_path:
os.makedirs(os.path.join(args.save_path,"rescaled_gt"),exist_ok=True)
#load model
#change the img size in model config according to data config
in_channels = model_config['in_channels']
out_channels = model_config['num_classes']
img_size = model_config['img_size']
if model_config['arch']=='Prompt Adapted SAM':
model = Prompt_Adapted_SAM(model_config, label_dict, args.device, training_strategy='biastuning')
elif model_config['arch']=='UNet':
model = UNet(in_channels=in_channels, out_channels=out_channels)
elif model_config['arch']=='UNext':
model = UNext(num_classes=out_channels, input_channels=in_channels, img_size=img_size)
elif model_config['arch']=='MedT':
#TODO
model = MedT(img_size=img_size, num_classes=out_channels)
elif model_config['arch']=='TransUNet':
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
config_vit.n_classes = out_channels
config_vit.n_skip = 3
# if args.vit_name.find('R50') != -1:
# config_vit.patches.grid = (int(args.img_size / args.vit_patches_size), int(args.img_size / args.vit_patches_size))
model = VisionTransformer(config_vit, img_size=img_size, num_classes=config_vit.n_classes)
model.load_state_dict(torch.load(args.pretrained_path, map_location=args.device))
model = model.to(args.device)
model = model.eval()
#load data transform
data_transform = ChestXDet_Transform(config=data_config)
#dice
dices = []
ious=[]
#load data
for i,img_name in enumerate(sorted(os.listdir(args.data_folder))):
# if i>20:
# continue
img_path = (os.path.join(args.data_folder,img_name))
if args.gt_path:
gt_path = (os.path.join(args.gt_path,img_name))
if not os.path.exists(gt_path):
gt_path = (os.path.join(args.gt_path,img_name[:-4]+'.png'))
if not os.path.exists(gt_path):
continue
# print(img_path)
img = torch.as_tensor(np.array(Image.open(img_path).convert("RGB")))
img = img.permute(2,0,1)
C,H,W = img.shape
#make a dummy mask of shape 1XHXW
label = np.array(Image.open(gt_path))
if args.gt_path:
mask = np.zeros((len(label_dict),img.shape[1], img.shape[2]))
for i,c in enumerate(list(label_dict.keys())):
temp = (label==label_dict[c])
mask[i,:,:] = temp
mask = torch.Tensor(mask+0)
else:
mask = torch.zeros((len(label_dict),H,W))
img, mask = data_transform(img, mask, is_train=False, apply_norm=True)
mask = (mask>=0.5)+0
img = img.unsqueeze(0).to(args.device) #1XCXHXW
masks = model(img,'')
# print("masks shape: ",masks.shape)
argmax_masks = torch.argmax(masks, dim=1).cpu().numpy()
# print("argmax masks shape: ",argmax_masks.shape)
classwise_dices = []
classwise_ious = []
for j,c1 in enumerate(label_dict):
res = np.where(argmax_masks==j,1,0)
# print("res shape: ",res.shape)
plt.imshow(res[0], cmap='gray')
save_dir = os.path.join(args.save_path, c1, 'rescaled_preds')
os.makedirs(save_dir, exist_ok=True)
plt.savefig(os.path.join(args.save_path, c1, 'rescaled_preds', img_name))
plt.close()
if args.gt_path:
plt.imshow((mask[j]), cmap='gray')
save_dir = os.path.join(args.save_path, c1, 'rescaled_gt')
os.makedirs(save_dir, exist_ok=True)
plt.savefig(os.path.join(args.save_path, c1, 'rescaled_gt', img_name))
plt.close()
classwise_dices.append(dice_coef(mask[j], torch.Tensor(res[0])))
classwise_ious.append(iou_coef(mask[j], torch.Tensor(res[0])))
# break
dices.append(classwise_dices)
ious.append(classwise_ious)
# print("classwise_dices: ", classwise_dices)
# print("classwise ious: ", classwise_ious)
print(torch.mean(torch.Tensor(dices),dim=0))
print(torch.mean(torch.Tensor(ious),dim=0))
if __name__ == '__main__':
main()