SALT-SAM / AllinonSAM /eval /endovis /testing_results_dice.py
pythn's picture
Upload with huggingface_hub
4a1f918 verified
import os
from PIL import Image
import sys
from matplotlib import pyplot as plt
import torch
sys.path.append("/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/SVDSAM/")
from utils import *
# test_path = "/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/SVDSAM/eval/endovis/svdshiftscale_ev17_tal_focal075_alpha2_1e-3"
test_path = "endovis17_lora16"
#when not differentiating between the forceps, add mbp to the first tuple
# instruments = [('lgr','rgr'),('llnd','rlnd'),('lpf','rpf')]
instruments = [('Left Grasping Retractor','Right Grasping Retractor'),('Left Large Needle Driver','Right Large Needle Driver'),('Left Prograsp Forceps','Right Prograsp Forceps')]
for dataset in sorted(os.listdir(test_path)):
for instrument in instruments:
dices = []
ious = []
if len(instrument)==3:
gt_path1 = os.path.join(test_path, dataset,instrument[0],'rescaled_gt')
gt_path2 = os.path.join(test_path, dataset,instrument[2],'rescaled_gt')
extra_preds_path = os.path.join(test_path, dataset,instrument[2],'rescaled_preds')
else:
gt_path = os.path.join(test_path, dataset,instrument[0],'rescaled_gt')
left_preds_path = os.path.join(test_path, dataset,instrument[0],'rescaled_preds')
right_preds_path = os.path.join(test_path, dataset,instrument[1],'rescaled_preds')
for frame in sorted(os.listdir(left_preds_path)):
if len(instrument)==3:
gold1 = ((plt.imread(os.path.join(gt_path1,frame))[:,:,0][58:-52,143:-126])>=0.5)+0
gold2 = ((plt.imread(os.path.join(gt_path2,frame))[:,:,0][58:-52,143:-126])>=0.5)+0
extra_pred = ((plt.imread(os.path.join(extra_preds_path, frame))[:,:,0][58:-52,143:-126])>=0.5)
gold = (gold1 | gold2)+0
else:
gold = ((plt.imread(os.path.join(gt_path,frame))[:,:,0][58:-52,143:-126])>=0.5)+0
left_pred = ((plt.imread(os.path.join(left_preds_path, frame))[:,:,0][58:-52,143:-126])>=0.5)
right_pred = ((plt.imread(os.path.join(right_preds_path, frame))[:,:,0][58:-52,143:-126])>=0.5)
pred = (left_pred | right_pred)
if len(instrument)==3:
pred = (pred | extra_pred)
pred = pred + 0
gold = torch.Tensor(gold).unsqueeze(0)
pred = torch.Tensor(pred).unsqueeze(0)
dices.append(dice_coef(gold, pred))
ious.append(iou_coef(gold, pred))
# if instrument==('lpf','rpf') and dataset=='instrument_2':
# print(dices)
# print(os.path.join(left_preds_path, frame))
# plt.imshow(plt.imread(os.path.join(left_preds_path, frame)),cmap='gray')
# plt.imshow(pred[0],'gray')
# plt.show()
# plt.imshow(gold[0],cmap='gray')
# plt.show()
# 1/0
print(f"Dataset: {dataset}, instrument: {instrument}, dice: {torch.mean(torch.Tensor(dices))}, iou: {torch.mean(torch.Tensor(ious))}")
print('\n')