|
import os |
|
import time |
|
import numpy as np |
|
from skimage import io |
|
import time |
|
|
|
import torch, gc |
|
import torch.nn as nn |
|
from torch.autograd import Variable |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
|
|
from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize |
|
from basics import f1_mae_torch |
|
from models import * |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(hypar["seed"]) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(hypar["seed"]) |
|
|
|
print("define gt encoder ...") |
|
net = ISNetGTEncoder() |
|
|
|
if(hypar["gt_encoder_model"]!=""): |
|
model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"] |
|
if torch.cuda.is_available(): |
|
net.load_state_dict(torch.load(model_path)) |
|
net.cuda() |
|
else: |
|
net.load_state_dict(torch.load(model_path,map_location="cpu")) |
|
print("gt encoder restored from the saved weights ...") |
|
return net |
|
|
|
if torch.cuda.is_available(): |
|
net.cuda() |
|
|
|
print("--- define optimizer for GT Encoder---") |
|
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) |
|
|
|
model_path = hypar["model_path"] |
|
model_save_fre = hypar["model_save_fre"] |
|
max_ite = hypar["max_ite"] |
|
batch_size_train = hypar["batch_size_train"] |
|
batch_size_valid = hypar["batch_size_valid"] |
|
|
|
if(not os.path.exists(model_path)): |
|
os.mkdir(model_path) |
|
|
|
ite_num = hypar["start_ite"] |
|
ite_num4val = 0 |
|
running_loss = 0.0 |
|
running_tar_loss = 0.0 |
|
last_f1 = [0 for x in range(len(valid_dataloaders))] |
|
|
|
train_num = train_datasets[0].__len__() |
|
|
|
net.train() |
|
|
|
start_last = time.time() |
|
gos_dataloader = train_dataloaders[0] |
|
epoch_num = hypar["max_epoch_num"] |
|
notgood_cnt = 0 |
|
for epoch in range(epoch_num): |
|
|
|
for i, data in enumerate(gos_dataloader): |
|
|
|
if(ite_num >= max_ite): |
|
print("Training Reached the Maximal Iteration Number ", max_ite) |
|
exit() |
|
|
|
|
|
ite_num = ite_num + 1 |
|
ite_num4val = ite_num4val + 1 |
|
|
|
|
|
labels = data['label'] |
|
|
|
if(hypar["model_digit"]=="full"): |
|
labels = labels.type(torch.FloatTensor) |
|
else: |
|
labels = labels.type(torch.HalfTensor) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
labels_v = Variable(labels.cuda(), requires_grad=False) |
|
else: |
|
labels_v = Variable(labels, requires_grad=False) |
|
|
|
|
|
|
|
|
|
start_inf_loss_back = time.time() |
|
optimizer.zero_grad() |
|
|
|
ds, fs = net(labels_v) |
|
loss2, loss = net.compute_loss(ds, labels_v) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
running_loss += loss.item() |
|
running_tar_loss += loss2.item() |
|
|
|
|
|
del ds, loss2, loss |
|
end_inf_loss_back = time.time()-start_inf_loss_back |
|
|
|
print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % ( |
|
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back)) |
|
start_last = time.time() |
|
|
|
if ite_num % model_save_fre == 0: |
|
notgood_cnt += 1 |
|
|
|
|
|
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch) |
|
|
|
net.train() |
|
|
|
tmp_out = 0 |
|
print("last_f1:",last_f1) |
|
print("tmp_f1:",tmp_f1) |
|
for fi in range(len(last_f1)): |
|
if(tmp_f1[fi]>last_f1[fi]): |
|
tmp_out = 1 |
|
print("tmp_out:",tmp_out) |
|
if(tmp_out): |
|
notgood_cnt = 0 |
|
last_f1 = tmp_f1 |
|
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1] |
|
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae] |
|
maxf1 = '_'.join(tmp_f1_str) |
|
meanM = '_'.join(tmp_mae_str) |
|
|
|
model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\ |
|
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\ |
|
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\ |
|
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\ |
|
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \ |
|
"_maxF1_" + maxf1 + \ |
|
"_mae_" + meanM + \ |
|
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth" |
|
torch.save(net.state_dict(), model_path + model_name) |
|
|
|
running_loss = 0.0 |
|
running_tar_loss = 0.0 |
|
ite_num4val = 0 |
|
|
|
if(tmp_f1[0]>0.99): |
|
print("GT encoder is well-trained and obtained...") |
|
return net |
|
|
|
if(notgood_cnt >= hypar["early_stop"]): |
|
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !") |
|
exit() |
|
|
|
print("Training Reaches The Maximum Epoch Number") |
|
return net |
|
|
|
def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0): |
|
net.eval() |
|
print("Validating...") |
|
epoch_num = hypar["max_epoch_num"] |
|
|
|
val_loss = 0.0 |
|
tar_loss = 0.0 |
|
|
|
|
|
tmp_f1 = [] |
|
tmp_mae = [] |
|
tmp_time = [] |
|
|
|
start_valid = time.time() |
|
for k in range(len(valid_dataloaders)): |
|
|
|
valid_dataloader = valid_dataloaders[k] |
|
valid_dataset = valid_datasets[k] |
|
|
|
val_num = valid_dataset.__len__() |
|
mybins = np.arange(0,256) |
|
PRE = np.zeros((val_num,len(mybins)-1)) |
|
REC = np.zeros((val_num,len(mybins)-1)) |
|
F1 = np.zeros((val_num,len(mybins)-1)) |
|
MAE = np.zeros((val_num)) |
|
|
|
val_cnt = 0.0 |
|
i_val = None |
|
|
|
for i_val, data_val in enumerate(valid_dataloader): |
|
|
|
|
|
imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape'] |
|
|
|
if(hypar["model_digit"]=="full"): |
|
labels_val = labels_val.type(torch.FloatTensor) |
|
else: |
|
labels_val = labels_val.type(torch.HalfTensor) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
labels_val_v = Variable(labels_val.cuda(), requires_grad=False) |
|
else: |
|
labels_val_v = Variable(labels_val,requires_grad=False) |
|
|
|
t_start = time.time() |
|
ds_val = net(labels_val_v)[0] |
|
t_end = time.time()-t_start |
|
tmp_time.append(t_end) |
|
|
|
|
|
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v) |
|
|
|
|
|
for t in range(hypar["batch_size_valid"]): |
|
val_cnt = val_cnt + 1.0 |
|
print("num of val: ", val_cnt) |
|
i_test = imidx_val[t].data.numpy() |
|
|
|
pred_val = ds_val[0][t,:,:,:] |
|
|
|
|
|
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear')) |
|
|
|
ma = torch.max(pred_val) |
|
mi = torch.min(pred_val) |
|
pred_val = (pred_val-mi)/(ma-mi) |
|
|
|
|
|
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) |
|
if gt.max()==1: |
|
gt=gt*255 |
|
with torch.no_grad(): |
|
gt = torch.tensor(gt).to(device) |
|
|
|
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar) |
|
|
|
PRE[i_test,:]=pre |
|
REC[i_test,:] = rec |
|
F1[i_test,:] = f1 |
|
MAE[i_test] = mae |
|
|
|
del ds_val, gt |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
val_loss += loss_val.item() |
|
tar_loss += loss2_val.item() |
|
|
|
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end)) |
|
|
|
del loss2_val, loss_val |
|
|
|
print('============================') |
|
PRE_m = np.mean(PRE,0) |
|
REC_m = np.mean(REC,0) |
|
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8) |
|
|
|
tmp_f1.append(np.amax(f1_m)) |
|
tmp_mae.append(np.mean(MAE)) |
|
print("The max F1 Score: %f"%(np.max(f1_m))) |
|
print("MAE: ", np.mean(MAE)) |
|
|
|
|
|
|
|
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time |
|
|
|
def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): |
|
|
|
if hypar["interm_sup"]: |
|
print("Get the gt encoder ...") |
|
featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val) |
|
|
|
for param in featurenet.parameters(): |
|
param.requires_grad=False |
|
|
|
|
|
model_path = hypar["model_path"] |
|
model_save_fre = hypar["model_save_fre"] |
|
max_ite = hypar["max_ite"] |
|
batch_size_train = hypar["batch_size_train"] |
|
batch_size_valid = hypar["batch_size_valid"] |
|
|
|
if(not os.path.exists(model_path)): |
|
os.mkdir(model_path) |
|
|
|
ite_num = hypar["start_ite"] |
|
ite_num4val = 0 |
|
running_loss = 0.0 |
|
running_tar_loss = 0.0 |
|
last_f1 = [0 for x in range(len(valid_dataloaders))] |
|
|
|
train_num = train_datasets[0].__len__() |
|
|
|
net.train() |
|
|
|
start_last = time.time() |
|
gos_dataloader = train_dataloaders[0] |
|
epoch_num = hypar["max_epoch_num"] |
|
notgood_cnt = 0 |
|
for epoch in range(epoch_num): |
|
|
|
for i, data in enumerate(gos_dataloader): |
|
|
|
if(ite_num >= max_ite): |
|
print("Training Reached the Maximal Iteration Number ", max_ite) |
|
exit() |
|
|
|
|
|
ite_num = ite_num + 1 |
|
ite_num4val = ite_num4val + 1 |
|
|
|
|
|
inputs, labels = data['image'], data['label'] |
|
|
|
if(hypar["model_digit"]=="full"): |
|
inputs = inputs.type(torch.FloatTensor) |
|
labels = labels.type(torch.FloatTensor) |
|
else: |
|
inputs = inputs.type(torch.HalfTensor) |
|
labels = labels.type(torch.HalfTensor) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False) |
|
else: |
|
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False) |
|
|
|
|
|
|
|
|
|
start_inf_loss_back = time.time() |
|
optimizer.zero_grad() |
|
|
|
if hypar["interm_sup"]: |
|
|
|
ds,dfs = net(inputs_v) |
|
_,fs = featurenet(labels_v) |
|
loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE') |
|
else: |
|
|
|
ds,_ = net(inputs_v) |
|
loss2, loss = net.compute_loss(ds, labels_v) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
running_loss += loss.item() |
|
running_tar_loss += loss2.item() |
|
|
|
|
|
del ds, loss2, loss |
|
end_inf_loss_back = time.time()-start_inf_loss_back |
|
|
|
print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % ( |
|
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back)) |
|
start_last = time.time() |
|
|
|
if ite_num % model_save_fre == 0: |
|
notgood_cnt += 1 |
|
net.eval() |
|
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch) |
|
net.train() |
|
|
|
tmp_out = 0 |
|
print("last_f1:",last_f1) |
|
print("tmp_f1:",tmp_f1) |
|
for fi in range(len(last_f1)): |
|
if(tmp_f1[fi]>last_f1[fi]): |
|
tmp_out = 1 |
|
print("tmp_out:",tmp_out) |
|
if(tmp_out): |
|
notgood_cnt = 0 |
|
last_f1 = tmp_f1 |
|
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1] |
|
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae] |
|
maxf1 = '_'.join(tmp_f1_str) |
|
meanM = '_'.join(tmp_mae_str) |
|
|
|
model_name = "/gpu_itr_"+str(ite_num)+\ |
|
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\ |
|
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\ |
|
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\ |
|
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \ |
|
"_maxF1_" + maxf1 + \ |
|
"_mae_" + meanM + \ |
|
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth" |
|
torch.save(net.state_dict(), model_path + model_name) |
|
|
|
running_loss = 0.0 |
|
running_tar_loss = 0.0 |
|
ite_num4val = 0 |
|
|
|
if(notgood_cnt >= hypar["early_stop"]): |
|
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !") |
|
exit() |
|
|
|
print("Training Reaches The Maximum Epoch Number") |
|
|
|
def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0): |
|
net.eval() |
|
print("Validating...") |
|
epoch_num = hypar["max_epoch_num"] |
|
|
|
val_loss = 0.0 |
|
tar_loss = 0.0 |
|
val_cnt = 0.0 |
|
|
|
tmp_f1 = [] |
|
tmp_mae = [] |
|
tmp_time = [] |
|
|
|
start_valid = time.time() |
|
|
|
for k in range(len(valid_dataloaders)): |
|
|
|
valid_dataloader = valid_dataloaders[k] |
|
valid_dataset = valid_datasets[k] |
|
|
|
val_num = valid_dataset.__len__() |
|
mybins = np.arange(0,256) |
|
PRE = np.zeros((val_num,len(mybins)-1)) |
|
REC = np.zeros((val_num,len(mybins)-1)) |
|
F1 = np.zeros((val_num,len(mybins)-1)) |
|
MAE = np.zeros((val_num)) |
|
|
|
for i_val, data_val in enumerate(valid_dataloader): |
|
val_cnt = val_cnt + 1.0 |
|
imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape'] |
|
|
|
if(hypar["model_digit"]=="full"): |
|
inputs_val = inputs_val.type(torch.FloatTensor) |
|
labels_val = labels_val.type(torch.FloatTensor) |
|
else: |
|
inputs_val = inputs_val.type(torch.HalfTensor) |
|
labels_val = labels_val.type(torch.HalfTensor) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False) |
|
else: |
|
inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False) |
|
|
|
t_start = time.time() |
|
ds_val = net(inputs_val_v)[0] |
|
t_end = time.time()-t_start |
|
tmp_time.append(t_end) |
|
|
|
|
|
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v) |
|
|
|
|
|
for t in range(hypar["batch_size_valid"]): |
|
i_test = imidx_val[t].data.numpy() |
|
|
|
pred_val = ds_val[0][t,:,:,:] |
|
|
|
|
|
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear')) |
|
|
|
|
|
ma = torch.max(pred_val) |
|
mi = torch.min(pred_val) |
|
pred_val = (pred_val-mi)/(ma-mi) |
|
|
|
if len(valid_dataset.dataset["ori_gt_path"]) != 0: |
|
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) |
|
if gt.max()==1: |
|
gt=gt*255 |
|
else: |
|
gt = np.zeros((shapes_val[t][0],shapes_val[t][1])) |
|
with torch.no_grad(): |
|
gt = torch.tensor(gt).to(device) |
|
|
|
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar) |
|
|
|
|
|
PRE[i_test,:]=pre |
|
REC[i_test,:] = rec |
|
F1[i_test,:] = f1 |
|
MAE[i_test] = mae |
|
|
|
del ds_val, gt |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
val_loss += loss_val.item() |
|
tar_loss += loss2_val.item() |
|
|
|
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end)) |
|
|
|
del loss2_val, loss_val |
|
|
|
print('============================') |
|
PRE_m = np.mean(PRE,0) |
|
REC_m = np.mean(REC,0) |
|
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8) |
|
|
|
tmp_f1.append(np.amax(f1_m)) |
|
tmp_mae.append(np.mean(MAE)) |
|
|
|
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time |
|
|
|
def main(train_datasets, |
|
valid_datasets, |
|
hypar): |
|
|
|
|
|
dataloaders_train = [] |
|
dataloaders_valid = [] |
|
|
|
if(hypar["mode"]=="train"): |
|
print("--- create training dataloader ---") |
|
|
|
train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train") |
|
|
|
train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list, |
|
cache_size = hypar["cache_size"], |
|
cache_boost = hypar["cache_boost_train"], |
|
my_transforms = [ |
|
GOSRandomHFlip(), |
|
|
|
|
|
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]), |
|
], |
|
batch_size = hypar["batch_size_train"], |
|
shuffle = True) |
|
train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list, |
|
cache_size = hypar["cache_size"], |
|
cache_boost = hypar["cache_boost_train"], |
|
my_transforms = [ |
|
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]), |
|
], |
|
batch_size = hypar["batch_size_valid"], |
|
shuffle = False) |
|
print(len(train_dataloaders), " train dataloaders created") |
|
|
|
print("--- create valid dataloader ---") |
|
|
|
valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid") |
|
|
|
valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list, |
|
cache_size = hypar["cache_size"], |
|
cache_boost = hypar["cache_boost_valid"], |
|
my_transforms = [ |
|
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]), |
|
|
|
], |
|
batch_size=hypar["batch_size_valid"], |
|
shuffle=False) |
|
print(len(valid_dataloaders), " valid dataloaders created") |
|
|
|
|
|
|
|
print("--- build model ---") |
|
net = hypar["model"] |
|
|
|
|
|
if(hypar["model_digit"]=="half"): |
|
net.half() |
|
for layer in net.modules(): |
|
if isinstance(layer, nn.BatchNorm2d): |
|
layer.float() |
|
|
|
if torch.cuda.is_available(): |
|
net.cuda() |
|
|
|
if(hypar["restore_model"]!=""): |
|
print("restore model from:") |
|
print(hypar["model_path"]+"/"+hypar["restore_model"]) |
|
if torch.cuda.is_available(): |
|
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"])) |
|
else: |
|
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu")) |
|
|
|
print("--- define optimizer ---") |
|
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) |
|
|
|
|
|
if(hypar["mode"]=="train"): |
|
train(net, |
|
optimizer, |
|
train_dataloaders, |
|
train_datasets, |
|
valid_dataloaders, |
|
valid_datasets, |
|
hypar, |
|
train_dataloaders_val, train_datasets_val) |
|
else: |
|
valid(net, |
|
valid_dataloaders, |
|
valid_datasets, |
|
hypar) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
train_datasets, valid_datasets = [], [] |
|
dataset_1, dataset_1 = {}, {} |
|
|
|
dataset_tr = {"name": "DIS5K-TR", |
|
"im_dir": "../DIS5K/DIS-TR/im", |
|
"gt_dir": "../DIS5K/DIS-TR/gt", |
|
"im_ext": ".jpg", |
|
"gt_ext": ".png", |
|
"cache_dir":"../DIS5K-Cache/DIS-TR"} |
|
|
|
dataset_vd = {"name": "DIS5K-VD", |
|
"im_dir": "../DIS5K/DIS-VD/im", |
|
"gt_dir": "../DIS5K/DIS-VD/gt", |
|
"im_ext": ".jpg", |
|
"gt_ext": ".png", |
|
"cache_dir":"../DIS5K-Cache/DIS-VD"} |
|
|
|
dataset_te1 = {"name": "DIS5K-TE1", |
|
"im_dir": "../DIS5K/DIS-TE1/im", |
|
"gt_dir": "../DIS5K/DIS-TE1/gt", |
|
"im_ext": ".jpg", |
|
"gt_ext": ".png", |
|
"cache_dir":"../DIS5K-Cache/DIS-TE1"} |
|
|
|
dataset_te2 = {"name": "DIS5K-TE2", |
|
"im_dir": "../DIS5K/DIS-TE2/im", |
|
"gt_dir": "../DIS5K/DIS-TE2/gt", |
|
"im_ext": ".jpg", |
|
"gt_ext": ".png", |
|
"cache_dir":"../DIS5K-Cache/DIS-TE2"} |
|
|
|
dataset_te3 = {"name": "DIS5K-TE3", |
|
"im_dir": "../DIS5K/DIS-TE3/im", |
|
"gt_dir": "../DIS5K/DIS-TE3/gt", |
|
"im_ext": ".jpg", |
|
"gt_ext": ".png", |
|
"cache_dir":"../DIS5K-Cache/DIS-TE3"} |
|
|
|
dataset_te4 = {"name": "DIS5K-TE4", |
|
"im_dir": "../DIS5K/DIS-TE4/im", |
|
"gt_dir": "../DIS5K/DIS-TE4/gt", |
|
"im_ext": ".jpg", |
|
"gt_ext": ".png", |
|
"cache_dir":"../DIS5K-Cache/DIS-TE4"} |
|
|
|
dataset_demo = {"name": "your-dataset", |
|
"im_dir": "../your-dataset/im", |
|
"gt_dir": "", |
|
"im_ext": ".jpg", |
|
"gt_ext": "", |
|
"cache_dir":"../your-dataset/cache"} |
|
|
|
train_datasets = [dataset_tr] |
|
|
|
valid_datasets = [dataset_vd] |
|
|
|
|
|
hypar = {} |
|
|
|
|
|
hypar["mode"] = "train" |
|
|
|
|
|
|
|
|
|
hypar["interm_sup"] = False |
|
|
|
if hypar["mode"] == "train": |
|
hypar["valid_out_dir"] = "" |
|
hypar["model_path"] ="../saved_models/IS-Net-test" |
|
hypar["restore_model"] = "" |
|
hypar["start_ite"] = 0 |
|
hypar["gt_encoder_model"] = "" |
|
else: |
|
hypar["valid_out_dir"] = "../your-results/" |
|
hypar["model_path"] = "../saved_models/IS-Net" |
|
hypar["restore_model"] = "isnet.pth" |
|
|
|
|
|
|
|
|
|
|
|
hypar["model_digit"] = "full" |
|
hypar["seed"] = 0 |
|
|
|
|
|
|
|
|
|
hypar["cache_size"] = [1024, 1024] |
|
hypar["cache_boost_train"] = False |
|
hypar["cache_boost_valid"] = False |
|
|
|
|
|
hypar["input_size"] = [1024, 1024] |
|
hypar["crop_size"] = [1024, 1024] |
|
hypar["random_flip_h"] = 1 |
|
hypar["random_flip_v"] = 0 |
|
|
|
|
|
print("building model...") |
|
hypar["model"] = ISNetDIS() |
|
hypar["early_stop"] = 20 |
|
hypar["model_save_fre"] = 2000 |
|
|
|
hypar["batch_size_train"] = 8 |
|
hypar["batch_size_valid"] = 1 |
|
print("batch size: ", hypar["batch_size_train"]) |
|
|
|
hypar["max_ite"] = 10000000 |
|
hypar["max_epoch_num"] = 1000000 |
|
|
|
main(train_datasets, |
|
valid_datasets, |
|
hypar=hypar) |