import os import random import logging from copy import deepcopy from collections import defaultdict import cv2 import torch import numpy as np from tqdm import tqdm from torch.utils.data import DataLoader from isegm.utils.log import logger, TqdmToLogger, SummaryWriterAvg from isegm.utils.vis import draw_probmap, draw_points from isegm.utils.misc import save_checkpoint from isegm.utils.serialization import get_config_repr from isegm.utils.distributed import get_dp_wrapper, get_sampler, reduce_loss_dict from .optimizer import get_optimizer class ISTrainer(object): def __init__(self, model, cfg, model_cfg, loss_cfg, trainset, valset, optimizer='adam', optimizer_params=None, image_dump_interval=200, checkpoint_interval=10, tb_dump_period=25, max_interactive_points=0, lr_scheduler=None, metrics=None, additional_val_metrics=None, net_inputs=('images', 'points'), max_num_next_clicks=0, click_models=None, prev_mask_drop_prob=0.0, ): self.cfg = cfg self.model_cfg = model_cfg self.max_interactive_points = max_interactive_points self.loss_cfg = loss_cfg self.val_loss_cfg = deepcopy(loss_cfg) self.tb_dump_period = tb_dump_period self.net_inputs = net_inputs self.max_num_next_clicks = max_num_next_clicks self.click_models = click_models self.prev_mask_drop_prob = prev_mask_drop_prob if cfg.distributed: cfg.batch_size //= cfg.ngpus cfg.val_batch_size //= cfg.ngpus if metrics is None: metrics = [] self.train_metrics = metrics self.val_metrics = deepcopy(metrics) if additional_val_metrics is not None: self.val_metrics.extend(additional_val_metrics) self.checkpoint_interval = checkpoint_interval self.image_dump_interval = image_dump_interval self.task_prefix = '' self.sw = None self.trainset = trainset self.valset = valset logger.info(f'Dataset of {trainset.get_samples_number()} samples was loaded for training.') logger.info(f'Dataset of {valset.get_samples_number()} samples was loaded for validation.') self.train_data = DataLoader( trainset, cfg.batch_size, sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed), drop_last=True, pin_memory=True, num_workers=cfg.workers ) self.val_data = DataLoader( valset, cfg.val_batch_size, sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed), drop_last=True, pin_memory=True, num_workers=cfg.workers ) self.optim = get_optimizer(model, optimizer, optimizer_params) model = self._load_weights(model) if cfg.multi_gpu: model = get_dp_wrapper(cfg.distributed)(model, device_ids=cfg.gpu_ids, output_device=cfg.gpu_ids[0]) if self.is_master: logger.info(model) logger.info(get_config_repr(model._config)) self.device = cfg.device self.net = model.to(self.device) self.lr = optimizer_params['lr'] if lr_scheduler is not None: self.lr_scheduler = lr_scheduler(optimizer=self.optim) if cfg.start_epoch > 0: for _ in range(cfg.start_epoch): self.lr_scheduler.step() self.tqdm_out = TqdmToLogger(logger, level=logging.INFO) if self.click_models is not None: for click_model in self.click_models: for param in click_model.parameters(): param.requires_grad = False click_model.to(self.device) click_model.eval() def run(self, num_epochs, start_epoch=None, validation=True): if start_epoch is None: start_epoch = self.cfg.start_epoch logger.info(f'Starting Epoch: {start_epoch}') logger.info(f'Total Epochs: {num_epochs}') for epoch in range(start_epoch, num_epochs): self.training(epoch) if validation: self.validation(epoch) def training(self, epoch): if self.sw is None and self.is_master: self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), flush_secs=10, dump_period=self.tb_dump_period) if self.cfg.distributed: self.train_data.sampler.set_epoch(epoch) log_prefix = 'Train' + self.task_prefix.capitalize() tbar = tqdm(self.train_data, file=self.tqdm_out, ncols=100)\ if self.is_master else self.train_data for metric in self.train_metrics: metric.reset_epoch_stats() self.net.train() train_loss = 0.0 for i, batch_data in enumerate(tbar): global_step = epoch * len(self.train_data) + i loss, losses_logging, splitted_batch_data, outputs = \ self.batch_forward(batch_data) self.optim.zero_grad() loss.backward() self.optim.step() losses_logging['overall'] = loss reduce_loss_dict(losses_logging) train_loss += losses_logging['overall'].item() if self.is_master: for loss_name, loss_value in losses_logging.items(): self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', value=loss_value.item(), global_step=global_step) for k, v in self.loss_cfg.items(): if '_loss' in k and hasattr(v, 'log_states') and self.loss_cfg.get(k + '_weight', 0.0) > 0: v.log_states(self.sw, f'{log_prefix}Losses/{k}', global_step) if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0: self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train') self.sw.add_scalar(tag=f'{log_prefix}States/learning_rate', value=self.lr if not hasattr(self, 'lr_scheduler') else self.lr_scheduler.get_lr()[-1], global_step=global_step) tbar.set_description(f'Epoch {epoch}, training loss {train_loss/(i+1):.4f}') for metric in self.train_metrics: metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) if self.is_master: for metric in self.train_metrics: self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', value=metric.get_epoch_value(), global_step=epoch, disable_avg=True) save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, epoch=None, multi_gpu=self.cfg.multi_gpu) if isinstance(self.checkpoint_interval, (list, tuple)): checkpoint_interval = [x for x in self.checkpoint_interval if x[0] <= epoch][-1][1] else: checkpoint_interval = self.checkpoint_interval if epoch % checkpoint_interval == 0: save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, epoch=epoch, multi_gpu=self.cfg.multi_gpu) if hasattr(self, 'lr_scheduler'): self.lr_scheduler.step() def validation(self, epoch): if self.sw is None and self.is_master: self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), flush_secs=10, dump_period=self.tb_dump_period) log_prefix = 'Val' + self.task_prefix.capitalize() tbar = tqdm(self.val_data, file=self.tqdm_out, ncols=100) if self.is_master else self.val_data for metric in self.val_metrics: metric.reset_epoch_stats() val_loss = 0 losses_logging = defaultdict(list) self.net.eval() for i, batch_data in enumerate(tbar): global_step = epoch * len(self.val_data) + i loss, batch_losses_logging, splitted_batch_data, outputs = \ self.batch_forward(batch_data, validation=True) batch_losses_logging['overall'] = loss reduce_loss_dict(batch_losses_logging) for loss_name, loss_value in batch_losses_logging.items(): losses_logging[loss_name].append(loss_value.item()) val_loss += batch_losses_logging['overall'].item() if self.is_master: tbar.set_description(f'Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}') for metric in self.val_metrics: metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) if self.is_master: for loss_name, loss_values in losses_logging.items(): self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', value=np.array(loss_values).mean(), global_step=epoch, disable_avg=True) for metric in self.val_metrics: self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', value=metric.get_epoch_value(), global_step=epoch, disable_avg=True) def batch_forward(self, batch_data, validation=False): metrics = self.val_metrics if validation else self.train_metrics losses_logging = dict() with torch.set_grad_enabled(not validation): batch_data = {k: v.to(self.device) for k, v in batch_data.items()} image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points'] orig_image, orig_gt_mask, orig_points = image.clone(), gt_mask.clone(), points.clone() prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :] last_click_indx = None with torch.no_grad(): num_iters = random.randint(0, self.max_num_next_clicks) for click_indx in range(num_iters): last_click_indx = click_indx if not validation: self.net.eval() if self.click_models is None or click_indx >= len(self.click_models): eval_model = self.net else: eval_model = self.click_models[click_indx] net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image prev_output = torch.sigmoid(eval_model(net_input, points)['instances']) points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1) if not validation: self.net.train() if self.net.with_prev_mask and self.prev_mask_drop_prob > 0 and last_click_indx is not None: zero_mask = np.random.random(size=prev_output.size(0)) < self.prev_mask_drop_prob prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask]) batch_data['points'] = points net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image output = self.net(net_input, points) loss = 0.0 loss = self.add_loss('instance_loss', loss, losses_logging, validation, lambda: (output['instances'], batch_data['instances'])) loss = self.add_loss('instance_aux_loss', loss, losses_logging, validation, lambda: (output['instances_aux'], batch_data['instances'])) if self.is_master: with torch.no_grad(): for m in metrics: m.update(*(output.get(x) for x in m.pred_outputs), *(batch_data[x] for x in m.gt_outputs)) return loss, losses_logging, batch_data, output def add_loss(self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs): loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg loss_weight = loss_cfg.get(loss_name + '_weight', 0.0) if loss_weight > 0.0: loss_criterion = loss_cfg.get(loss_name) loss = loss_criterion(*lambda_loss_inputs()) loss = torch.mean(loss) losses_logging[loss_name] = loss loss = loss_weight * loss total_loss = total_loss + loss return total_loss def save_visualization(self, splitted_batch_data, outputs, global_step, prefix): output_images_path = self.cfg.VIS_PATH / prefix if self.task_prefix: output_images_path /= self.task_prefix if not output_images_path.exists(): output_images_path.mkdir(parents=True) image_name_prefix = f'{global_step:06d}' def _save_image(suffix, image): cv2.imwrite(str(output_images_path / f'{image_name_prefix}_{suffix}.jpg'), image, [cv2.IMWRITE_JPEG_QUALITY, 85]) images = splitted_batch_data['images'] points = splitted_batch_data['points'] instance_masks = splitted_batch_data['instances'] gt_instance_masks = instance_masks.cpu().numpy() predicted_instance_masks = torch.sigmoid(outputs['instances']).detach().cpu().numpy() points = points.detach().cpu().numpy() image_blob, points = images[0], points[0] gt_mask = np.squeeze(gt_instance_masks[0], axis=0) predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0) image = image_blob.cpu().numpy() * 255 image = image.transpose((1, 2, 0)) image_with_points = draw_points(image, points[:self.max_interactive_points], (0, 255, 0)) image_with_points = draw_points(image_with_points, points[self.max_interactive_points:], (0, 0, 255)) gt_mask[gt_mask < 0] = 0.25 gt_mask = draw_probmap(gt_mask) predicted_mask = draw_probmap(predicted_mask) viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(np.uint8) _save_image('instance_segmentation', viz_image[:, :, ::-1]) def _load_weights(self, net): if self.cfg.weights is not None: if os.path.isfile(self.cfg.weights): load_weights(net, self.cfg.weights) self.cfg.weights = None else: raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'") elif self.cfg.resume_exp is not None: checkpoints = list(self.cfg.CHECKPOINTS_PATH.glob(f'{self.cfg.resume_prefix}*.pth')) assert len(checkpoints) == 1 checkpoint_path = checkpoints[0] logger.info(f'Load checkpoint from path: {checkpoint_path}') load_weights(net, str(checkpoint_path)) return net @property def is_master(self): return self.cfg.local_rank == 0 def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): assert click_indx > 0 pred = pred.cpu().numpy()[:, 0, :, :] gt = gt.cpu().numpy()[:, 0, :, :] > 0.5 fn_mask = np.logical_and(gt, pred < pred_thresh) fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) num_points = points.size(1) // 2 points = points.clone() for bindx in range(fn_mask.shape[0]): fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] fn_max_dist = np.max(fn_mask_dt) fp_max_dist = np.max(fp_mask_dt) is_positive = fn_max_dist > fp_max_dist dt = fn_mask_dt if is_positive else fp_mask_dt inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 indices = np.argwhere(inner_mask) if len(indices) > 0: coords = indices[np.random.randint(0, len(indices))] if is_positive: points[bindx, num_points - click_indx, 0] = float(coords[0]) points[bindx, num_points - click_indx, 1] = float(coords[1]) points[bindx, num_points - click_indx, 2] = float(click_indx) else: points[bindx, 2 * num_points - click_indx, 0] = float(coords[0]) points[bindx, 2 * num_points - click_indx, 1] = float(coords[1]) points[bindx, 2 * num_points - click_indx, 2] = float(click_indx) return points def load_weights(model, path_to_weights): current_state_dict = model.state_dict() new_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict'] current_state_dict.update(new_state_dict) model.load_state_dict(current_state_dict)