|
import torch |
|
import yaml |
|
import sys |
|
import copy |
|
import os |
|
import h5py |
|
sys.path.append("/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/SVDSAM/") |
|
|
|
from data_utils import * |
|
from model import * |
|
from utils import * |
|
|
|
label_names = ['Spleen', 'Right Kidney', 'Left Kidney', 'Gall Bladder', 'Liver', 'Stomach', 'Aorta', 'Pancreas'] |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--data_folder', default='config_tmp.yml', |
|
help='data folder file path') |
|
|
|
parser.add_argument('--data_config', default='config_tmp.yml', |
|
help='data config file path') |
|
|
|
parser.add_argument('--model_config', default='model_baseline.yml', |
|
help='model config file path') |
|
|
|
parser.add_argument('--pretrained_path', default=None, |
|
help='pretrained model path') |
|
|
|
parser.add_argument('--save_path', default='checkpoints/temp.pth', |
|
help='pretrained model path') |
|
|
|
parser.add_argument('--gt_path', default='', |
|
help='ground truth path') |
|
|
|
parser.add_argument('--device', default='cuda:0', help='device to train on') |
|
|
|
parser.add_argument('--labels_of_interest', default='Left Prograsp Forceps,Maryland Bipolar Forceps,Right Prograsp Forceps,Left Large Needle Driver,Right Large Needle Driver', help='labels of interest') |
|
|
|
parser.add_argument('--codes', default='1,2,1,3,3', help='numeric label to save per instrument') |
|
|
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
def main(): |
|
args = parse_args() |
|
with open(args.data_config, 'r') as f: |
|
data_config = yaml.load(f, Loader=yaml.FullLoader) |
|
with open(args.model_config, 'r') as f: |
|
model_config = yaml.load(f, Loader=yaml.FullLoader) |
|
labels_of_interest = args.labels_of_interest.split(',') |
|
codes = args.codes.split(',') |
|
codes = [int(c) for c in codes] |
|
|
|
label_dict = { |
|
"Spleen":1, |
|
"Right Kidney": 2, |
|
"Left Kidney": 3, |
|
"Gall Bladder": 4, |
|
"Liver": 5, |
|
"Stomach": 6, |
|
"Aorta": 7, |
|
"Pancreas": 8 |
|
} |
|
|
|
|
|
|
|
os.makedirs(os.path.join(args.save_path,"preds"),exist_ok=True) |
|
os.makedirs(os.path.join(args.save_path,"rescaled_preds"),exist_ok=True) |
|
os.makedirs(os.path.join(args.save_path,"rescaled_gt"),exist_ok=True) |
|
|
|
|
|
model = Prompt_Adapted_SAM(config=model_config, label_text_dict=label_dict, device=args.device, training_strategy='svdtuning') |
|
|
|
|
|
|
|
sdict = torch.load(args.pretrained_path, map_location=args.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.load_state_dict(sdict,strict=True) |
|
model = model.to(args.device) |
|
model = model.eval() |
|
|
|
|
|
data_transform = BTCV_Transform(config=data_config) |
|
|
|
|
|
dices = [] |
|
ious=[] |
|
|
|
|
|
for i,h5_name in enumerate(sorted(os.listdir(args.data_folder))): |
|
|
|
|
|
h5_path = (os.path.join(args.data_folder,h5_name)) |
|
data = h5py.File(h5_path) |
|
all_img, all_label = data['image'], data['label'] |
|
|
|
for i in range(all_img.shape[0]): |
|
if i%5!=0: |
|
continue |
|
img = torch.as_tensor(all_img[i]).unsqueeze(0).repeat(3,1,1) |
|
label = torch.as_tensor(all_label[i]) |
|
|
|
|
|
|
|
|
|
C,H,W = img.shape |
|
|
|
selected_color = label_dict[args.labels_of_interest] |
|
temp = (label==selected_color)+0 |
|
|
|
|
|
|
|
mask = torch.Tensor(temp).unsqueeze(0) |
|
img, mask = data_transform(img, mask, is_train=False, apply_norm=True) |
|
mask = (mask>=0.5)+0 |
|
|
|
|
|
img = img.unsqueeze(0).to(args.device) |
|
img_embeds = model.get_image_embeddings(img) |
|
|
|
|
|
img_embeds_repeated = img_embeds.repeat(len(labels_of_interest),1,1,1) |
|
x_text = [t for t in labels_of_interest] |
|
masks = model.get_masks_for_multiple_labels(img_embeds_repeated, x_text).cpu() |
|
argmax_masks = torch.argmax(masks, dim=0) |
|
final_mask = torch.zeros(masks[0].shape) |
|
final_mask_rescaled = torch.zeros(masks[0].shape).unsqueeze(-1).repeat(1,1,3) |
|
|
|
for i in range(final_mask.shape[0]): |
|
for j in range(final_mask.shape[1]): |
|
final_mask[i,j] = codes[argmax_masks[i,j]] if masks[argmax_masks[i,j],i,j]>=0.5 else 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.imshow((masks[0]>=0.5), cmap='gray') |
|
plt.savefig(os.path.join(args.save_path,'rescaled_preds', h5_name[:h5_name.find('.')]+"_"+str(i)+".png")) |
|
plt.close() |
|
|
|
plt.imshow((mask[0]), cmap='gray') |
|
plt.savefig(os.path.join(args.save_path,'rescaled_gt', h5_name[:h5_name.find('.')]+"_"+str(i)+".png")) |
|
plt.close() |
|
|
|
|
|
dices.append(dice_coef(mask, (masks>=0.5)+0)) |
|
ious.append(iou_coef(mask, (masks>=0.5)+0)) |
|
|
|
print(torch.mean(torch.Tensor(dices))) |
|
print(torch.mean(torch.Tensor(ious))) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|