AndreasLH's picture
upload repo
56bd2b5
# Copyright by HQ-SAM team
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import argparse
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import cv2
import random
from typing import Dict, List, Tuple
from segment_anything_training import sam_model_registry
from segment_anything_training.modeling import TwoWayTransformer, MaskDecoder
from utils.dataloader import get_im_gt_name_dict, create_dataloaders, RandomHFlip, Resize, LargeScaleJitter
from utils.loss_mask import loss_masks
import utils.misc as misc
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x
class MaskDecoderHQ(MaskDecoder):
def __init__(self, model_type):
super().__init__(transformer_dim=256,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=256,
mlp_dim=2048,
num_heads=8,
),
num_multimask_outputs=3,
activation=nn.GELU,
iou_head_depth= 3,
iou_head_hidden_dim= 256,)
assert model_type in ["vit_b","vit_l","vit_h"]
checkpoint_dict = {"vit_b":"pretrained_checkpoint/sam_vit_b_maskdecoder.pth",
"vit_l":"pretrained_checkpoint/sam_vit_l_maskdecoder.pth",
'vit_h':"pretrained_checkpoint/sam_vit_h_maskdecoder.pth"}
checkpoint_path = checkpoint_dict[model_type]
self.load_state_dict(torch.load(checkpoint_path))
print("HQ Decoder init from SAM MaskDecoder")
for n,p in self.named_parameters():
p.requires_grad = False
transformer_dim=256
vit_dim_dict = {"vit_b":768,"vit_l":1024,"vit_h":1280}
vit_dim = vit_dim_dict[model_type]
self.hf_token = nn.Embedding(1, transformer_dim)
self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
self.num_mask_tokens = self.num_mask_tokens + 1
self.compress_vit_feat = nn.Sequential(
nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim),
nn.GELU(),
nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2))
self.embedding_encoder = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
nn.GELU(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
)
self.embedding_maskfeature = nn.Sequential(
nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1),
LayerNorm2d(transformer_dim // 4),
nn.GELU(),
nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1))
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
hq_token_only: bool,
interm_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted hq masks
"""
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
batch_len = len(image_embeddings)
masks = []
iou_preds = []
for i_batch in range(batch_len):
mask, iou_pred = self.predict_masks(
image_embeddings=image_embeddings[i_batch].unsqueeze(0),
image_pe=image_pe[i_batch],
sparse_prompt_embeddings=sparse_prompt_embeddings[i_batch],
dense_prompt_embeddings=dense_prompt_embeddings[i_batch],
hq_feature = hq_features[i_batch].unsqueeze(0)
)
masks.append(mask)
iou_preds.append(iou_pred)
masks = torch.cat(masks,0)
iou_preds = torch.cat(iou_preds,0)
# Select the correct mask or masks for output
if multimask_output:
# mask with highest score
mask_slice = slice(1,self.num_mask_tokens-1)
iou_preds = iou_preds[:, mask_slice]
iou_preds, max_iou_idx = torch.max(iou_preds,dim=1)
iou_preds = iou_preds.unsqueeze(1)
masks_multi = masks[:, mask_slice, :, :]
masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
else:
# singale mask output, default
mask_slice = slice(0, 1)
masks_sam = masks[:,mask_slice]
masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens), :, :]
if hq_token_only:
return masks_hq
else:
return masks_sam, masks_hq
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
hq_feature: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding_sam = self.output_upscaling(src)
upscaled_embedding_ours = self.embedding_maskfeature(upscaled_embedding_sam) + hq_feature
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
if i < 4:
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
else:
hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding_sam.shape
masks_sam = (hyper_in[:,:4] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
masks_ours = (hyper_in[:,4:] @ upscaled_embedding_ours.view(b, c, h * w)).view(b, -1, h, w)
masks = torch.cat([masks_sam,masks_ours],dim=1)
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
def show_anns(masks, input_point, input_box, input_label, filename, image, ious, boundary_ious):
if len(masks) == 0:
return
for i, (mask, iou, biou) in enumerate(zip(masks, ious, boundary_ious)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
if input_box is not None:
show_box(input_box, plt.gca())
if (input_point is not None) and (input_label is not None):
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.savefig(filename+'_'+str(i)+'.png',bbox_inches='tight',pad_inches=-0.1)
plt.close()
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
def get_args_parser():
parser = argparse.ArgumentParser('HQ-SAM', add_help=False)
parser.add_argument("--output", type=str, required=True,
help="Path to the directory where masks and checkpoints will be output")
parser.add_argument("--model-type", type=str, default="vit_l",
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']")
parser.add_argument("--checkpoint", type=str, required=True,
help="The path to the SAM checkpoint to use for mask generation.")
parser.add_argument("--device", type=str, default="cuda",
help="The device to run generation on.")
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--learning_rate', default=1e-3, type=float)
parser.add_argument('--start_epoch', default=0, type=int)
parser.add_argument('--lr_drop_epoch', default=10, type=int)
parser.add_argument('--max_epoch_num', default=12, type=int)
parser.add_argument('--input_size', default=[1024,1024], type=list)
parser.add_argument('--batch_size_train', default=4, type=int)
parser.add_argument('--batch_size_valid', default=1, type=int)
parser.add_argument('--model_save_fre', default=1, type=int)
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--rank', default=0, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', type=int, help='local rank for dist')
parser.add_argument('--find_unused_params', action='store_true')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--visualize', action='store_true')
parser.add_argument("--restore-model", type=str,
help="The path to the hq_decoder training checkpoint for evaluation")
return parser.parse_args()
def main(net, train_datasets, valid_datasets, args):
misc.init_distributed_mode(args)
print('world size: {}'.format(args.world_size))
print('rank: {}'.format(args.rank))
print('local_rank: {}'.format(args.local_rank))
print("args: " + str(args) + '\n')
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
### --- Step 1: Train or Valid dataset ---
if not args.eval:
print("--- create training dataloader ---")
train_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
train_dataloaders, train_datasets = create_dataloaders(train_im_gt_list,
my_transforms = [
RandomHFlip(),
LargeScaleJitter()
],
batch_size = args.batch_size_train,
training = True)
print(len(train_dataloaders), " train dataloaders created")
print("--- create valid dataloader ---")
valid_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
valid_dataloaders, valid_datasets = create_dataloaders(valid_im_gt_list,
my_transforms = [
Resize(args.input_size)
],
batch_size=args.batch_size_valid,
training=False)
print(len(valid_dataloaders), " valid dataloaders created")
### --- Step 2: DistributedDataParallel---
if torch.cuda.is_available():
net.cuda()
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu], find_unused_parameters=args.find_unused_params)
net_without_ddp = net.module
### --- Step 3: Train or Evaluate ---
if not args.eval:
print("--- define optimizer ---")
optimizer = optim.Adam(net_without_ddp.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop_epoch)
lr_scheduler.last_epoch = args.start_epoch
train(args, net, optimizer, train_dataloaders, valid_dataloaders, lr_scheduler)
else:
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
_ = sam.to(device=args.device)
sam = torch.nn.parallel.DistributedDataParallel(sam, device_ids=[args.gpu], find_unused_parameters=args.find_unused_params)
if args.restore_model:
print("restore model from:", args.restore_model)
if torch.cuda.is_available():
net_without_ddp.load_state_dict(torch.load(args.restore_model))
else:
net_without_ddp.load_state_dict(torch.load(args.restore_model,map_location="cpu"))
evaluate(args, net, sam, valid_dataloaders, args.visualize)
def train(args, net, optimizer, train_dataloaders, valid_dataloaders, lr_scheduler):
if misc.is_main_process():
os.makedirs(args.output, exist_ok=True)
epoch_start = args.start_epoch
epoch_num = args.max_epoch_num
train_num = len(train_dataloaders)
net.train()
_ = net.to(device=args.device)
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
_ = sam.to(device=args.device)
sam = torch.nn.parallel.DistributedDataParallel(sam, device_ids=[args.gpu], find_unused_parameters=args.find_unused_params)
for epoch in range(epoch_start,epoch_num):
print("epoch: ",epoch, " learning rate: ", optimizer.param_groups[0]["lr"])
metric_logger = misc.MetricLogger(delimiter=" ")
train_dataloaders.batch_sampler.sampler.set_epoch(epoch)
for data in metric_logger.log_every(train_dataloaders,1000):
inputs, labels = data['image'], data['label']
if torch.cuda.is_available():
inputs = inputs.cuda()
labels = labels.cuda()
imgs = inputs.permute(0, 2, 3, 1).cpu().numpy()
# input prompt
input_keys = ['box','point','noise_mask']
labels_box = misc.masks_to_boxes(labels[:,0,:,:])
try:
labels_points = misc.masks_sample_points(labels[:,0,:,:])
except:
# less than 10 points
input_keys = ['box','noise_mask']
labels_256 = F.interpolate(labels, size=(256, 256), mode='bilinear')
labels_noisemask = misc.masks_noise(labels_256)
batched_input = []
for b_i in range(len(imgs)):
dict_input = dict()
input_image = torch.as_tensor(imgs[b_i].astype(dtype=np.uint8), device=sam.device).permute(2, 0, 1).contiguous()
dict_input['image'] = input_image
input_type = random.choice(input_keys)
if input_type == 'box':
dict_input['boxes'] = labels_box[b_i:b_i+1]
elif input_type == 'point':
point_coords = labels_points[b_i:b_i+1]
dict_input['point_coords'] = point_coords
dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:]
elif input_type == 'noise_mask':
dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1]
else:
raise NotImplementedError
dict_input['original_size'] = imgs[b_i].shape[:2]
batched_input.append(dict_input)
with torch.no_grad():
batched_output, interm_embeddings = sam(batched_input, multimask_output=False)
batch_len = len(batched_output)
encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0)
image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)]
sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)]
dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)]
masks_hq = net(
image_embeddings=encoder_embedding,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
hq_token_only=True,
interm_embeddings=interm_embeddings,
)
loss_mask, loss_dice = loss_masks(masks_hq, labels/255.0, len(masks_hq))
loss = loss_mask + loss_dice
loss_dict = {"loss_mask": loss_mask, "loss_dice":loss_dice}
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = misc.reduce_dict(loss_dict)
losses_reduced_scaled = sum(loss_dict_reduced.values())
loss_value = losses_reduced_scaled.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
metric_logger.update(training_loss=loss_value, **loss_dict_reduced)
print("Finished epoch: ", epoch)
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0}
lr_scheduler.step()
test_stats = evaluate(args, net, sam, valid_dataloaders)
train_stats.update(test_stats)
net.train()
if epoch % args.model_save_fre == 0:
model_name = "/epoch_"+str(epoch)+".pth"
print('come here save at', args.output + model_name)
misc.save_on_master(net.module.state_dict(), args.output + model_name)
# Finish training
print("Training Reaches The Maximum Epoch Number")
# merge sam and hq_decoder
if misc.is_main_process():
sam_ckpt = torch.load(args.checkpoint)
hq_decoder = torch.load(args.output + model_name)
for key in hq_decoder.keys():
sam_key = 'mask_decoder.'+key
if sam_key not in sam_ckpt.keys():
sam_ckpt[sam_key] = hq_decoder[key]
model_name = "/sam_hq_epoch_"+str(epoch)+".pth"
torch.save(sam_ckpt, args.output + model_name)
def compute_iou(preds, target):
assert target.shape[1] == 1, 'only support one mask per image now'
if(preds.shape[2]!=target.shape[2] or preds.shape[3]!=target.shape[3]):
postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode='bilinear', align_corners=False)
else:
postprocess_preds = preds
iou = 0
for i in range(0,len(preds)):
iou = iou + misc.mask_iou(postprocess_preds[i],target[i])
return iou / len(preds)
def compute_boundary_iou(preds, target):
assert target.shape[1] == 1, 'only support one mask per image now'
if(preds.shape[2]!=target.shape[2] or preds.shape[3]!=target.shape[3]):
postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode='bilinear', align_corners=False)
else:
postprocess_preds = preds
iou = 0
for i in range(0,len(preds)):
iou = iou + misc.boundary_iou(target[i],postprocess_preds[i])
return iou / len(preds)
def evaluate(args, net, sam, valid_dataloaders, visualize=False):
net.eval()
print("Validating...")
test_stats = {}
for k in range(len(valid_dataloaders)):
metric_logger = misc.MetricLogger(delimiter=" ")
valid_dataloader = valid_dataloaders[k]
print('valid_dataloader len:', len(valid_dataloader))
for data_val in metric_logger.log_every(valid_dataloader,1000):
imidx_val, inputs_val, labels_val, shapes_val, labels_ori = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape'], data_val['ori_label']
if torch.cuda.is_available():
inputs_val = inputs_val.cuda()
labels_val = labels_val.cuda()
labels_ori = labels_ori.cuda()
imgs = inputs_val.permute(0, 2, 3, 1).cpu().numpy()
labels_box = misc.masks_to_boxes(labels_val[:,0,:,:])
input_keys = ['box']
batched_input = []
for b_i in range(len(imgs)):
dict_input = dict()
input_image = torch.as_tensor(imgs[b_i].astype(dtype=np.uint8), device=sam.device).permute(2, 0, 1).contiguous()
dict_input['image'] = input_image
input_type = random.choice(input_keys)
if input_type == 'box':
dict_input['boxes'] = labels_box[b_i:b_i+1]
elif input_type == 'point':
point_coords = labels_points[b_i:b_i+1]
dict_input['point_coords'] = point_coords
dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:]
elif input_type == 'noise_mask':
dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1]
else:
raise NotImplementedError
dict_input['original_size'] = imgs[b_i].shape[:2]
batched_input.append(dict_input)
with torch.no_grad():
batched_output, interm_embeddings = sam(batched_input, multimask_output=False)
batch_len = len(batched_output)
encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0)
image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)]
sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)]
dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)]
masks_sam, masks_hq = net(
image_embeddings=encoder_embedding,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
hq_token_only=False,
interm_embeddings=interm_embeddings,
)
iou = compute_iou(masks_hq,labels_ori)
boundary_iou = compute_boundary_iou(masks_hq,labels_ori)
if visualize:
print("visualize")
os.makedirs(args.output, exist_ok=True)
masks_hq_vis = (F.interpolate(masks_hq.detach(), (1024, 1024), mode="bilinear", align_corners=False) > 0).cpu()
for ii in range(len(imgs)):
base = data_val['imidx'][ii].item()
print('base:', base)
save_base = os.path.join(args.output, str(k)+'_'+ str(base))
imgs_ii = imgs[ii].astype(dtype=np.uint8)
show_iou = torch.tensor([iou.item()])
show_boundary_iou = torch.tensor([boundary_iou.item()])
show_anns(masks_hq_vis[ii], None, labels_box[ii].cpu(), None, save_base , imgs_ii, show_iou, show_boundary_iou)
loss_dict = {"val_iou_"+str(k): iou, "val_boundary_iou_"+str(k): boundary_iou}
loss_dict_reduced = misc.reduce_dict(loss_dict)
metric_logger.update(**loss_dict_reduced)
print('============================')
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
resstat = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0}
test_stats.update(resstat)
return test_stats
if __name__ == "__main__":
### --------------- Configuring the Train and Valid datasets ---------------
dataset_dis = {"name": "DIS5K-TR",
"im_dir": "./data/DIS5K/DIS-TR/im",
"gt_dir": "./data/DIS5K/DIS-TR/gt",
"im_ext": ".jpg",
"gt_ext": ".png"}
dataset_thin = {"name": "ThinObject5k-TR",
"im_dir": "./data/thin_object_detection/ThinObject5K/images_train",
"gt_dir": "./data/thin_object_detection/ThinObject5K/masks_train",
"im_ext": ".jpg",
"gt_ext": ".png"}
dataset_fss = {"name": "FSS",
"im_dir": "./data/cascade_psp/fss_all",
"gt_dir": "./data/cascade_psp/fss_all",
"im_ext": ".jpg",
"gt_ext": ".png"}
dataset_duts = {"name": "DUTS-TR",
"im_dir": "./data/cascade_psp/DUTS-TR",
"gt_dir": "./data/cascade_psp/DUTS-TR",
"im_ext": ".jpg",
"gt_ext": ".png"}
dataset_duts_te = {"name": "DUTS-TE",
"im_dir": "./data/cascade_psp/DUTS-TE",
"gt_dir": "./data/cascade_psp/DUTS-TE",
"im_ext": ".jpg",
"gt_ext": ".png"}
dataset_ecssd = {"name": "ECSSD",
"im_dir": "./data/cascade_psp/ecssd",
"gt_dir": "./data/cascade_psp/ecssd",
"im_ext": ".jpg",
"gt_ext": ".png"}
dataset_msra = {"name": "MSRA10K",
"im_dir": "./data/cascade_psp/MSRA_10K",
"gt_dir": "./data/cascade_psp/MSRA_10K",
"im_ext": ".jpg",
"gt_ext": ".png"}
# valid set
dataset_coift_val = {"name": "COIFT",
"im_dir": "./data/thin_object_detection/COIFT/images",
"gt_dir": "./data/thin_object_detection/COIFT/masks",
"im_ext": ".jpg",
"gt_ext": ".png"}
dataset_hrsod_val = {"name": "HRSOD",
"im_dir": "./data/thin_object_detection/HRSOD/images",
"gt_dir": "./data/thin_object_detection/HRSOD/masks_max255",
"im_ext": ".jpg",
"gt_ext": ".png"}
dataset_thin_val = {"name": "ThinObject5k-TE",
"im_dir": "./data/thin_object_detection/ThinObject5K/images_test",
"gt_dir": "./data/thin_object_detection/ThinObject5K/masks_test",
"im_ext": ".jpg",
"gt_ext": ".png"}
dataset_dis_val = {"name": "DIS5K-VD",
"im_dir": "./data/DIS5K/DIS-VD/im",
"gt_dir": "./data/DIS5K/DIS-VD/gt",
"im_ext": ".jpg",
"gt_ext": ".png"}
train_datasets = [dataset_dis, dataset_thin, dataset_fss, dataset_duts, dataset_duts_te, dataset_ecssd, dataset_msra]
valid_datasets = [dataset_dis_val, dataset_coift_val, dataset_hrsod_val, dataset_thin_val]
args = get_args_parser()
net = MaskDecoderHQ(args.model_type)
main(net, train_datasets, valid_datasets, args)