SALT-SAM / AllinonSAM /eval /endovis /generate_predictions_baselines.py
pythn's picture
Upload with huggingface_hub
4a1f918 verified
import torch
import yaml
import sys
import copy
import os
sys.path.append("/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/biastuning/")
from data_utils import *
from model import *
from utils import *
from baselines import UNet, UNext, medt_net
from vit_seg_modeling import VisionTransformer
from vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from axialnet import MedT
label_names = ['Left Prograsp Forceps', 'Maryland Bipolar Forceps', 'Right Prograsp Forceps', 'Left Large Needle Driver', 'Right Large Needle Driver', 'Left Grasping Retractor', 'Right Grasping Retractor', 'Vessel Sealer', 'Monopolar Curved Scissors']
label_dict = {}
visualize_dict = {}
for i,ln in enumerate(label_names):
label_dict[ln] = i
# visualize_dict[ln] = visualize_li[i]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_folder', default='config_tmp.yml',
help='data folder file path')
parser.add_argument('--data_config', default='config_tmp.yml',
help='data config file path')
parser.add_argument('--model_config', default='model_baseline.yml',
help='model config file path')
parser.add_argument('--pretrained_path', default=None,
help='pretrained model path')
parser.add_argument('--save_path', default='checkpoints/temp.pth',
help='pretrained model path')
parser.add_argument('--gt_path', default='',
help='ground truth path')
parser.add_argument('--device', default='cuda:0', help='device to train on')
parser.add_argument('--codes', default='1,2,1,3,3', help='numeric label to save per instrument')
args = parser.parse_args()
return args
def main():
args = parse_args()
with open(args.data_config, 'r') as f:
data_config = yaml.load(f, Loader=yaml.FullLoader)
with open(args.model_config, 'r') as f:
model_config = yaml.load(f, Loader=yaml.FullLoader)
codes = args.codes.split(',')
codes = [int(c) for c in codes]
label_dict = {
'Left Prograsp Forceps': 2,
'Maryland Bipolar Forceps': 1,
'Right Prograsp Forceps': 2,
'Left Large Needle Driver': 3,
'Right Large Needle Driver': 3,
'Left Grasping Retractor': 5,
'Right Grasping Retractor': 5,
'Vessel Sealer': 4,
'Monopolar Curved Scissors': 6
}
#make folder to save visualizations
os.makedirs(os.path.join(args.save_path,"preds"),exist_ok=True)
os.makedirs(os.path.join(args.save_path,"rescaled_preds"),exist_ok=True)
if args.gt_path:
os.makedirs(os.path.join(args.save_path,"rescaled_gt"),exist_ok=True)
#load model
#change the img size in model config according to data config
in_channels = model_config['in_channels']
out_channels = model_config['num_classes']
img_size = model_config['img_size']
if model_config['arch']=='Prompt Adapted SAM':
model = Prompt_Adapted_SAM(model_config, label_dict, args.device, training_strategy='biastuning')
elif model_config['arch']=='UNet':
model = UNet(in_channels=in_channels, out_channels=out_channels)
elif model_config['arch']=='UNext':
model = UNext(num_classes=out_channels, input_channels=in_channels, img_size=img_size)
elif model_config['arch']=='MedT':
#TODO
model = MedT(img_size=img_size, num_classes=out_channels)
elif model_config['arch']=='TransUNet':
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
config_vit.n_classes = out_channels
config_vit.n_skip = 3
# if args.vit_name.find('R50') != -1:
# config_vit.patches.grid = (int(args.img_size / args.vit_patches_size), int(args.img_size / args.vit_patches_size))
model = VisionTransformer(config_vit, img_size=img_size, num_classes=config_vit.n_classes)
model.load_state_dict(torch.load(args.pretrained_path, map_location=args.device))
model = model.to(args.device)
model = model.eval()
#load data transform
data_transform = ENDOVIS_Transform(config=data_config)
#dice
dices = []
ious = []
#load data
for i,img_name in enumerate(sorted(os.listdir(args.data_folder))):
# if i%5!=0:
# continue
img_path = (os.path.join(args.data_folder,img_name))
if args.gt_path:
#for test data, the labels are arranged differently so uncomment the line below
gt_path = (os.path.join(args.gt_path,img_name))
# gt_path = (os.path.join(args.gt_path,label_name,img_name))
# print(img_path)
img = torch.as_tensor(np.array(Image.open(img_path).convert("RGB")))
img = img.permute(2,0,1)
C,H,W = img.shape
#make a dummy mask of shape 1XHXW
if args.gt_path:
label = torch.as_tensor(np.array(Image.open(gt_path)))
mask = np.zeros((len(label_dict),img.shape[1], img.shape[2]))
for i,c in enumerate(list(label_dict.keys())):
mask[i,:,:] = ((label==label_dict[c])+0)
mask = torch.as_tensor(mask)
else:
mask = torch.zeros((len(label_dict),H,W))
img, mask = data_transform(img, mask, is_train=False, apply_norm=True)
mask = (mask>=0.5)+0
#get image embeddings
img = img.unsqueeze(0).to(args.device) #1XCXHXW
masks = model(img,'')
argmax_masks = torch.argmax(masks, dim=1).cpu().numpy()
# print("argmax masks shape: ",argmax_masks.shape)
classwise_dices = []
classwise_ious = []
for j,c1 in enumerate(label_dict):
res = np.where(argmax_masks==j,1,0)
# print("res shape: ",res.shape)
plt.imshow(res[0], cmap='gray')
save_dir = os.path.join(args.save_path, c1, 'rescaled_preds')
os.makedirs(save_dir, exist_ok=True)
plt.savefig(os.path.join(args.save_path, c1, 'rescaled_preds', img_name))
plt.close()
if args.gt_path:
plt.imshow((mask[j]), cmap='gray')
save_dir = os.path.join(args.save_path, c1, 'rescaled_gt')
os.makedirs(save_dir, exist_ok=True)
plt.savefig(os.path.join(args.save_path, c1, 'rescaled_gt', img_name))
plt.close()
classwise_dices.append(dice_coef(mask[j], torch.Tensor(res[0])))
classwise_ious.append(iou_coef(mask[j], torch.Tensor(res[0])))
# break
dices.append(classwise_dices)
ious.append(classwise_ious)
# print("classwise_dices: ", classwise_dices)
# print("classwise ious: ", classwise_ious)
print(torch.mean(torch.Tensor(dices),dim=0))
print(torch.mean(torch.Tensor(ious),dim=0))
if __name__ == '__main__':
main()
# {
# "Bipolar Forceps": 1,
# "Prograsp Forceps": 2,
# "Large Needle Driver": 3,
# "Vessel Sealer": 4,
# "Grasping Retractor": 5,
# "Monopolar Curved Scissors": 6,
# "Other": 7
# }