Last commit not found
import datetime | |
import os | |
import time | |
import torch | |
import torch.utils.data | |
from torch import nn | |
from functools import reduce | |
import operator | |
from bert.multimodal_bert import MultiModalBert | |
import torchvision | |
from lib import multimodal_segmentation_ppm | |
import transforms as T | |
import utils | |
import numpy as np | |
import torch.nn.functional as F | |
import gc | |
from collections import OrderedDict | |
import torch.backends.cudnn as cudnn | |
#from ffrecord.torch import DataLoader,Dataset | |
from modeling.MaskFormerModel import MaskFormerHead | |
from addict import Dict | |
from mask2former_utils.criterion import SetCriterion, Criterion | |
from mask2former_utils.matcher import HungarianMatcher | |
from bert.modeling_bert import BertLMPredictionHead, BertEncoder | |
class WrapperModel(nn.Module): | |
def __init__(self, image_model, language_model, classifier, args) : | |
super(WrapperModel, self).__init__() | |
self.image_model = image_model | |
self.language_model = language_model | |
self.classifier = classifier | |
self.lang_proj = nn.Linear(768,256) | |
config = Dict({ | |
"architectures": [ | |
"BertForMaskedLM" | |
], | |
"attention_probs_dropout_prob": 0.1, | |
"gradient_checkpointing": False, | |
"hidden_act": "gelu", | |
"hidden_dropout_prob": 0.1, | |
"hidden_size": 512, | |
"initializer_range": 0.02, | |
"intermediate_size": 3072, | |
"layer_norm_eps": 1e-12, | |
#"max_position_embeddings": 16+20, | |
"model_type": "bert", | |
"num_attention_heads": 8, | |
"num_hidden_layers": 8, | |
"pad_token_id": 0, | |
"position_embedding_type": "absolute", | |
"transformers_version": "4.6.0.dev0", | |
"type_vocab_size": 2, | |
"use_cache": True, | |
"vocab_size": 30522 | |
}) | |
self.mlm_transformer = BertEncoder(config) | |
self.lang_proj = nn.Linear(768,256) | |
self.mlm_vis_proj = nn.Conv2d(1024,512,1) | |
self.mlm_lang_proj = nn.Linear(768,512) | |
#print(vis_proj) | |
self.mlm_head = BertLMPredictionHead(config) | |
assert args.img_size % 4 == 0 | |
num_img_tokens = 20 + ((args.img_size // 4)//8) ** 2 | |
print(num_img_tokens) | |
self.mlm_pos_embeds = nn.Embedding(num_img_tokens+1, 512) | |
self.mlm_modal_embeds = nn.Embedding(3, 512) | |
self.mlm_mask_embed = nn.Embedding(1, 512) | |
self.mlm_pos_mlp = nn.Sequential( | |
nn.Linear(2, 512), | |
nn.LayerNorm(512), | |
nn.Linear(512,512), | |
nn.GELU() | |
) | |
def _get_binary_mask(self, target): | |
# 返回每类的binary mask | |
y, x = target.size() | |
target_onehot = torch.zeros(self.num_classes + 1, y, x) | |
target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) | |
return target_onehot[1:] | |
def semantic_inference(self, mask_cls, mask_pred): | |
mask_cls = F.softmax(mask_cls, dim=1)[...,1:] | |
mask_pred = mask_pred.sigmoid() | |
semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) | |
return semseg | |
def forward(self, image, sentences, attentions, mlm_targets, mlm_masks, position): | |
input_shape = image.shape[-2:] | |
l_mask = attentions.unsqueeze(dim=-1) | |
i0, Wh, Ww = self.image_model.forward_stem(image) | |
l0, extended_attention_mask = self.language_model.forward_stem(mlm_targets.squeeze(1), attentions) | |
i1 = self.image_model.forward_stage1(i0, Wh, Ww) | |
l1 = self.language_model.forward_stage1(l0, extended_attention_mask) | |
i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) | |
l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) | |
i1 = i1_temp | |
i2 = self.image_model.forward_stage2(i1, Wh, Ww) | |
l2 = self.language_model.forward_stage2(l1, extended_attention_mask) | |
i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) | |
l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) | |
i2 = i2_temp | |
i3 = self.image_model.forward_stage3(i2, Wh, Ww) | |
l3 = self.language_model.forward_stage3(l2, extended_attention_mask) | |
i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) | |
l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) | |
i3 = i3_temp | |
i4 = self.image_model.forward_stage4(i3, Wh, Ww) | |
l4 = self.language_model.forward_stage4(l3, extended_attention_mask) | |
i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) | |
l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) | |
i4 = i4_temp | |
#i1_residual, i2_residual, i3_residual, i4_residual = features | |
#x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual) | |
#x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) | |
outputs = {} | |
outputs['s1'] = i1_residual | |
outputs['s2'] = i2_residual | |
outputs['s3'] = i3_residual | |
outputs['s4'] = i4_residual | |
predictions, mask_features = self.classifier(outputs) | |
#print(target_reshape.shape) | |
#tmp = np.argwhere(target_reshape[:, 0].detach().cpu().numpy()).reshape(-1, target_reshape.shape[2]*target_reshape[3], 3) | |
#centroid = tmp.mean(1) | |
#print(centroid) | |
#centroid_x, centroid_y = int(centroid[1]), int(centroid[0]) | |
#last_hidden_states = brt_model(sentences, attention_mask=attentions)[0] # (6, 10, 768) | |
#embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy | |
l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) | |
l1 = self.language_model.forward_stage1(l0, extended_attention_mask) | |
l2 = self.language_model.forward_stage2(l1, extended_attention_mask) | |
l3 = self.language_model.forward_stage3(l2, extended_attention_mask) | |
l4 = self.language_model.forward_stage4(l3, extended_attention_mask) | |
mlp_embed = self.mlm_pos_mlp(position) | |
#print(centroid_x, centroid_y) | |
mlm_targets = torch.where( | |
mlm_masks > 0, | |
mlm_targets, | |
torch.ones_like(mlm_targets) * (-1) | |
) | |
#print(x_c4[target_reshape[:, [0]].bool()].shape) | |
vis_features = self.mlm_vis_proj(i4_residual).flatten(2).permute(0,2,1) | |
#print(l4.shape) | |
lang_features = self.mlm_lang_proj(l4) | |
#print(lang_features.shape, vis_features.shape, mlp_embed.shape) | |
mm_features = torch.cat([lang_features, vis_features, mlp_embed.unsqueeze(1)], dim=1) | |
#print(mm_features.shape) | |
#print(mlm_modal_embeds.weight.shape) | |
modal_embeds = torch.cat([self.mlm_modal_embeds.weight[0].unsqueeze(0).repeat(1, lang_features.shape[1], 1), self.mlm_modal_embeds.weight[1].unsqueeze(0).repeat(1, vis_features.shape[1], 1), self.mlm_modal_embeds.weight[2].unsqueeze(0).repeat(1,1,1)], dim=1) | |
#print(modal_embeds.shape) | |
#print(mlm_transformer) | |
#print(attentions.shape) | |
mixed_attention_mask = torch.cat([attentions.unsqueeze(-1), torch.ones(attentions.shape[0], vis_features.shape[1]+1, 1).to(attentions.device)], dim=1) | |
mixed_attention_mask = mixed_attention_mask.permute(0,2,1).unsqueeze(1) | |
mixed_attention_mask = (1-mixed_attention_mask)* -10000.0 | |
head_mask = [None] * 8 | |
#extended_attention_mask = get_extended_attention_mask(mixed_attention_mask, mm_features.shape, mm_features.device) | |
#print(mm_features.shape, mixed_attention_mask.shape, head_mask) | |
#print(mm_features.shape, self.mlm_pos_embeds.weight.shape, self.mlm_modal_embeds.weight.shape) | |
head_features = self.mlm_transformer(mm_features + self.mlm_pos_embeds.weight.unsqueeze(0) + modal_embeds, mixed_attention_mask, head_mask)[0] | |
#print(head_features.shape, attentions.shape) | |
head_features = head_features[:, :20][attentions.bool()] | |
#print(embedding.shape, mask_features.shape) | |
mlm_predictions = self.mlm_head(head_features) | |
mlm_predictions = mlm_predictions.reshape(-1, self.language_model.config.vocab_size) | |
mlm_targets = mlm_targets.squeeze(1)[attentions.bool()] | |
#mlm_loss = mlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets) | |
#loss += mlm_loss | |
#mlm_loss_print=mlm_loss.item() | |
return predictions, mask_features, self.lang_proj((l4_residual * l_mask).sum(1)/l_mask.sum(1)), mlm_predictions, mlm_targets | |
# IoU calculation for validation | |
def IoU(pred, gt): | |
#pred = pred.argmax(1) | |
pred = (pred > 0.5) | |
intersection = torch.sum(torch.mul(pred, gt)) | |
union = torch.sum(torch.add(pred, gt)) - intersection | |
if intersection == 0 or union == 0: | |
iou = 0 | |
else: | |
iou = float(intersection) / float(union) | |
return iou, intersection, union | |
def get_dataset(image_set, transform, args): | |
from data.dataset_refer_bert_mlm import ReferDataset | |
ds = ReferDataset(args, | |
split=image_set, | |
image_transforms=transform, | |
target_transforms=None | |
) | |
num_classes = 2 | |
return ds, num_classes | |
def get_transform(args): | |
transforms = [T.Resize(args.img_size, args.img_size), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
] | |
return T.Compose(transforms) | |
#def criterion(input, target): | |
# weight = torch.FloatTensor([0.9, 1.1]).cuda() | |
# return nn.functional.cross_entropy(input, target, weight=weight) | |
def evaluate(model, data_loader): | |
model.eval() | |
metric_logger = utils.MetricLogger(delimiter=" ") | |
header = 'Test:' | |
total_its = 0 | |
acc_ious = 0 | |
# evaluation variables | |
cum_I, cum_U = 0, 0 | |
eval_seg_iou_list = [.5, .6, .7, .8, .9] | |
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) | |
seg_total = 0 | |
mean_IoU = [] | |
with torch.no_grad(): | |
for data in metric_logger.log_every(data_loader, 100, header): | |
total_its += 1 | |
#image, target, sentences, attentions = data | |
#image, target, sentences, attentions = image.cuda(non_blocking=True),\ | |
# target.cuda(non_blocking=True),\ | |
# sentences.cuda(non_blocking=True),\ | |
# attentions.cuda(non_blocking=True) | |
image, target, sentences, attentions, mlm_targets, mlm_masks, position = data | |
image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\ | |
target.cuda(non_blocking=True),\ | |
sentences.cuda(non_blocking=True),\ | |
attentions.cuda(non_blocking=True), \ | |
mlm_targets.cuda(non_blocking=True), \ | |
mlm_masks.cuda(non_blocking=True), \ | |
position.cuda(non_blocking=True) | |
sentences = sentences.squeeze(1) | |
attentions = attentions.squeeze(1) | |
#print("sentences", sentences.shape) | |
#print("attentions", attentions.shape) | |
output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position) | |
mask_cls_results = output["pred_logits"] | |
mask_pred_results = output["pred_masks"] | |
target_shape = target.shape[-2:] | |
mask_pred_results = F.interpolate(mask_pred_results, size=target_shape, mode='bilinear', align_corners=True) | |
pred_masks = model.module.semantic_inference(mask_cls_results, mask_pred_results) | |
output = pred_masks[0] | |
iou, I, U = IoU(output, target) | |
acc_ious += iou | |
mean_IoU.append(iou) | |
cum_I += I | |
cum_U += U | |
for n_eval_iou in range(len(eval_seg_iou_list)): | |
eval_seg_iou = eval_seg_iou_list[n_eval_iou] | |
seg_correct[n_eval_iou] += (iou >= eval_seg_iou) | |
seg_total += 1 | |
iou = acc_ious / total_its | |
mean_IoU = np.array(mean_IoU) | |
mIoU = np.mean(mean_IoU) | |
print('Final results:') | |
print('Mean IoU is %.2f\n' % (mIoU * 100.)) | |
results_str = '' | |
for n_eval_iou in range(len(eval_seg_iou_list)): | |
results_str += ' precision@%s = %.2f\n' % \ | |
(str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) | |
results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) | |
print(results_str) | |
return 100 * iou, 100 * cum_I / cum_U | |
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, | |
iterations, args): | |
model.train() | |
metric_logger = utils.MetricLogger(delimiter=" ") | |
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) | |
header = 'Epoch: [{}]'.format(epoch) | |
train_loss = 0 | |
total_its = 0 | |
for data in metric_logger.log_every(data_loader, print_freq, header): | |
total_its += 1 | |
#image, target, sentences, attentions = data | |
#image, target, sentences, attentions = image.cuda(non_blocking=True),\ | |
# target.cuda(non_blocking=True),\ | |
# sentences.cuda(non_blocking=True),\ | |
# attentions.cuda(non_blocking=True) | |
image, target, sentences, attentions, mlm_targets, mlm_masks, position = data | |
image, target, sentences, attentions, mlm_targets, mlm_masks, position = image.cuda(non_blocking=True),\ | |
target.cuda(non_blocking=True),\ | |
sentences.cuda(non_blocking=True),\ | |
attentions.cuda(non_blocking=True), \ | |
mlm_targets.cuda(non_blocking=True), \ | |
mlm_masks.cuda(non_blocking=True), \ | |
position.cuda(non_blocking=True) | |
sentences = sentences.squeeze(1) | |
attentions = attentions.squeeze(1) | |
#l_mask = attentions.unsqueeze(dim=-1) | |
output, mask_features, avg_lang_feature, mlm_predictions, mlm_targets = model(image, sentences, attentions, mlm_targets, mlm_masks, position) | |
#print(avg_lang_feature.shape) | |
avg_lang_feature = torch.nn.functional.normalize(avg_lang_feature, dim=1) | |
#print("----") | |
#print(output.shape) | |
#print(mask_features.shape) | |
#print(avg_lang_feature.shape) | |
#print( mlm_predictions.shape) | |
#print(mlm_targets.shape) | |
#print("----") | |
target_shape = target.shape[-2:] | |
output['pred_masks'] = F.interpolate(output['pred_masks'], size=target_shape, mode='bilinear', align_corners=True) | |
if "aux_outputs" in output: | |
for i, aux_outputs in enumerate(output["aux_outputs"]): | |
output['aux_outputs'][i]['pred_masks'] = F.interpolate(output['aux_outputs'][i]['pred_masks'], size=target_shape, mode='bilinear', align_corners=True) | |
# pixel region | |
B, C, H, W = mask_features.shape | |
target_reshape = F.interpolate(target.unsqueeze(1).float(), size=mask_features.shape[-2:], mode='nearest').long() | |
target_reshape = target_reshape.repeat(1, mask_features.shape[1], 1, 1) | |
#print(avg_pos_feature.shape, avg_lang_feature.shape, avg_neg_feature.shape) | |
#cl_loss = 0.0 | |
plic_lang_loss = 0.0 | |
plic_pos_loss = 0.0 | |
plic_neg_loss = 0.0 | |
for i in range(B): | |
if ((target_reshape[[i]] == 0).sum() != 0 and (target_reshape[[i]] == 1).sum() != 0): | |
avg_pos_feature = (mask_features[[i]] * target_reshape[[i]]).sum(-1).sum(-1) / target_reshape[[i]].sum(-1).sum(-1) | |
avg_neg_feature = (mask_features[[i]] * (1.0-target_reshape[[i]])).sum(-1).sum(-1) / (1.0-target_reshape[[i]]).sum(-1).sum(-1) | |
avg_pos_feature = torch.nn.functional.normalize(avg_pos_feature, dim=1) | |
avg_neg_feature = torch.nn.functional.normalize(avg_neg_feature, dim=1) | |
#avg lang feature no normalize??? | |
pos_features = mask_features[[i]][target_reshape[[i]]==1].view(1, C, -1) | |
neg_features = mask_features[[i]][target_reshape[[i]]==0].view(1, C, -1) | |
#inter_neg_features = mask_features[[B-i-1]][target_reshape[[B-i-1]]==1].view(1, C, -1) | |
#neg_features = torch.cat([intra_neg_features, inter_neg_features], dim=2) | |
pos_features = torch.nn.functional.normalize(pos_features, dim=1) | |
neg_features = torch.nn.functional.normalize(neg_features, dim=1) | |
#print(avg_lang_feature.shape, avg_lang_feature[[i]].shape, pos_features.shape) | |
lang_pos_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], pos_features) | |
lang_neg_scores = torch.einsum("bq,bqn->bn", avg_lang_feature[[i]], neg_features) | |
lang_matrix = torch.cat([lang_pos_scores.unsqueeze(-1), lang_neg_scores.unsqueeze(1).repeat(1, lang_pos_scores.shape[1], 1)], dim=2) | |
lang_labels = torch.zeros(lang_matrix.shape[1], dtype=torch.long).cuda() | |
lang_labels = lang_labels.unsqueeze(0).repeat(lang_matrix.shape[0], 1) | |
lang_score = torch.softmax(lang_matrix, -1) | |
lang_score = 1.0 - lang_score[:, :, 0] | |
pos_pos_scores = torch.einsum("bq,bqn->bn", avg_pos_feature, pos_features) | |
pos_neg_scores = torch.einsum("bqn,bqm->bnm", pos_features, neg_features) | |
pos_matrix = torch.cat([pos_pos_scores.unsqueeze(-1), pos_neg_scores], dim=2) | |
pos_labels = torch.zeros(pos_matrix.shape[1], dtype=torch.long).cuda() | |
pos_labels = pos_labels.unsqueeze(0).repeat(pos_matrix.shape[0], 1) | |
pos_score = torch.softmax(pos_matrix, -1) | |
pos_score = 1.0 - pos_score[:, :, 0] | |
#pos_weight = pos_weight.view(-1, pos_weight.shape[-1]) | |
#intra_neg_features = torch.nn.functional.normalize(intra_neg_features, dim=1) | |
neg_neg_scores = torch.einsum("bq,bqn->bn", avg_neg_feature, neg_features) | |
neg_pos_scores = torch.einsum("bqn,bqm->bnm", neg_features, pos_features) | |
neg_matrix = torch.cat([neg_neg_scores.unsqueeze(-1), neg_pos_scores], dim=2) | |
neg_labels = torch.zeros(neg_matrix.shape[1], dtype=torch.long).cuda() | |
neg_labels = neg_labels.unsqueeze(0).repeat(neg_matrix.shape[0], 1) | |
neg_score = torch.softmax(neg_matrix, -1) | |
neg_score = 1.0 - neg_score[:, :, 0] | |
#neg_weight = neg_weight.view(-1, neg_weight.shape[-1]) | |
pos_loss = (torch.pow(pos_score, args.plic_pos_alpha) * torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/args.plic_pos_temp, pos_labels.view(-1), reduction='none')).mean() | |
neg_loss = (torch.pow(neg_score, args.plic_neg_alpha) * torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/args.plic_neg_temp, neg_labels.view(-1), reduction='none')).mean() | |
lang_loss = (torch.pow(lang_score, args.plic_lang_alpha) * torch.nn.functional.cross_entropy(lang_matrix.view(-1, lang_matrix.shape[-1])/args.plic_lang_temp, lang_labels.view(-1), reduction='none')).mean() | |
plic_pos_loss += pos_loss | |
plic_neg_loss += neg_loss | |
plic_lang_loss += lang_loss | |
#cl_loss += 0.5 * (torch.nn.functional.cross_entropy(pos_matrix.view(-1, pos_matrix.shape[-1])/cl_temp, pos_labels.view(-1))+torch.nn.functional.cross_entropy(neg_matrix.view(-1, neg_matrix.shape[-1])/cl_temp, neg_labels.view(-1))) | |
plic_pos_loss = (args.plic_pos_weight * plic_pos_loss) / B | |
plic_neg_loss = (args.plic_neg_weight * plic_neg_loss) / B | |
plic_lang_loss = (args.plic_lang_weight * plic_lang_loss) / B | |
plic_loss = plic_pos_loss + plic_neg_loss +plic_lang_loss | |
#print(output.device, target.device) | |
losses = criterion(output, target) | |
weight_dict = criterion.weight_dict | |
loss_ce = 0.0 | |
loss_dice = 0.0 | |
loss_mask = 0.0 | |
for k in list(losses.keys()): | |
if k in weight_dict: | |
losses[k] *= criterion.weight_dict[k] | |
if '_ce' in k: | |
loss_ce += losses[k] | |
elif '_dice' in k: | |
loss_dice += losses[k] | |
else: | |
loss_mask += losses[k] | |
else: | |
# remove this loss if not specified in `weight_dict` | |
losses.pop(k) | |
#loss = 0.3 * loss_ce + 0.3 * loss_dice + 0.4 * loss_mask | |
smlm_loss = args.smlm_weight * nn.CrossEntropyLoss(ignore_index=-1)(mlm_predictions, mlm_targets) | |
loss = loss_ce + loss_dice + loss_mask + plic_loss + smlm_loss | |
#loss = criterion(output.squeeze(1), target.float()) | |
optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+ | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
torch.cuda.synchronize() | |
train_loss += loss.item() | |
iterations += 1 | |
#metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) | |
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), plic_loss=plic_loss.item(), plic_lang_loss=plic_lang_loss.item(), plic_pos_loss=plic_pos_loss.item(), plic_neg_loss=plic_neg_loss.item(), smlm_loss=smlm_loss.item()) | |
#metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"], loss_ce=loss_ce.item(), loss_dice=loss_dice.item(), loss_mask=loss_mask.item(), cl_loss=cl_loss.item(), cl_lang_loss=cl_lang_loss_print, cl_pos_loss=cl_pos_loss_print, cl_neg_loss=cl_neg_loss_print) | |
#del image, target, sentences, attentions, loss, output, data | |
#if bert_model is not None: | |
# del last_hidden_states, embedding | |
#gc.collect() | |
#torch.cuda.empty_cache() | |
#del loss | |
#del cl_loss | |
#del cl_lang_loss | |
#del loss_ce | |
#del loss_dice | |
#del loss_mask | |
torch.cuda.synchronize() | |
def main(args): | |
#def main(local_rank, args): | |
#ip = os.environ['MASTER_IP'] | |
#port = os.environ['MASTER_PORT'] | |
#hosts = int(os.environ['WORLD_SIZE']) # 机器个数 1 | |
#rank = int(os.environ['RANK']) # 当前机器编号 | |
#gpus = torch.cuda.device_count() # 每台机器的GPU个数 | |
#print(local_rank, rank, gpus) #3 0 8 | |
#dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank) | |
#torch.cuda.set_device(local_rank) | |
#dist.barrier() | |
##utils.init_distributed_mode(args) | |
#args.distributed=True | |
#args.gpu = local_rank | |
#print(args) | |
##misc.init_distributed_mode(args) | |
#print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) | |
#print("{}".format(args).replace(', ', ',\n')) | |
#device = torch.device(args.device) | |
# fix the seed for reproducibility | |
seed = args.seed + utils.get_rank() | |
print('seed', seed) | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
#cudnn.benchmark = True | |
dataset, num_classes = get_dataset("train", | |
get_transform(args=args), | |
args=args) | |
dataset_test, _ = get_dataset("val", | |
get_transform(args=args), | |
args=args) | |
# batch sampler | |
print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") | |
num_tasks = utils.get_world_size() | |
global_rank = utils.get_rank() | |
#num_tasks = hosts*gpus | |
#global_rank = rank*gpus+local_rank | |
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, | |
shuffle=True) | |
test_sampler = torch.utils.data.SequentialSampler(dataset_test) | |
# data loader | |
data_loader = torch.utils.data.DataLoader( | |
dataset, batch_size=args.batch_size, | |
sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True) | |
data_loader_test = torch.utils.data.DataLoader( | |
dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers) | |
# model initialization | |
print(args.model) | |
model = multimodal_segmentation_ppm.__dict__[args.model](pretrained=args.pretrained_swin_weights, | |
args=args) | |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
#model.cuda() | |
#model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True) | |
#model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=False) | |
#single_model = model.module | |
if args.model != 'lavt_one': | |
model_class = MultiModalBert | |
bert_model = model_class.from_pretrained(args.ck_bert, embed_dim=model.backbone.embed_dim) | |
bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel | |
#bert_model.cuda() | |
bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model) | |
#bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[local_rank]) | |
#single_bert_model = bert_model.module | |
else: | |
bert_model = None | |
single_bert_model = None | |
input_shape = dict() | |
input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) | |
input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) | |
input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) | |
input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) | |
cfg = Dict() | |
cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 | |
cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 | |
cfg.MODEL.MASK_FORMER.NHEADS = 8 | |
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = args.transformer_enc_layers | |
cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 | |
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 | |
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] | |
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 | |
cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 | |
cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = args.num_object_queries | |
cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = args.dim_feedforward | |
cfg.MODEL.MASK_FORMER.DEC_LAYERS = args.dec_layers | |
cfg.MODEL.MASK_FORMER.PRE_NORM = False | |
cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True | |
cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = args.no_object_weight | |
cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = args.class_weight | |
cfg.MODEL.MASK_FORMER.DICE_WEIGHT = args.dice_weight | |
cfg.MODEL.MASK_FORMER.MASK_WEIGHT = args.mask_weight | |
cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = args.train_num_points | |
cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0 | |
cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75 | |
print(cfg) | |
maskformer_head = MaskFormerHead(cfg, input_shape) | |
maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head) | |
#maskformer_head.cuda() | |
#maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False) | |
#single_head = maskformer_head.module | |
#print(single_head) | |
model = WrapperModel(model.backbone, bert_model, maskformer_head, args) | |
model.cuda() | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) | |
single_model = model.module | |
# mask2former loss | |
deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION | |
no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT | |
# loss weights | |
class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT | |
dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT | |
mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT | |
# self.criterion = Criterion(self.num_classes) | |
# building criterion | |
matcher = HungarianMatcher( | |
cost_class=class_weight, | |
cost_mask=mask_weight, | |
cost_dice=dice_weight, | |
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, | |
) | |
weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight} | |
if deep_supervision: | |
dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS | |
aux_weight_dict = {} | |
for i in range(dec_layers - 1): | |
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) | |
weight_dict.update(aux_weight_dict) | |
losses = ["labels", "masks"] | |
criterion = SetCriterion( | |
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, | |
matcher=matcher, | |
weight_dict=weight_dict, | |
eos_coef=no_object_weight, | |
losses=losses, | |
num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, | |
oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO, | |
importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO, | |
device='cuda' | |
) | |
if args.resume == "auto": | |
last_ckpt = "" | |
for e in range(args.epochs): | |
ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') | |
if os.path.exists(ckpt_path): | |
last_ckpt = ckpt_path | |
args.resume = last_ckpt | |
# resume training | |
if args.resume: | |
checkpoint = torch.load(args.resume, map_location='cpu') | |
single_model.load_state_dict(checkpoint['model']) | |
#if args.model != 'lavt_one': | |
# single_bert_model.load_state_dict(checkpoint['bert_model']) | |
# parameters to optimize | |
backbone_no_decay = list() | |
backbone_decay = list() | |
for name, m in single_model.image_model.named_parameters(): | |
if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: | |
backbone_no_decay.append(m) | |
else: | |
backbone_decay.append(m) | |
params_to_optimize = [ | |
{'params': backbone_no_decay, 'weight_decay': 0.0}, | |
{'params': backbone_decay}, | |
{"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, | |
# the following are the parameters of bert | |
{"params": reduce(operator.concat, | |
[[p for p in single_model.language_model.encoder.layer[i].parameters() | |
if p.requires_grad] for i in range(10)])}, | |
{"params": single_model.language_model.pwams.parameters()}, | |
{"params": single_model.language_model.res_gates.parameters()}, | |
{"params": single_model.language_model.norms.parameters()}, | |
{"params": single_model.lang_proj.parameters()}, | |
#{"params": single_model.language_model.parameters()}, | |
{'params': single_model.mlm_head.parameters()}, | |
{'params': single_model.mlm_vis_proj.parameters()}, | |
{'params': single_model.mlm_lang_proj.parameters()}, | |
{'params': single_model.mlm_transformer.parameters()}, | |
{'params': single_model.mlm_pos_embeds.parameters()}, | |
{'params': single_model.mlm_modal_embeds.parameters()}, | |
{'params': single_model.mlm_mask_embed.parameters()}, | |
{'params': single_model.mlm_pos_mlp.parameters()}, | |
#{'params': mlm_head.parameters(), 'weight_decay': 0.0}, | |
#{'params': mlm_vis_proj.parameters(), 'weight_decay': 0.0}, | |
#{'params': mlm_lang_proj.parameters(), 'weight_decay': 0.0}, | |
#{'params': mlm_transformer.parameters(), 'weight_decay': 0.0}, | |
#{'params': mlm_pos_embeds.parameters(), 'weight_decay': 0.0}, | |
#{'params': mlm_modal_embeds.parameters(), 'weight_decay': 0.0}, | |
#{'params': mlm_mask_embed.parameters(), 'weight_decay': 0.0}, | |
#{'params': mlm_pos_mlp.parameters(), 'weight_decay': 0.0}, | |
] | |
# optimizer | |
optimizer = torch.optim.AdamW(params_to_optimize, | |
lr=args.lr, | |
weight_decay=args.weight_decay, | |
amsgrad=args.amsgrad | |
) | |
# learning rate scheduler | |
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, | |
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) | |
# housekeeping | |
start_time = time.time() | |
iterations = 0 | |
best_oIoU = -0.1 | |
# resume training (optimizer, lr scheduler, and the epoch) | |
if args.resume: | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |
resume_epoch = checkpoint['epoch'] | |
else: | |
resume_epoch = -999 | |
# training loops | |
for epoch in range(max(0, resume_epoch+1), args.epochs): | |
data_loader.sampler.set_epoch(epoch) | |
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, | |
iterations, args) | |
iou, overallIoU = evaluate(model, data_loader_test) | |
print('Average object IoU {}'.format(iou)) | |
print('Overall IoU {}'.format(overallIoU)) | |
dict_to_save = {'model': single_model.state_dict(), | |
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, | |
'lr_scheduler': lr_scheduler.state_dict()} | |
checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch)) | |
utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP') | |
if utils.is_main_process(): | |
os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) | |
if utils.is_main_process(): | |
ckpt_paths = [] | |
for e in range(args.epochs): | |
ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') | |
print(ckpt_path) | |
if os.path.exists(ckpt_path): | |
ckpt_paths.append(ckpt_path) | |
print(ckpt_paths) | |
for ckpt_path in ckpt_paths[:-args.max_ckpt]: | |
os.remove(ckpt_path) | |
print("remove {:s}".format(ckpt_path)) | |
save_checkpoint = (best_oIoU < overallIoU) | |
if save_checkpoint: | |
print('Better epoch: {}\n'.format(epoch)) | |
dict_to_save = {'model': single_model.state_dict(), | |
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, | |
'lr_scheduler': lr_scheduler.state_dict()} | |
checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id)) | |
utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP') | |
if utils.is_main_process(): | |
os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) | |
best_oIoU = overallIoU | |
torch.cuda.empty_cache() | |
# summarize | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
print('Training time {}'.format(total_time_str)) | |
if __name__ == "__main__": | |
from args import get_parser | |
parser = get_parser() | |
args = parser.parse_args() | |
os.makedirs(args.output_dir, exist_ok=True) | |
# set up distributed learning | |
utils.init_distributed_mode(args) | |
print('Image size: {}'.format(str(args.img_size))) | |
main(args) | |
#mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count()) | |