SALT-SAM / AllinonSAM /eval /btcv /generate_predictions.py
pythn's picture
Upload with huggingface_hub
4a1f918 verified
import torch
import yaml
import sys
import copy
import os
import h5py
sys.path.append("/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/SVDSAM/")
from data_utils import *
from model import *
from utils import *
label_names = ['Spleen', 'Right Kidney', 'Left Kidney', 'Gall Bladder', 'Liver', 'Stomach', 'Aorta', 'Pancreas']
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_folder', default='config_tmp.yml',
help='data folder file path')
parser.add_argument('--data_config', default='config_tmp.yml',
help='data config file path')
parser.add_argument('--model_config', default='model_baseline.yml',
help='model config file path')
parser.add_argument('--pretrained_path', default=None,
help='pretrained model path')
parser.add_argument('--save_path', default='checkpoints/temp.pth',
help='pretrained model path')
parser.add_argument('--gt_path', default='',
help='ground truth path')
parser.add_argument('--device', default='cuda:0', help='device to train on')
parser.add_argument('--labels_of_interest', default='Left Prograsp Forceps,Maryland Bipolar Forceps,Right Prograsp Forceps,Left Large Needle Driver,Right Large Needle Driver', help='labels of interest')
parser.add_argument('--codes', default='1,2,1,3,3', help='numeric label to save per instrument')
args = parser.parse_args()
return args
def main():
args = parse_args()
with open(args.data_config, 'r') as f:
data_config = yaml.load(f, Loader=yaml.FullLoader)
with open(args.model_config, 'r') as f:
model_config = yaml.load(f, Loader=yaml.FullLoader)
labels_of_interest = args.labels_of_interest.split(',')
codes = args.codes.split(',')
codes = [int(c) for c in codes]
label_dict = {
"Spleen":1,
"Right Kidney": 2,
"Left Kidney": 3,
"Gall Bladder": 4,
"Liver": 5,
"Stomach": 6,
"Aorta": 7,
"Pancreas": 8
}
#make folder to save visualizations
os.makedirs(os.path.join(args.save_path,"preds"),exist_ok=True)
os.makedirs(os.path.join(args.save_path,"rescaled_preds"),exist_ok=True)
os.makedirs(os.path.join(args.save_path,"rescaled_gt"),exist_ok=True)
#load model
model = Prompt_Adapted_SAM(config=model_config, label_text_dict=label_dict, device=args.device, training_strategy='svdtuning')
# model = Prompt_Adapted_SAM(config=model_config, label_text_dict=label_dict, device=args.device, training_strategy='lora')
#legacy model support
sdict = torch.load(args.pretrained_path, map_location=args.device)
# for key in list(sdict.keys()):
# if 'sam_encoder.neck' in key:
# if '0' in key:
# new_key = key.replace('0','conv1')
# if '1' in key:
# new_key = key.replace('1','ln1')
# if '2' in key:
# new_key = key.replace('2','conv2')
# if '3' in key:
# new_key = key.replace('3','ln2')
# sdict[new_key] = sdict[key]
# _ = sdict.pop(key)
# if 'mask_decoder' in key:
# if 'trainable' in key:
# _ = sdict.pop(key)
model.load_state_dict(sdict,strict=True)
model = model.to(args.device)
model = model.eval()
#load data transform
data_transform = BTCV_Transform(config=data_config)
#dice
dices = []
ious=[]
#load data
for i,h5_name in enumerate(sorted(os.listdir(args.data_folder))):
# if i%5!=0:
# continue
h5_path = (os.path.join(args.data_folder,h5_name))
data = h5py.File(h5_path)
all_img, all_label = data['image'], data['label']
for i in range(all_img.shape[0]):
if i%5!=0:
continue
img = torch.as_tensor(all_img[i]).unsqueeze(0).repeat(3,1,1)
label = torch.as_tensor(all_label[i])
# print("image shape", img.shape)
# print("label shape: ", label.shape)
# 1/0
# img = img.permute(2,0,1)
C,H,W = img.shape
#make a dummy mask of shape 1XHXW
selected_color = label_dict[args.labels_of_interest]
temp = (label==selected_color)+0
# plt.imshow(gold)
# plt.show()
mask = torch.Tensor(temp).unsqueeze(0)
img, mask = data_transform(img, mask, is_train=False, apply_norm=True)
mask = (mask>=0.5)+0
#get image embeddings
img = img.unsqueeze(0).to(args.device) #1XCXHXW
img_embeds = model.get_image_embeddings(img)
# generate masks for all labels of interest
img_embeds_repeated = img_embeds.repeat(len(labels_of_interest),1,1,1)
x_text = [t for t in labels_of_interest]
masks = model.get_masks_for_multiple_labels(img_embeds_repeated, x_text).cpu()
argmax_masks = torch.argmax(masks, dim=0)
final_mask = torch.zeros(masks[0].shape)
final_mask_rescaled = torch.zeros(masks[0].shape).unsqueeze(-1).repeat(1,1,3)
#save masks
for i in range(final_mask.shape[0]):
for j in range(final_mask.shape[1]):
final_mask[i,j] = codes[argmax_masks[i,j]] if masks[argmax_masks[i,j],i,j]>=0.5 else 0
# final_mask_rescaled[i,j] = torch.Tensor(visualize_dict[(labels_of_interest[argmax_masks[i,j]])] if masks[argmax_masks[i,j],i,j]>=0.5 else [0,0,0])
# save_im = Image.fromarray(final_mask.numpy())
# save_im.save(os.path.join(args.save_path,'preds', img_name))
# plt.imshow(final_mask_rescaled,cmap='gray')
# plt.savefig(os.path.join(args.save_path,'rescaled_preds', img_name))
# plt.close()
# print("label shape: ", label.shape)
# plt.imshow(label[0], cmap='gray')
# plt.show()
plt.imshow((masks[0]>=0.5), cmap='gray')
plt.savefig(os.path.join(args.save_path,'rescaled_preds', h5_name[:h5_name.find('.')]+"_"+str(i)+".png"))
plt.close()
plt.imshow((mask[0]), cmap='gray')
plt.savefig(os.path.join(args.save_path,'rescaled_gt', h5_name[:h5_name.find('.')]+"_"+str(i)+".png"))
plt.close()
# print("dice: ",dice_coef(label, (masks>0.5)+0))
dices.append(dice_coef(mask, (masks>=0.5)+0))
ious.append(iou_coef(mask, (masks>=0.5)+0))
# break
print(torch.mean(torch.Tensor(dices)))
print(torch.mean(torch.Tensor(ious)))
if __name__ == '__main__':
main()