File size: 3,064 Bytes
4a1f918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
from PIL import Image
import sys
from matplotlib import pyplot as plt
import torch

sys.path.append("/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/SVDSAM/")
from utils import *

# test_path = "/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/SVDSAM/eval/endovis/svdshiftscale_ev17_tal_focal075_alpha2_1e-3"
test_path = "endovis17_lora16"

#when not differentiating between the forceps, add mbp to the first tuple
# instruments = [('lgr','rgr'),('llnd','rlnd'),('lpf','rpf')]
instruments = [('Left Grasping Retractor','Right Grasping Retractor'),('Left Large Needle Driver','Right Large Needle Driver'),('Left Prograsp Forceps','Right Prograsp Forceps')]

for dataset in sorted(os.listdir(test_path)):
    for instrument in instruments:
        dices = []
        ious = []
        if len(instrument)==3:
            gt_path1 = os.path.join(test_path, dataset,instrument[0],'rescaled_gt')
            gt_path2 = os.path.join(test_path, dataset,instrument[2],'rescaled_gt')
            extra_preds_path = os.path.join(test_path, dataset,instrument[2],'rescaled_preds')
        else:
            gt_path = os.path.join(test_path, dataset,instrument[0],'rescaled_gt')
        left_preds_path = os.path.join(test_path, dataset,instrument[0],'rescaled_preds')
        right_preds_path = os.path.join(test_path, dataset,instrument[1],'rescaled_preds')
        for frame in sorted(os.listdir(left_preds_path)):
            if len(instrument)==3:
                gold1 = ((plt.imread(os.path.join(gt_path1,frame))[:,:,0][58:-52,143:-126])>=0.5)+0
                gold2 = ((plt.imread(os.path.join(gt_path2,frame))[:,:,0][58:-52,143:-126])>=0.5)+0
                extra_pred = ((plt.imread(os.path.join(extra_preds_path, frame))[:,:,0][58:-52,143:-126])>=0.5)
                gold = (gold1 | gold2)+0
            else:
                gold = ((plt.imread(os.path.join(gt_path,frame))[:,:,0][58:-52,143:-126])>=0.5)+0
            left_pred = ((plt.imread(os.path.join(left_preds_path, frame))[:,:,0][58:-52,143:-126])>=0.5)
            right_pred = ((plt.imread(os.path.join(right_preds_path, frame))[:,:,0][58:-52,143:-126])>=0.5)
            
            pred = (left_pred | right_pred)
            if len(instrument)==3:
                pred = (pred | extra_pred)
            pred = pred + 0
            gold = torch.Tensor(gold).unsqueeze(0)
            pred = torch.Tensor(pred).unsqueeze(0)
            dices.append(dice_coef(gold, pred))
            ious.append(iou_coef(gold, pred))

       
        # if instrument==('lpf','rpf') and dataset=='instrument_2':
        #     print(dices)
        #     print(os.path.join(left_preds_path, frame))
        #     plt.imshow(plt.imread(os.path.join(left_preds_path, frame)),cmap='gray')
        #     plt.imshow(pred[0],'gray')
        #     plt.show()
        #     plt.imshow(gold[0],cmap='gray')
        #     plt.show()
        #     1/0

        print(f"Dataset: {dataset}, instrument: {instrument}, dice: {torch.mean(torch.Tensor(dices))}, iou: {torch.mean(torch.Tensor(ious))}")
    print('\n')