curt-park's picture
Refactor code
1615d09
import logging
import os
import random
from collections import defaultdict
from copy import deepcopy
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from isegm.utils.distributed import (get_dp_wrapper, get_sampler,
reduce_loss_dict)
from isegm.utils.log import SummaryWriterAvg, TqdmToLogger, logger
from isegm.utils.misc import save_checkpoint
from isegm.utils.serialization import get_config_repr
from isegm.utils.vis import draw_points, draw_probmap
from .optimizer import get_optimizer
class ISTrainer(object):
def __init__(
self,
model,
cfg,
model_cfg,
loss_cfg,
trainset,
valset,
optimizer="adam",
optimizer_params=None,
image_dump_interval=200,
checkpoint_interval=10,
tb_dump_period=25,
max_interactive_points=0,
lr_scheduler=None,
metrics=None,
additional_val_metrics=None,
net_inputs=("images", "points"),
max_num_next_clicks=0,
click_models=None,
prev_mask_drop_prob=0.0,
):
self.cfg = cfg
self.model_cfg = model_cfg
self.max_interactive_points = max_interactive_points
self.loss_cfg = loss_cfg
self.val_loss_cfg = deepcopy(loss_cfg)
self.tb_dump_period = tb_dump_period
self.net_inputs = net_inputs
self.max_num_next_clicks = max_num_next_clicks
self.click_models = click_models
self.prev_mask_drop_prob = prev_mask_drop_prob
if cfg.distributed:
cfg.batch_size //= cfg.ngpus
cfg.val_batch_size //= cfg.ngpus
if metrics is None:
metrics = []
self.train_metrics = metrics
self.val_metrics = deepcopy(metrics)
if additional_val_metrics is not None:
self.val_metrics.extend(additional_val_metrics)
self.checkpoint_interval = checkpoint_interval
self.image_dump_interval = image_dump_interval
self.task_prefix = ""
self.sw = None
self.trainset = trainset
self.valset = valset
logger.info(
f"Dataset of {trainset.get_samples_number()} samples was loaded for training."
)
logger.info(
f"Dataset of {valset.get_samples_number()} samples was loaded for validation."
)
self.train_data = DataLoader(
trainset,
cfg.batch_size,
sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed),
drop_last=True,
pin_memory=True,
num_workers=cfg.workers,
)
self.val_data = DataLoader(
valset,
cfg.val_batch_size,
sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed),
drop_last=True,
pin_memory=True,
num_workers=cfg.workers,
)
self.optim = get_optimizer(model, optimizer, optimizer_params)
model = self._load_weights(model)
if cfg.multi_gpu:
model = get_dp_wrapper(cfg.distributed)(
model, device_ids=cfg.gpu_ids, output_device=cfg.gpu_ids[0]
)
if self.is_master:
logger.info(model)
logger.info(get_config_repr(model._config))
self.device = cfg.device
self.net = model.to(self.device)
self.lr = optimizer_params["lr"]
if lr_scheduler is not None:
self.lr_scheduler = lr_scheduler(optimizer=self.optim)
if cfg.start_epoch > 0:
for _ in range(cfg.start_epoch):
self.lr_scheduler.step()
self.tqdm_out = TqdmToLogger(logger, level=logging.INFO)
if self.click_models is not None:
for click_model in self.click_models:
for param in click_model.parameters():
param.requires_grad = False
click_model.to(self.device)
click_model.eval()
def run(self, num_epochs, start_epoch=None, validation=True):
if start_epoch is None:
start_epoch = self.cfg.start_epoch
logger.info(f"Starting Epoch: {start_epoch}")
logger.info(f"Total Epochs: {num_epochs}")
for epoch in range(start_epoch, num_epochs):
self.training(epoch)
if validation:
self.validation(epoch)
def training(self, epoch):
if self.sw is None and self.is_master:
self.sw = SummaryWriterAvg(
log_dir=str(self.cfg.LOGS_PATH),
flush_secs=10,
dump_period=self.tb_dump_period,
)
if self.cfg.distributed:
self.train_data.sampler.set_epoch(epoch)
log_prefix = "Train" + self.task_prefix.capitalize()
tbar = (
tqdm(self.train_data, file=self.tqdm_out, ncols=100)
if self.is_master
else self.train_data
)
for metric in self.train_metrics:
metric.reset_epoch_stats()
self.net.train()
train_loss = 0.0
for i, batch_data in enumerate(tbar):
global_step = epoch * len(self.train_data) + i
loss, losses_logging, splitted_batch_data, outputs = self.batch_forward(
batch_data
)
self.optim.zero_grad()
loss.backward()
self.optim.step()
losses_logging["overall"] = loss
reduce_loss_dict(losses_logging)
train_loss += losses_logging["overall"].item()
if self.is_master:
for loss_name, loss_value in losses_logging.items():
self.sw.add_scalar(
tag=f"{log_prefix}Losses/{loss_name}",
value=loss_value.item(),
global_step=global_step,
)
for k, v in self.loss_cfg.items():
if (
"_loss" in k
and hasattr(v, "log_states")
and self.loss_cfg.get(k + "_weight", 0.0) > 0
):
v.log_states(self.sw, f"{log_prefix}Losses/{k}", global_step)
if (
self.image_dump_interval > 0
and global_step % self.image_dump_interval == 0
):
self.save_visualization(
splitted_batch_data, outputs, global_step, prefix="train"
)
self.sw.add_scalar(
tag=f"{log_prefix}States/learning_rate",
value=self.lr
if not hasattr(self, "lr_scheduler")
else self.lr_scheduler.get_lr()[-1],
global_step=global_step,
)
tbar.set_description(
f"Epoch {epoch}, training loss {train_loss/(i+1):.4f}"
)
for metric in self.train_metrics:
metric.log_states(
self.sw, f"{log_prefix}Metrics/{metric.name}", global_step
)
if self.is_master:
for metric in self.train_metrics:
self.sw.add_scalar(
tag=f"{log_prefix}Metrics/{metric.name}",
value=metric.get_epoch_value(),
global_step=epoch,
disable_avg=True,
)
save_checkpoint(
self.net,
self.cfg.CHECKPOINTS_PATH,
prefix=self.task_prefix,
epoch=None,
multi_gpu=self.cfg.multi_gpu,
)
if isinstance(self.checkpoint_interval, (list, tuple)):
checkpoint_interval = [
x for x in self.checkpoint_interval if x[0] <= epoch
][-1][1]
else:
checkpoint_interval = self.checkpoint_interval
if epoch % checkpoint_interval == 0:
save_checkpoint(
self.net,
self.cfg.CHECKPOINTS_PATH,
prefix=self.task_prefix,
epoch=epoch,
multi_gpu=self.cfg.multi_gpu,
)
if hasattr(self, "lr_scheduler"):
self.lr_scheduler.step()
def validation(self, epoch):
if self.sw is None and self.is_master:
self.sw = SummaryWriterAvg(
log_dir=str(self.cfg.LOGS_PATH),
flush_secs=10,
dump_period=self.tb_dump_period,
)
log_prefix = "Val" + self.task_prefix.capitalize()
tbar = (
tqdm(self.val_data, file=self.tqdm_out, ncols=100)
if self.is_master
else self.val_data
)
for metric in self.val_metrics:
metric.reset_epoch_stats()
val_loss = 0
losses_logging = defaultdict(list)
self.net.eval()
for i, batch_data in enumerate(tbar):
global_step = epoch * len(self.val_data) + i
(
loss,
batch_losses_logging,
splitted_batch_data,
outputs,
) = self.batch_forward(batch_data, validation=True)
batch_losses_logging["overall"] = loss
reduce_loss_dict(batch_losses_logging)
for loss_name, loss_value in batch_losses_logging.items():
losses_logging[loss_name].append(loss_value.item())
val_loss += batch_losses_logging["overall"].item()
if self.is_master:
tbar.set_description(
f"Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}"
)
for metric in self.val_metrics:
metric.log_states(
self.sw, f"{log_prefix}Metrics/{metric.name}", global_step
)
if self.is_master:
for loss_name, loss_values in losses_logging.items():
self.sw.add_scalar(
tag=f"{log_prefix}Losses/{loss_name}",
value=np.array(loss_values).mean(),
global_step=epoch,
disable_avg=True,
)
for metric in self.val_metrics:
self.sw.add_scalar(
tag=f"{log_prefix}Metrics/{metric.name}",
value=metric.get_epoch_value(),
global_step=epoch,
disable_avg=True,
)
def batch_forward(self, batch_data, validation=False):
metrics = self.val_metrics if validation else self.train_metrics
losses_logging = dict()
with torch.set_grad_enabled(not validation):
batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
image, gt_mask, points = (
batch_data["images"],
batch_data["instances"],
batch_data["points"],
)
orig_image, orig_gt_mask, orig_points = (
image.clone(),
gt_mask.clone(),
points.clone(),
)
prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :]
last_click_indx = None
with torch.no_grad():
num_iters = random.randint(0, self.max_num_next_clicks)
for click_indx in range(num_iters):
last_click_indx = click_indx
if not validation:
self.net.eval()
if self.click_models is None or click_indx >= len(
self.click_models
):
eval_model = self.net
else:
eval_model = self.click_models[click_indx]
net_input = (
torch.cat((image, prev_output), dim=1)
if self.net.with_prev_mask
else image
)
prev_output = torch.sigmoid(
eval_model(net_input, points)["instances"]
)
points = get_next_points(
prev_output, orig_gt_mask, points, click_indx + 1
)
if not validation:
self.net.train()
if (
self.net.with_prev_mask
and self.prev_mask_drop_prob > 0
and last_click_indx is not None
):
zero_mask = (
np.random.random(size=prev_output.size(0))
< self.prev_mask_drop_prob
)
prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask])
batch_data["points"] = points
net_input = (
torch.cat((image, prev_output), dim=1)
if self.net.with_prev_mask
else image
)
output = self.net(net_input, points)
loss = 0.0
loss = self.add_loss(
"instance_loss",
loss,
losses_logging,
validation,
lambda: (output["instances"], batch_data["instances"]),
)
loss = self.add_loss(
"instance_aux_loss",
loss,
losses_logging,
validation,
lambda: (output["instances_aux"], batch_data["instances"]),
)
if self.is_master:
with torch.no_grad():
for m in metrics:
m.update(
*(output.get(x) for x in m.pred_outputs),
*(batch_data[x] for x in m.gt_outputs),
)
return loss, losses_logging, batch_data, output
def add_loss(
self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs
):
loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg
loss_weight = loss_cfg.get(loss_name + "_weight", 0.0)
if loss_weight > 0.0:
loss_criterion = loss_cfg.get(loss_name)
loss = loss_criterion(*lambda_loss_inputs())
loss = torch.mean(loss)
losses_logging[loss_name] = loss
loss = loss_weight * loss
total_loss = total_loss + loss
return total_loss
def save_visualization(self, splitted_batch_data, outputs, global_step, prefix):
output_images_path = self.cfg.VIS_PATH / prefix
if self.task_prefix:
output_images_path /= self.task_prefix
if not output_images_path.exists():
output_images_path.mkdir(parents=True)
image_name_prefix = f"{global_step:06d}"
def _save_image(suffix, image):
cv2.imwrite(
str(output_images_path / f"{image_name_prefix}_{suffix}.jpg"),
image,
[cv2.IMWRITE_JPEG_QUALITY, 85],
)
images = splitted_batch_data["images"]
points = splitted_batch_data["points"]
instance_masks = splitted_batch_data["instances"]
gt_instance_masks = instance_masks.cpu().numpy()
predicted_instance_masks = (
torch.sigmoid(outputs["instances"]).detach().cpu().numpy()
)
points = points.detach().cpu().numpy()
image_blob, points = images[0], points[0]
gt_mask = np.squeeze(gt_instance_masks[0], axis=0)
predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0)
image = image_blob.cpu().numpy() * 255
image = image.transpose((1, 2, 0))
image_with_points = draw_points(
image, points[: self.max_interactive_points], (0, 255, 0)
)
image_with_points = draw_points(
image_with_points, points[self.max_interactive_points :], (0, 0, 255)
)
gt_mask[gt_mask < 0] = 0.25
gt_mask = draw_probmap(gt_mask)
predicted_mask = draw_probmap(predicted_mask)
viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(
np.uint8
)
_save_image("instance_segmentation", viz_image[:, :, ::-1])
def _load_weights(self, net):
if self.cfg.weights is not None:
if os.path.isfile(self.cfg.weights):
load_weights(net, self.cfg.weights)
self.cfg.weights = None
else:
raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'")
elif self.cfg.resume_exp is not None:
checkpoints = list(
self.cfg.CHECKPOINTS_PATH.glob(f"{self.cfg.resume_prefix}*.pth")
)
assert len(checkpoints) == 1
checkpoint_path = checkpoints[0]
logger.info(f"Load checkpoint from path: {checkpoint_path}")
load_weights(net, str(checkpoint_path))
return net
@property
def is_master(self):
return self.cfg.local_rank == 0
def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49):
assert click_indx > 0
pred = pred.cpu().numpy()[:, 0, :, :]
gt = gt.cpu().numpy()[:, 0, :, :] > 0.5
fn_mask = np.logical_and(gt, pred < pred_thresh)
fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh)
fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8)
fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8)
num_points = points.size(1) // 2
points = points.clone()
for bindx in range(fn_mask.shape[0]):
fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1]
fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1]
fn_max_dist = np.max(fn_mask_dt)
fp_max_dist = np.max(fp_mask_dt)
is_positive = fn_max_dist > fp_max_dist
dt = fn_mask_dt if is_positive else fp_mask_dt
inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0
indices = np.argwhere(inner_mask)
if len(indices) > 0:
coords = indices[np.random.randint(0, len(indices))]
if is_positive:
points[bindx, num_points - click_indx, 0] = float(coords[0])
points[bindx, num_points - click_indx, 1] = float(coords[1])
points[bindx, num_points - click_indx, 2] = float(click_indx)
else:
points[bindx, 2 * num_points - click_indx, 0] = float(coords[0])
points[bindx, 2 * num_points - click_indx, 1] = float(coords[1])
points[bindx, 2 * num_points - click_indx, 2] = float(click_indx)
return points
def load_weights(model, path_to_weights):
current_state_dict = model.state_dict()
new_state_dict = torch.load(path_to_weights, map_location="cpu")["state_dict"]
current_state_dict.update(new_state_dict)
model.load_state_dict(current_state_dict)