|
import numpy as np |
|
|
|
PATCH_SIZE = 256 |
|
OVERLAP = 32 |
|
|
|
def split_image_into_patches(image): |
|
height, width, _ = image.shape |
|
patches = [] |
|
|
|
for y in range(0, height-PATCH_SIZE+1, PATCH_SIZE-OVERLAP): |
|
for x in range(0, width-PATCH_SIZE+1, PATCH_SIZE-OVERLAP): |
|
patch = (y,x,image[y:y+PATCH_SIZE, x:x+PATCH_SIZE]) |
|
patches.append(patch) |
|
|
|
return patches |
|
|
|
def stitch_patches_to_image(patches, image_shape): |
|
stitched_image = np.zeros(image_shape) |
|
overlap_mask = np.zeros(image_shape[:2])+1e-10 |
|
|
|
for patch in patches: |
|
y, x, p = patch |
|
try: |
|
|
|
stitched_image[y:y+PATCH_SIZE, x:x+PATCH_SIZE] += p |
|
overlap_mask[y:y+PATCH_SIZE, x:x+PATCH_SIZE] += 1 |
|
except: |
|
print(p.shape) |
|
print(y,x) |
|
print(image_shape) |
|
1/0 |
|
|
|
|
|
stitched_image = ((stitched_image/overlap_mask)>0.5)+0 |
|
|
|
return stitched_image.astype(np.uint8) |
|
|
|
import torch |
|
import yaml |
|
import sys |
|
import copy |
|
import os |
|
sys.path.append("/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/biastuning/") |
|
|
|
from data_utils import * |
|
from model import * |
|
from utils import * |
|
|
|
label_names = ['Left Prograsp Forceps', 'Maryland Bipolar Forceps', 'Right Prograsp Forceps', 'Left Large Needle Driver', 'Right Large Needle Driver'] |
|
visualize_li = [[1,0,0],[0,1,0],[1,0,0], [0,0,1], [0,0,1]] |
|
label_dict = {} |
|
visualize_dict = {} |
|
for i,ln in enumerate(label_names): |
|
label_dict[ln] = i |
|
visualize_dict[ln] = visualize_li[i] |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--data_folder', default='config_tmp.yml', |
|
help='data folder file path') |
|
|
|
parser.add_argument('--data_config', default='config_tmp.yml', |
|
help='data config file path') |
|
|
|
parser.add_argument('--model_config', default='model_baseline.yml', |
|
help='model config file path') |
|
|
|
parser.add_argument('--pretrained_path', default=None, |
|
help='pretrained model path') |
|
|
|
parser.add_argument('--save_path', default='checkpoints/temp.pth', |
|
help='pretrained model path') |
|
|
|
parser.add_argument('--device', default='cuda:0', help='device to train on') |
|
|
|
parser.add_argument('--labels_of_interest', default='Left Prograsp Forceps,Maryland Bipolar Forceps,Right Prograsp Forceps,Left Large Needle Driver,Right Large Needle Driver', help='labels of interest') |
|
|
|
parser.add_argument('--codes', default='1,2,1,3,3', help='numeric label to save per instrument') |
|
|
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
def main(): |
|
args = parse_args() |
|
with open(args.data_config, 'r') as f: |
|
data_config = yaml.load(f, Loader=yaml.FullLoader) |
|
with open(args.model_config, 'r') as f: |
|
model_config = yaml.load(f, Loader=yaml.FullLoader) |
|
labels_of_interest = args.labels_of_interest.split(',') |
|
codes = args.codes.split(',') |
|
codes = [int(c) for c in codes] |
|
|
|
|
|
|
|
os.makedirs(os.path.join(args.save_path,"preds"),exist_ok=True) |
|
os.makedirs(os.path.join(args.save_path,"rescaled_preds"),exist_ok=True) |
|
|
|
|
|
model = Prompt_Adapted_SAM(config=model_config, label_text_dict=label_dict, device=args.device) |
|
model.load_state_dict(torch.load(args.pretrained_path, map_location=args.device)) |
|
model = model.eval() |
|
model = model.to(args.device) |
|
|
|
|
|
data_transform = ENDOVIS_Transform(config=data_config) |
|
|
|
|
|
for img_name in sorted(os.listdir(args.data_folder)): |
|
img_path = (os.path.join(args.data_folder,img_name)) |
|
|
|
original_img = torch.as_tensor(np.array(Image.open(img_path).convert("RGB"))) |
|
patches = split_image_into_patches(original_img) |
|
patch_masks = [] |
|
|
|
for y,x,p in patches: |
|
img = p.permute(2,0,1) |
|
|
|
label = torch.zeros(img.shape)[0].unsqueeze(0) |
|
img, _ = data_transform(img, label, is_train=False, apply_norm=True, crop=False, resize=False) |
|
|
|
|
|
img = img.unsqueeze(0).to(args.device) |
|
img_embeds = model.get_image_embeddings(img) |
|
|
|
|
|
img_embeds_repeated = img_embeds.repeat(len(labels_of_interest),1,1,1) |
|
x_text = [t for t in labels_of_interest] |
|
masks = model.get_masks_for_multiple_labels(img_embeds_repeated, x_text).cpu() |
|
|
|
|
|
masks, max_idxs = torch.max(masks,dim=0) |
|
patch_masks.append((y,x,masks.numpy())) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("original shape: ", original_img.shape) |
|
final_mask = stitch_patches_to_image(patch_masks, original_img.shape[:2]) |
|
print("final mask shape: ",final_mask.shape) |
|
save_im = Image.fromarray(final_mask) |
|
save_im.save(os.path.join(args.save_path,'preds', img_name)) |
|
|
|
|
|
|
|
|
|
break |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|