Spaces:
Sleeping
Sleeping
# Copyright by HQ-SAM team | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import argparse | |
import numpy as np | |
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import matplotlib.pyplot as plt | |
import cv2 | |
import random | |
from typing import Dict, List, Tuple | |
from segment_anything_training import sam_model_registry | |
from segment_anything_training.modeling import TwoWayTransformer, MaskDecoder | |
from utils.dataloader import get_im_gt_name_dict, create_dataloaders, RandomHFlip, Resize, LargeScaleJitter | |
from utils.loss_mask import loss_masks | |
import utils.misc as misc | |
class LayerNorm2d(nn.Module): | |
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(num_channels)) | |
self.bias = nn.Parameter(torch.zeros(num_channels)) | |
self.eps = eps | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
hidden_dim: int, | |
output_dim: int, | |
num_layers: int, | |
sigmoid_output: bool = False, | |
) -> None: | |
super().__init__() | |
self.num_layers = num_layers | |
h = [hidden_dim] * (num_layers - 1) | |
self.layers = nn.ModuleList( | |
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) | |
) | |
self.sigmoid_output = sigmoid_output | |
def forward(self, x): | |
for i, layer in enumerate(self.layers): | |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
if self.sigmoid_output: | |
x = F.sigmoid(x) | |
return x | |
class MaskDecoderHQ(MaskDecoder): | |
def __init__(self, model_type): | |
super().__init__(transformer_dim=256, | |
transformer=TwoWayTransformer( | |
depth=2, | |
embedding_dim=256, | |
mlp_dim=2048, | |
num_heads=8, | |
), | |
num_multimask_outputs=3, | |
activation=nn.GELU, | |
iou_head_depth= 3, | |
iou_head_hidden_dim= 256,) | |
assert model_type in ["vit_b","vit_l","vit_h"] | |
checkpoint_dict = {"vit_b":"pretrained_checkpoint/sam_vit_b_maskdecoder.pth", | |
"vit_l":"pretrained_checkpoint/sam_vit_l_maskdecoder.pth", | |
'vit_h':"pretrained_checkpoint/sam_vit_h_maskdecoder.pth"} | |
checkpoint_path = checkpoint_dict[model_type] | |
self.load_state_dict(torch.load(checkpoint_path)) | |
print("HQ Decoder init from SAM MaskDecoder") | |
for n,p in self.named_parameters(): | |
p.requires_grad = False | |
transformer_dim=256 | |
vit_dim_dict = {"vit_b":768,"vit_l":1024,"vit_h":1280} | |
vit_dim = vit_dim_dict[model_type] | |
self.hf_token = nn.Embedding(1, transformer_dim) | |
self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) | |
self.num_mask_tokens = self.num_mask_tokens + 1 | |
self.compress_vit_feat = nn.Sequential( | |
nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2), | |
LayerNorm2d(transformer_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2)) | |
self.embedding_encoder = nn.Sequential( | |
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), | |
LayerNorm2d(transformer_dim // 4), | |
nn.GELU(), | |
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), | |
) | |
self.embedding_maskfeature = nn.Sequential( | |
nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), | |
LayerNorm2d(transformer_dim // 4), | |
nn.GELU(), | |
nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1)) | |
def forward( | |
self, | |
image_embeddings: torch.Tensor, | |
image_pe: torch.Tensor, | |
sparse_prompt_embeddings: torch.Tensor, | |
dense_prompt_embeddings: torch.Tensor, | |
multimask_output: bool, | |
hq_token_only: bool, | |
interm_embeddings: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Predict masks given image and prompt embeddings. | |
Arguments: | |
image_embeddings (torch.Tensor): the embeddings from the ViT image encoder | |
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings | |
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes | |
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs | |
multimask_output (bool): Whether to return multiple masks or a single | |
mask. | |
Returns: | |
torch.Tensor: batched predicted hq masks | |
""" | |
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT | |
hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) | |
batch_len = len(image_embeddings) | |
masks = [] | |
iou_preds = [] | |
for i_batch in range(batch_len): | |
mask, iou_pred = self.predict_masks( | |
image_embeddings=image_embeddings[i_batch].unsqueeze(0), | |
image_pe=image_pe[i_batch], | |
sparse_prompt_embeddings=sparse_prompt_embeddings[i_batch], | |
dense_prompt_embeddings=dense_prompt_embeddings[i_batch], | |
hq_feature = hq_features[i_batch].unsqueeze(0) | |
) | |
masks.append(mask) | |
iou_preds.append(iou_pred) | |
masks = torch.cat(masks,0) | |
iou_preds = torch.cat(iou_preds,0) | |
# Select the correct mask or masks for output | |
if multimask_output: | |
# mask with highest score | |
mask_slice = slice(1,self.num_mask_tokens-1) | |
iou_preds = iou_preds[:, mask_slice] | |
iou_preds, max_iou_idx = torch.max(iou_preds,dim=1) | |
iou_preds = iou_preds.unsqueeze(1) | |
masks_multi = masks[:, mask_slice, :, :] | |
masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) | |
else: | |
# singale mask output, default | |
mask_slice = slice(0, 1) | |
masks_sam = masks[:,mask_slice] | |
masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens), :, :] | |
if hq_token_only: | |
return masks_hq | |
else: | |
return masks_sam, masks_hq | |
def predict_masks( | |
self, | |
image_embeddings: torch.Tensor, | |
image_pe: torch.Tensor, | |
sparse_prompt_embeddings: torch.Tensor, | |
dense_prompt_embeddings: torch.Tensor, | |
hq_feature: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Predicts masks. See 'forward' for more details.""" | |
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0) | |
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) | |
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) | |
# Expand per-image data in batch direction to be per-mask | |
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) | |
src = src + dense_prompt_embeddings | |
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) | |
b, c, h, w = src.shape | |
# Run the transformer | |
hs, src = self.transformer(src, pos_src, tokens) | |
iou_token_out = hs[:, 0, :] | |
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] | |
# Upscale mask embeddings and predict masks using the mask tokens | |
src = src.transpose(1, 2).view(b, c, h, w) | |
upscaled_embedding_sam = self.output_upscaling(src) | |
upscaled_embedding_ours = self.embedding_maskfeature(upscaled_embedding_sam) + hq_feature | |
hyper_in_list: List[torch.Tensor] = [] | |
for i in range(self.num_mask_tokens): | |
if i < 4: | |
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) | |
else: | |
hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :])) | |
hyper_in = torch.stack(hyper_in_list, dim=1) | |
b, c, h, w = upscaled_embedding_sam.shape | |
masks_sam = (hyper_in[:,:4] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) | |
masks_ours = (hyper_in[:,4:] @ upscaled_embedding_ours.view(b, c, h * w)).view(b, -1, h, w) | |
masks = torch.cat([masks_sam,masks_ours],dim=1) | |
iou_pred = self.iou_prediction_head(iou_token_out) | |
return masks, iou_pred | |
def show_anns(masks, input_point, input_box, input_label, filename, image, ious, boundary_ious): | |
if len(masks) == 0: | |
return | |
for i, (mask, iou, biou) in enumerate(zip(masks, ious, boundary_ious)): | |
plt.figure(figsize=(10,10)) | |
plt.imshow(image) | |
show_mask(mask, plt.gca()) | |
if input_box is not None: | |
show_box(input_box, plt.gca()) | |
if (input_point is not None) and (input_label is not None): | |
show_points(input_point, input_label, plt.gca()) | |
plt.axis('off') | |
plt.savefig(filename+'_'+str(i)+'.png',bbox_inches='tight',pad_inches=-0.1) | |
plt.close() | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_points(coords, labels, ax, marker_size=375): | |
pos_points = coords[labels==1] | |
neg_points = coords[labels==0] | |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
def show_box(box, ax): | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - box[0], box[3] - box[1] | |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
def get_args_parser(): | |
parser = argparse.ArgumentParser('HQ-SAM', add_help=False) | |
parser.add_argument("--output", type=str, required=True, | |
help="Path to the directory where masks and checkpoints will be output") | |
parser.add_argument("--model-type", type=str, default="vit_l", | |
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']") | |
parser.add_argument("--checkpoint", type=str, required=True, | |
help="The path to the SAM checkpoint to use for mask generation.") | |
parser.add_argument("--device", type=str, default="cuda", | |
help="The device to run generation on.") | |
parser.add_argument('--seed', default=42, type=int) | |
parser.add_argument('--learning_rate', default=1e-3, type=float) | |
parser.add_argument('--start_epoch', default=0, type=int) | |
parser.add_argument('--lr_drop_epoch', default=10, type=int) | |
parser.add_argument('--max_epoch_num', default=12, type=int) | |
parser.add_argument('--input_size', default=[1024,1024], type=list) | |
parser.add_argument('--batch_size_train', default=4, type=int) | |
parser.add_argument('--batch_size_valid', default=1, type=int) | |
parser.add_argument('--model_save_fre', default=1, type=int) | |
parser.add_argument('--world_size', default=1, type=int, | |
help='number of distributed processes') | |
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') | |
parser.add_argument('--rank', default=0, type=int, | |
help='number of distributed processes') | |
parser.add_argument('--local_rank', type=int, help='local rank for dist') | |
parser.add_argument('--find_unused_params', action='store_true') | |
parser.add_argument('--eval', action='store_true') | |
parser.add_argument('--visualize', action='store_true') | |
parser.add_argument("--restore-model", type=str, | |
help="The path to the hq_decoder training checkpoint for evaluation") | |
return parser.parse_args() | |
def main(net, train_datasets, valid_datasets, args): | |
misc.init_distributed_mode(args) | |
print('world size: {}'.format(args.world_size)) | |
print('rank: {}'.format(args.rank)) | |
print('local_rank: {}'.format(args.local_rank)) | |
print("args: " + str(args) + '\n') | |
seed = args.seed + misc.get_rank() | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
### --- Step 1: Train or Valid dataset --- | |
if not args.eval: | |
print("--- create training dataloader ---") | |
train_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train") | |
train_dataloaders, train_datasets = create_dataloaders(train_im_gt_list, | |
my_transforms = [ | |
RandomHFlip(), | |
LargeScaleJitter() | |
], | |
batch_size = args.batch_size_train, | |
training = True) | |
print(len(train_dataloaders), " train dataloaders created") | |
print("--- create valid dataloader ---") | |
valid_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid") | |
valid_dataloaders, valid_datasets = create_dataloaders(valid_im_gt_list, | |
my_transforms = [ | |
Resize(args.input_size) | |
], | |
batch_size=args.batch_size_valid, | |
training=False) | |
print(len(valid_dataloaders), " valid dataloaders created") | |
### --- Step 2: DistributedDataParallel--- | |
if torch.cuda.is_available(): | |
net.cuda() | |
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu], find_unused_parameters=args.find_unused_params) | |
net_without_ddp = net.module | |
### --- Step 3: Train or Evaluate --- | |
if not args.eval: | |
print("--- define optimizer ---") | |
optimizer = optim.Adam(net_without_ddp.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) | |
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop_epoch) | |
lr_scheduler.last_epoch = args.start_epoch | |
train(args, net, optimizer, train_dataloaders, valid_dataloaders, lr_scheduler) | |
else: | |
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) | |
_ = sam.to(device=args.device) | |
sam = torch.nn.parallel.DistributedDataParallel(sam, device_ids=[args.gpu], find_unused_parameters=args.find_unused_params) | |
if args.restore_model: | |
print("restore model from:", args.restore_model) | |
if torch.cuda.is_available(): | |
net_without_ddp.load_state_dict(torch.load(args.restore_model)) | |
else: | |
net_without_ddp.load_state_dict(torch.load(args.restore_model,map_location="cpu")) | |
evaluate(args, net, sam, valid_dataloaders, args.visualize) | |
def train(args, net, optimizer, train_dataloaders, valid_dataloaders, lr_scheduler): | |
if misc.is_main_process(): | |
os.makedirs(args.output, exist_ok=True) | |
epoch_start = args.start_epoch | |
epoch_num = args.max_epoch_num | |
train_num = len(train_dataloaders) | |
net.train() | |
_ = net.to(device=args.device) | |
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) | |
_ = sam.to(device=args.device) | |
sam = torch.nn.parallel.DistributedDataParallel(sam, device_ids=[args.gpu], find_unused_parameters=args.find_unused_params) | |
for epoch in range(epoch_start,epoch_num): | |
print("epoch: ",epoch, " learning rate: ", optimizer.param_groups[0]["lr"]) | |
metric_logger = misc.MetricLogger(delimiter=" ") | |
train_dataloaders.batch_sampler.sampler.set_epoch(epoch) | |
for data in metric_logger.log_every(train_dataloaders,1000): | |
inputs, labels = data['image'], data['label'] | |
if torch.cuda.is_available(): | |
inputs = inputs.cuda() | |
labels = labels.cuda() | |
imgs = inputs.permute(0, 2, 3, 1).cpu().numpy() | |
# input prompt | |
input_keys = ['box','point','noise_mask'] | |
labels_box = misc.masks_to_boxes(labels[:,0,:,:]) | |
try: | |
labels_points = misc.masks_sample_points(labels[:,0,:,:]) | |
except: | |
# less than 10 points | |
input_keys = ['box','noise_mask'] | |
labels_256 = F.interpolate(labels, size=(256, 256), mode='bilinear') | |
labels_noisemask = misc.masks_noise(labels_256) | |
batched_input = [] | |
for b_i in range(len(imgs)): | |
dict_input = dict() | |
input_image = torch.as_tensor(imgs[b_i].astype(dtype=np.uint8), device=sam.device).permute(2, 0, 1).contiguous() | |
dict_input['image'] = input_image | |
input_type = random.choice(input_keys) | |
if input_type == 'box': | |
dict_input['boxes'] = labels_box[b_i:b_i+1] | |
elif input_type == 'point': | |
point_coords = labels_points[b_i:b_i+1] | |
dict_input['point_coords'] = point_coords | |
dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:] | |
elif input_type == 'noise_mask': | |
dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1] | |
else: | |
raise NotImplementedError | |
dict_input['original_size'] = imgs[b_i].shape[:2] | |
batched_input.append(dict_input) | |
with torch.no_grad(): | |
batched_output, interm_embeddings = sam(batched_input, multimask_output=False) | |
batch_len = len(batched_output) | |
encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0) | |
image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)] | |
sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)] | |
dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)] | |
masks_hq = net( | |
image_embeddings=encoder_embedding, | |
image_pe=image_pe, | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=False, | |
hq_token_only=True, | |
interm_embeddings=interm_embeddings, | |
) | |
loss_mask, loss_dice = loss_masks(masks_hq, labels/255.0, len(masks_hq)) | |
loss = loss_mask + loss_dice | |
loss_dict = {"loss_mask": loss_mask, "loss_dice":loss_dice} | |
# reduce losses over all GPUs for logging purposes | |
loss_dict_reduced = misc.reduce_dict(loss_dict) | |
losses_reduced_scaled = sum(loss_dict_reduced.values()) | |
loss_value = losses_reduced_scaled.item() | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
metric_logger.update(training_loss=loss_value, **loss_dict_reduced) | |
print("Finished epoch: ", epoch) | |
metric_logger.synchronize_between_processes() | |
print("Averaged stats:", metric_logger) | |
train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} | |
lr_scheduler.step() | |
test_stats = evaluate(args, net, sam, valid_dataloaders) | |
train_stats.update(test_stats) | |
net.train() | |
if epoch % args.model_save_fre == 0: | |
model_name = "/epoch_"+str(epoch)+".pth" | |
print('come here save at', args.output + model_name) | |
misc.save_on_master(net.module.state_dict(), args.output + model_name) | |
# Finish training | |
print("Training Reaches The Maximum Epoch Number") | |
# merge sam and hq_decoder | |
if misc.is_main_process(): | |
sam_ckpt = torch.load(args.checkpoint) | |
hq_decoder = torch.load(args.output + model_name) | |
for key in hq_decoder.keys(): | |
sam_key = 'mask_decoder.'+key | |
if sam_key not in sam_ckpt.keys(): | |
sam_ckpt[sam_key] = hq_decoder[key] | |
model_name = "/sam_hq_epoch_"+str(epoch)+".pth" | |
torch.save(sam_ckpt, args.output + model_name) | |
def compute_iou(preds, target): | |
assert target.shape[1] == 1, 'only support one mask per image now' | |
if(preds.shape[2]!=target.shape[2] or preds.shape[3]!=target.shape[3]): | |
postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode='bilinear', align_corners=False) | |
else: | |
postprocess_preds = preds | |
iou = 0 | |
for i in range(0,len(preds)): | |
iou = iou + misc.mask_iou(postprocess_preds[i],target[i]) | |
return iou / len(preds) | |
def compute_boundary_iou(preds, target): | |
assert target.shape[1] == 1, 'only support one mask per image now' | |
if(preds.shape[2]!=target.shape[2] or preds.shape[3]!=target.shape[3]): | |
postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode='bilinear', align_corners=False) | |
else: | |
postprocess_preds = preds | |
iou = 0 | |
for i in range(0,len(preds)): | |
iou = iou + misc.boundary_iou(target[i],postprocess_preds[i]) | |
return iou / len(preds) | |
def evaluate(args, net, sam, valid_dataloaders, visualize=False): | |
net.eval() | |
print("Validating...") | |
test_stats = {} | |
for k in range(len(valid_dataloaders)): | |
metric_logger = misc.MetricLogger(delimiter=" ") | |
valid_dataloader = valid_dataloaders[k] | |
print('valid_dataloader len:', len(valid_dataloader)) | |
for data_val in metric_logger.log_every(valid_dataloader,1000): | |
imidx_val, inputs_val, labels_val, shapes_val, labels_ori = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape'], data_val['ori_label'] | |
if torch.cuda.is_available(): | |
inputs_val = inputs_val.cuda() | |
labels_val = labels_val.cuda() | |
labels_ori = labels_ori.cuda() | |
imgs = inputs_val.permute(0, 2, 3, 1).cpu().numpy() | |
labels_box = misc.masks_to_boxes(labels_val[:,0,:,:]) | |
input_keys = ['box'] | |
batched_input = [] | |
for b_i in range(len(imgs)): | |
dict_input = dict() | |
input_image = torch.as_tensor(imgs[b_i].astype(dtype=np.uint8), device=sam.device).permute(2, 0, 1).contiguous() | |
dict_input['image'] = input_image | |
input_type = random.choice(input_keys) | |
if input_type == 'box': | |
dict_input['boxes'] = labels_box[b_i:b_i+1] | |
elif input_type == 'point': | |
point_coords = labels_points[b_i:b_i+1] | |
dict_input['point_coords'] = point_coords | |
dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:] | |
elif input_type == 'noise_mask': | |
dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1] | |
else: | |
raise NotImplementedError | |
dict_input['original_size'] = imgs[b_i].shape[:2] | |
batched_input.append(dict_input) | |
with torch.no_grad(): | |
batched_output, interm_embeddings = sam(batched_input, multimask_output=False) | |
batch_len = len(batched_output) | |
encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0) | |
image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)] | |
sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)] | |
dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)] | |
masks_sam, masks_hq = net( | |
image_embeddings=encoder_embedding, | |
image_pe=image_pe, | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=False, | |
hq_token_only=False, | |
interm_embeddings=interm_embeddings, | |
) | |
iou = compute_iou(masks_hq,labels_ori) | |
boundary_iou = compute_boundary_iou(masks_hq,labels_ori) | |
if visualize: | |
print("visualize") | |
os.makedirs(args.output, exist_ok=True) | |
masks_hq_vis = (F.interpolate(masks_hq.detach(), (1024, 1024), mode="bilinear", align_corners=False) > 0).cpu() | |
for ii in range(len(imgs)): | |
base = data_val['imidx'][ii].item() | |
print('base:', base) | |
save_base = os.path.join(args.output, str(k)+'_'+ str(base)) | |
imgs_ii = imgs[ii].astype(dtype=np.uint8) | |
show_iou = torch.tensor([iou.item()]) | |
show_boundary_iou = torch.tensor([boundary_iou.item()]) | |
show_anns(masks_hq_vis[ii], None, labels_box[ii].cpu(), None, save_base , imgs_ii, show_iou, show_boundary_iou) | |
loss_dict = {"val_iou_"+str(k): iou, "val_boundary_iou_"+str(k): boundary_iou} | |
loss_dict_reduced = misc.reduce_dict(loss_dict) | |
metric_logger.update(**loss_dict_reduced) | |
print('============================') | |
# gather the stats from all processes | |
metric_logger.synchronize_between_processes() | |
print("Averaged stats:", metric_logger) | |
resstat = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} | |
test_stats.update(resstat) | |
return test_stats | |
if __name__ == "__main__": | |
### --------------- Configuring the Train and Valid datasets --------------- | |
dataset_dis = {"name": "DIS5K-TR", | |
"im_dir": "./data/DIS5K/DIS-TR/im", | |
"gt_dir": "./data/DIS5K/DIS-TR/gt", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
dataset_thin = {"name": "ThinObject5k-TR", | |
"im_dir": "./data/thin_object_detection/ThinObject5K/images_train", | |
"gt_dir": "./data/thin_object_detection/ThinObject5K/masks_train", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
dataset_fss = {"name": "FSS", | |
"im_dir": "./data/cascade_psp/fss_all", | |
"gt_dir": "./data/cascade_psp/fss_all", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
dataset_duts = {"name": "DUTS-TR", | |
"im_dir": "./data/cascade_psp/DUTS-TR", | |
"gt_dir": "./data/cascade_psp/DUTS-TR", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
dataset_duts_te = {"name": "DUTS-TE", | |
"im_dir": "./data/cascade_psp/DUTS-TE", | |
"gt_dir": "./data/cascade_psp/DUTS-TE", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
dataset_ecssd = {"name": "ECSSD", | |
"im_dir": "./data/cascade_psp/ecssd", | |
"gt_dir": "./data/cascade_psp/ecssd", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
dataset_msra = {"name": "MSRA10K", | |
"im_dir": "./data/cascade_psp/MSRA_10K", | |
"gt_dir": "./data/cascade_psp/MSRA_10K", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
# valid set | |
dataset_coift_val = {"name": "COIFT", | |
"im_dir": "./data/thin_object_detection/COIFT/images", | |
"gt_dir": "./data/thin_object_detection/COIFT/masks", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
dataset_hrsod_val = {"name": "HRSOD", | |
"im_dir": "./data/thin_object_detection/HRSOD/images", | |
"gt_dir": "./data/thin_object_detection/HRSOD/masks_max255", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
dataset_thin_val = {"name": "ThinObject5k-TE", | |
"im_dir": "./data/thin_object_detection/ThinObject5K/images_test", | |
"gt_dir": "./data/thin_object_detection/ThinObject5K/masks_test", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
dataset_dis_val = {"name": "DIS5K-VD", | |
"im_dir": "./data/DIS5K/DIS-VD/im", | |
"gt_dir": "./data/DIS5K/DIS-VD/gt", | |
"im_ext": ".jpg", | |
"gt_ext": ".png"} | |
train_datasets = [dataset_dis, dataset_thin, dataset_fss, dataset_duts, dataset_duts_te, dataset_ecssd, dataset_msra] | |
valid_datasets = [dataset_dis_val, dataset_coift_val, dataset_hrsod_val, dataset_thin_val] | |
args = get_args_parser() | |
net = MaskDecoderHQ(args.model_type) | |
main(net, train_datasets, valid_datasets, args) | |