Spaces:
Runtime error
Runtime error
import logging | |
import os | |
import random | |
from collections import defaultdict | |
from copy import deepcopy | |
import cv2 | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from isegm.utils.distributed import (get_dp_wrapper, get_sampler, | |
reduce_loss_dict) | |
from isegm.utils.log import SummaryWriterAvg, TqdmToLogger, logger | |
from isegm.utils.misc import save_checkpoint | |
from isegm.utils.serialization import get_config_repr | |
from isegm.utils.vis import draw_points, draw_probmap | |
from .optimizer import get_optimizer | |
class ISTrainer(object): | |
def __init__( | |
self, | |
model, | |
cfg, | |
model_cfg, | |
loss_cfg, | |
trainset, | |
valset, | |
optimizer="adam", | |
optimizer_params=None, | |
image_dump_interval=200, | |
checkpoint_interval=10, | |
tb_dump_period=25, | |
max_interactive_points=0, | |
lr_scheduler=None, | |
metrics=None, | |
additional_val_metrics=None, | |
net_inputs=("images", "points"), | |
max_num_next_clicks=0, | |
click_models=None, | |
prev_mask_drop_prob=0.0, | |
): | |
self.cfg = cfg | |
self.model_cfg = model_cfg | |
self.max_interactive_points = max_interactive_points | |
self.loss_cfg = loss_cfg | |
self.val_loss_cfg = deepcopy(loss_cfg) | |
self.tb_dump_period = tb_dump_period | |
self.net_inputs = net_inputs | |
self.max_num_next_clicks = max_num_next_clicks | |
self.click_models = click_models | |
self.prev_mask_drop_prob = prev_mask_drop_prob | |
if cfg.distributed: | |
cfg.batch_size //= cfg.ngpus | |
cfg.val_batch_size //= cfg.ngpus | |
if metrics is None: | |
metrics = [] | |
self.train_metrics = metrics | |
self.val_metrics = deepcopy(metrics) | |
if additional_val_metrics is not None: | |
self.val_metrics.extend(additional_val_metrics) | |
self.checkpoint_interval = checkpoint_interval | |
self.image_dump_interval = image_dump_interval | |
self.task_prefix = "" | |
self.sw = None | |
self.trainset = trainset | |
self.valset = valset | |
logger.info( | |
f"Dataset of {trainset.get_samples_number()} samples was loaded for training." | |
) | |
logger.info( | |
f"Dataset of {valset.get_samples_number()} samples was loaded for validation." | |
) | |
self.train_data = DataLoader( | |
trainset, | |
cfg.batch_size, | |
sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed), | |
drop_last=True, | |
pin_memory=True, | |
num_workers=cfg.workers, | |
) | |
self.val_data = DataLoader( | |
valset, | |
cfg.val_batch_size, | |
sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed), | |
drop_last=True, | |
pin_memory=True, | |
num_workers=cfg.workers, | |
) | |
self.optim = get_optimizer(model, optimizer, optimizer_params) | |
model = self._load_weights(model) | |
if cfg.multi_gpu: | |
model = get_dp_wrapper(cfg.distributed)( | |
model, device_ids=cfg.gpu_ids, output_device=cfg.gpu_ids[0] | |
) | |
if self.is_master: | |
logger.info(model) | |
logger.info(get_config_repr(model._config)) | |
self.device = cfg.device | |
self.net = model.to(self.device) | |
self.lr = optimizer_params["lr"] | |
if lr_scheduler is not None: | |
self.lr_scheduler = lr_scheduler(optimizer=self.optim) | |
if cfg.start_epoch > 0: | |
for _ in range(cfg.start_epoch): | |
self.lr_scheduler.step() | |
self.tqdm_out = TqdmToLogger(logger, level=logging.INFO) | |
if self.click_models is not None: | |
for click_model in self.click_models: | |
for param in click_model.parameters(): | |
param.requires_grad = False | |
click_model.to(self.device) | |
click_model.eval() | |
def run(self, num_epochs, start_epoch=None, validation=True): | |
if start_epoch is None: | |
start_epoch = self.cfg.start_epoch | |
logger.info(f"Starting Epoch: {start_epoch}") | |
logger.info(f"Total Epochs: {num_epochs}") | |
for epoch in range(start_epoch, num_epochs): | |
self.training(epoch) | |
if validation: | |
self.validation(epoch) | |
def training(self, epoch): | |
if self.sw is None and self.is_master: | |
self.sw = SummaryWriterAvg( | |
log_dir=str(self.cfg.LOGS_PATH), | |
flush_secs=10, | |
dump_period=self.tb_dump_period, | |
) | |
if self.cfg.distributed: | |
self.train_data.sampler.set_epoch(epoch) | |
log_prefix = "Train" + self.task_prefix.capitalize() | |
tbar = ( | |
tqdm(self.train_data, file=self.tqdm_out, ncols=100) | |
if self.is_master | |
else self.train_data | |
) | |
for metric in self.train_metrics: | |
metric.reset_epoch_stats() | |
self.net.train() | |
train_loss = 0.0 | |
for i, batch_data in enumerate(tbar): | |
global_step = epoch * len(self.train_data) + i | |
loss, losses_logging, splitted_batch_data, outputs = self.batch_forward( | |
batch_data | |
) | |
self.optim.zero_grad() | |
loss.backward() | |
self.optim.step() | |
losses_logging["overall"] = loss | |
reduce_loss_dict(losses_logging) | |
train_loss += losses_logging["overall"].item() | |
if self.is_master: | |
for loss_name, loss_value in losses_logging.items(): | |
self.sw.add_scalar( | |
tag=f"{log_prefix}Losses/{loss_name}", | |
value=loss_value.item(), | |
global_step=global_step, | |
) | |
for k, v in self.loss_cfg.items(): | |
if ( | |
"_loss" in k | |
and hasattr(v, "log_states") | |
and self.loss_cfg.get(k + "_weight", 0.0) > 0 | |
): | |
v.log_states(self.sw, f"{log_prefix}Losses/{k}", global_step) | |
if ( | |
self.image_dump_interval > 0 | |
and global_step % self.image_dump_interval == 0 | |
): | |
self.save_visualization( | |
splitted_batch_data, outputs, global_step, prefix="train" | |
) | |
self.sw.add_scalar( | |
tag=f"{log_prefix}States/learning_rate", | |
value=self.lr | |
if not hasattr(self, "lr_scheduler") | |
else self.lr_scheduler.get_lr()[-1], | |
global_step=global_step, | |
) | |
tbar.set_description( | |
f"Epoch {epoch}, training loss {train_loss/(i+1):.4f}" | |
) | |
for metric in self.train_metrics: | |
metric.log_states( | |
self.sw, f"{log_prefix}Metrics/{metric.name}", global_step | |
) | |
if self.is_master: | |
for metric in self.train_metrics: | |
self.sw.add_scalar( | |
tag=f"{log_prefix}Metrics/{metric.name}", | |
value=metric.get_epoch_value(), | |
global_step=epoch, | |
disable_avg=True, | |
) | |
save_checkpoint( | |
self.net, | |
self.cfg.CHECKPOINTS_PATH, | |
prefix=self.task_prefix, | |
epoch=None, | |
multi_gpu=self.cfg.multi_gpu, | |
) | |
if isinstance(self.checkpoint_interval, (list, tuple)): | |
checkpoint_interval = [ | |
x for x in self.checkpoint_interval if x[0] <= epoch | |
][-1][1] | |
else: | |
checkpoint_interval = self.checkpoint_interval | |
if epoch % checkpoint_interval == 0: | |
save_checkpoint( | |
self.net, | |
self.cfg.CHECKPOINTS_PATH, | |
prefix=self.task_prefix, | |
epoch=epoch, | |
multi_gpu=self.cfg.multi_gpu, | |
) | |
if hasattr(self, "lr_scheduler"): | |
self.lr_scheduler.step() | |
def validation(self, epoch): | |
if self.sw is None and self.is_master: | |
self.sw = SummaryWriterAvg( | |
log_dir=str(self.cfg.LOGS_PATH), | |
flush_secs=10, | |
dump_period=self.tb_dump_period, | |
) | |
log_prefix = "Val" + self.task_prefix.capitalize() | |
tbar = ( | |
tqdm(self.val_data, file=self.tqdm_out, ncols=100) | |
if self.is_master | |
else self.val_data | |
) | |
for metric in self.val_metrics: | |
metric.reset_epoch_stats() | |
val_loss = 0 | |
losses_logging = defaultdict(list) | |
self.net.eval() | |
for i, batch_data in enumerate(tbar): | |
global_step = epoch * len(self.val_data) + i | |
( | |
loss, | |
batch_losses_logging, | |
splitted_batch_data, | |
outputs, | |
) = self.batch_forward(batch_data, validation=True) | |
batch_losses_logging["overall"] = loss | |
reduce_loss_dict(batch_losses_logging) | |
for loss_name, loss_value in batch_losses_logging.items(): | |
losses_logging[loss_name].append(loss_value.item()) | |
val_loss += batch_losses_logging["overall"].item() | |
if self.is_master: | |
tbar.set_description( | |
f"Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}" | |
) | |
for metric in self.val_metrics: | |
metric.log_states( | |
self.sw, f"{log_prefix}Metrics/{metric.name}", global_step | |
) | |
if self.is_master: | |
for loss_name, loss_values in losses_logging.items(): | |
self.sw.add_scalar( | |
tag=f"{log_prefix}Losses/{loss_name}", | |
value=np.array(loss_values).mean(), | |
global_step=epoch, | |
disable_avg=True, | |
) | |
for metric in self.val_metrics: | |
self.sw.add_scalar( | |
tag=f"{log_prefix}Metrics/{metric.name}", | |
value=metric.get_epoch_value(), | |
global_step=epoch, | |
disable_avg=True, | |
) | |
def batch_forward(self, batch_data, validation=False): | |
metrics = self.val_metrics if validation else self.train_metrics | |
losses_logging = dict() | |
with torch.set_grad_enabled(not validation): | |
batch_data = {k: v.to(self.device) for k, v in batch_data.items()} | |
image, gt_mask, points = ( | |
batch_data["images"], | |
batch_data["instances"], | |
batch_data["points"], | |
) | |
orig_image, orig_gt_mask, orig_points = ( | |
image.clone(), | |
gt_mask.clone(), | |
points.clone(), | |
) | |
prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :] | |
last_click_indx = None | |
with torch.no_grad(): | |
num_iters = random.randint(0, self.max_num_next_clicks) | |
for click_indx in range(num_iters): | |
last_click_indx = click_indx | |
if not validation: | |
self.net.eval() | |
if self.click_models is None or click_indx >= len( | |
self.click_models | |
): | |
eval_model = self.net | |
else: | |
eval_model = self.click_models[click_indx] | |
net_input = ( | |
torch.cat((image, prev_output), dim=1) | |
if self.net.with_prev_mask | |
else image | |
) | |
prev_output = torch.sigmoid( | |
eval_model(net_input, points)["instances"] | |
) | |
points = get_next_points( | |
prev_output, orig_gt_mask, points, click_indx + 1 | |
) | |
if not validation: | |
self.net.train() | |
if ( | |
self.net.with_prev_mask | |
and self.prev_mask_drop_prob > 0 | |
and last_click_indx is not None | |
): | |
zero_mask = ( | |
np.random.random(size=prev_output.size(0)) | |
< self.prev_mask_drop_prob | |
) | |
prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask]) | |
batch_data["points"] = points | |
net_input = ( | |
torch.cat((image, prev_output), dim=1) | |
if self.net.with_prev_mask | |
else image | |
) | |
output = self.net(net_input, points) | |
loss = 0.0 | |
loss = self.add_loss( | |
"instance_loss", | |
loss, | |
losses_logging, | |
validation, | |
lambda: (output["instances"], batch_data["instances"]), | |
) | |
loss = self.add_loss( | |
"instance_aux_loss", | |
loss, | |
losses_logging, | |
validation, | |
lambda: (output["instances_aux"], batch_data["instances"]), | |
) | |
if self.is_master: | |
with torch.no_grad(): | |
for m in metrics: | |
m.update( | |
*(output.get(x) for x in m.pred_outputs), | |
*(batch_data[x] for x in m.gt_outputs), | |
) | |
return loss, losses_logging, batch_data, output | |
def add_loss( | |
self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs | |
): | |
loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg | |
loss_weight = loss_cfg.get(loss_name + "_weight", 0.0) | |
if loss_weight > 0.0: | |
loss_criterion = loss_cfg.get(loss_name) | |
loss = loss_criterion(*lambda_loss_inputs()) | |
loss = torch.mean(loss) | |
losses_logging[loss_name] = loss | |
loss = loss_weight * loss | |
total_loss = total_loss + loss | |
return total_loss | |
def save_visualization(self, splitted_batch_data, outputs, global_step, prefix): | |
output_images_path = self.cfg.VIS_PATH / prefix | |
if self.task_prefix: | |
output_images_path /= self.task_prefix | |
if not output_images_path.exists(): | |
output_images_path.mkdir(parents=True) | |
image_name_prefix = f"{global_step:06d}" | |
def _save_image(suffix, image): | |
cv2.imwrite( | |
str(output_images_path / f"{image_name_prefix}_{suffix}.jpg"), | |
image, | |
[cv2.IMWRITE_JPEG_QUALITY, 85], | |
) | |
images = splitted_batch_data["images"] | |
points = splitted_batch_data["points"] | |
instance_masks = splitted_batch_data["instances"] | |
gt_instance_masks = instance_masks.cpu().numpy() | |
predicted_instance_masks = ( | |
torch.sigmoid(outputs["instances"]).detach().cpu().numpy() | |
) | |
points = points.detach().cpu().numpy() | |
image_blob, points = images[0], points[0] | |
gt_mask = np.squeeze(gt_instance_masks[0], axis=0) | |
predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0) | |
image = image_blob.cpu().numpy() * 255 | |
image = image.transpose((1, 2, 0)) | |
image_with_points = draw_points( | |
image, points[: self.max_interactive_points], (0, 255, 0) | |
) | |
image_with_points = draw_points( | |
image_with_points, points[self.max_interactive_points :], (0, 0, 255) | |
) | |
gt_mask[gt_mask < 0] = 0.25 | |
gt_mask = draw_probmap(gt_mask) | |
predicted_mask = draw_probmap(predicted_mask) | |
viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype( | |
np.uint8 | |
) | |
_save_image("instance_segmentation", viz_image[:, :, ::-1]) | |
def _load_weights(self, net): | |
if self.cfg.weights is not None: | |
if os.path.isfile(self.cfg.weights): | |
load_weights(net, self.cfg.weights) | |
self.cfg.weights = None | |
else: | |
raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'") | |
elif self.cfg.resume_exp is not None: | |
checkpoints = list( | |
self.cfg.CHECKPOINTS_PATH.glob(f"{self.cfg.resume_prefix}*.pth") | |
) | |
assert len(checkpoints) == 1 | |
checkpoint_path = checkpoints[0] | |
logger.info(f"Load checkpoint from path: {checkpoint_path}") | |
load_weights(net, str(checkpoint_path)) | |
return net | |
def is_master(self): | |
return self.cfg.local_rank == 0 | |
def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): | |
assert click_indx > 0 | |
pred = pred.cpu().numpy()[:, 0, :, :] | |
gt = gt.cpu().numpy()[:, 0, :, :] > 0.5 | |
fn_mask = np.logical_and(gt, pred < pred_thresh) | |
fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) | |
fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8) | |
fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8) | |
num_points = points.size(1) // 2 | |
points = points.clone() | |
for bindx in range(fn_mask.shape[0]): | |
fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] | |
fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] | |
fn_max_dist = np.max(fn_mask_dt) | |
fp_max_dist = np.max(fp_mask_dt) | |
is_positive = fn_max_dist > fp_max_dist | |
dt = fn_mask_dt if is_positive else fp_mask_dt | |
inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 | |
indices = np.argwhere(inner_mask) | |
if len(indices) > 0: | |
coords = indices[np.random.randint(0, len(indices))] | |
if is_positive: | |
points[bindx, num_points - click_indx, 0] = float(coords[0]) | |
points[bindx, num_points - click_indx, 1] = float(coords[1]) | |
points[bindx, num_points - click_indx, 2] = float(click_indx) | |
else: | |
points[bindx, 2 * num_points - click_indx, 0] = float(coords[0]) | |
points[bindx, 2 * num_points - click_indx, 1] = float(coords[1]) | |
points[bindx, 2 * num_points - click_indx, 2] = float(click_indx) | |
return points | |
def load_weights(model, path_to_weights): | |
current_state_dict = model.state_dict() | |
new_state_dict = torch.load(path_to_weights, map_location="cpu")["state_dict"] | |
current_state_dict.update(new_state_dict) | |
model.load_state_dict(current_state_dict) | |