pg56714's picture
Upload 115 files
8e5cc83 verified
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from PIL import Image
from tqdm import tqdm
from efficientvit.apps.trainer import Trainer
from efficientvit.apps.utils import AverageMeter, get_dist_local_rank, get_dist_size, is_master, sync_tensor
from efficientvit.models.utils import list_join
from efficientvit.samcore.data_provider import SAMDataProvider
from efficientvit.samcore.trainer import SAMRunConfig
from efficientvit.samcore.trainer.utils import (
compute_boundary_iou,
compute_iou,
loss_masks,
mask_iou_batch,
masks_sample_points,
)
__all__ = ["SAMTrainer"]
class SAMTrainer(Trainer):
def __init__(
self,
path: str,
model: nn.Module,
data_provider: SAMDataProvider,
) -> None:
super().__init__(
path=path,
model=model,
data_provider=data_provider,
)
if is_master():
self.wandb_log = wandb.init(project="efficientvit-sam")
def _validate(self, model, data_loader, epoch: int, sub_epoch: int) -> dict[str, any]:
val_loss = AverageMeter()
val_iou = AverageMeter()
val_iou_boundary = AverageMeter()
with torch.no_grad():
with tqdm(
total=len(data_loader),
desc=f"Validate Epoch #{epoch + 1}, Sub Epoch #{sub_epoch+1}",
disable=not is_master(),
file=sys.stdout,
) as t:
for i, data in enumerate(data_loader):
image = data["image"].cuda()
masks = data["masks"].cuda()
bboxs = data["bboxs"].cuda() * 2 if image.shape[2] == 512 else data["bboxs"].cuda()
points = data["points"].cuda() * 2 if image.shape[2] == 512 else data["points"].cuda()
bboxs[..., 2] = bboxs[..., 0] + bboxs[..., 2]
bboxs[..., 3] = bboxs[..., 1] + bboxs[..., 3]
batched_input = []
for b_i in range(len(image)):
dict_input = dict()
dict_input["image"] = image[b_i]
dict_input["boxes"] = bboxs[b_i]
batched_input.append(dict_input)
output, iou_predictions = model(batched_input, True)
B, M, N, H, W = output.shape
output = torch.stack(
[
output[k][torch.arange(M), iou_predictions[k].argmax(-1).squeeze()]
for k in range(len(output))
],
dim=0,
)
output = (
F.interpolate(output, size=(image.shape[2], image.shape[3]), mode="bilinear")
.reshape(-1, image.shape[2], image.shape[3])
.unsqueeze(1)
)
masks = masks.reshape(-1, image.shape[2], image.shape[3]).unsqueeze(1)
loss_mask, loss_dice = loss_masks(output, masks, len(output))
loss = loss_mask * 20 + loss_dice
iou = compute_iou(output, masks * 255)
boundary_iou = compute_boundary_iou(output, masks * 255)
loss = sync_tensor(loss)
iou = sync_tensor(iou)
boundary_iou = sync_tensor(boundary_iou)
val_loss.update(loss, image.shape[0] * get_dist_size())
val_iou.update(iou, image.shape[0] * get_dist_size())
val_iou_boundary.update(boundary_iou, image.shape[0] * get_dist_size())
t.set_postfix(
{
"loss": val_loss.avg,
"iou": val_iou.avg,
"boundary_iou": val_iou_boundary.avg,
"bs": image.shape[0] * get_dist_size(),
}
)
t.update()
if is_master():
self.wandb_log.log(
{"val_loss": val_loss.avg, "val_iou": val_iou.avg, "val_boundary_iou": val_iou_boundary.avg}
)
return {
"val_loss": val_loss.avg,
"val_iou": val_iou.avg,
"val_boundary_iou": val_iou_boundary.avg,
}
def validate(self, model=None, data_loader=None, epoch=0, sub_epoch=0) -> dict[str, any]:
model = model or self.eval_network
if data_loader is None:
data_loader = self.data_provider.valid
model.eval()
return self._validate(model, data_loader, epoch, sub_epoch)
def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
image = feed_dict["image"].cuda()
masks = feed_dict["masks"].cuda()
bboxs = feed_dict["bboxs"].cuda() * 2 if image.shape[2] == 512 else feed_dict["bboxs"].cuda()
points = feed_dict["points"].cuda() * 2 if image.shape[2] == 512 else feed_dict["points"].cuda()
bboxs[..., 2] = bboxs[..., 0] + bboxs[..., 2]
bboxs[..., 3] = bboxs[..., 1] + bboxs[..., 3]
return {
"image": image,
"masks": masks,
"points": points,
"bboxs": bboxs,
}
def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
image = feed_dict["image"]
masks = feed_dict["masks"]
bboxs = feed_dict["bboxs"]
points = feed_dict["points"]
batched_input = []
for b_i in range(len(image)):
dict_input = dict()
dict_input["image"] = image[b_i]
if random.random() >= 0.5:
dict_input["boxes"] = bboxs[b_i]
else:
try:
n_p = int(random.random() * 10 + 1)
dict_input["point_coords"] = masks_sample_points(masks[b_i], k=n_p)
if image.shape[2] == 512:
dict_input["point_coords"] = dict_input["point_coords"] * 2
dict_input["point_labels"] = torch.ones((points[b_i].shape[0], n_p), device=image.device)
except:
dict_input["boxes"] = bboxs[b_i]
batched_input.append(dict_input)
with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.enable_amp):
if random.random() >= 0.5:
output, iou_predictions = self.model(batched_input, multimask_output=True)
else:
output, iou_predictions = self.model(batched_input, multimask_output=False)
masks = masks.reshape(-1, image.shape[2], image.shape[3]).unsqueeze(1)
loss_list = []
for i in range(output.shape[2]):
output_i = (
F.interpolate(output[:, :, i], size=(image.shape[2], image.shape[3]), mode="bilinear")
.reshape(-1, image.shape[2], image.shape[3])
.unsqueeze(1)
)
loss_mask_i, loss_dice_i = loss_masks(output_i, masks, len(output_i), mode="none")
loss_i = loss_mask_i * 20 + loss_dice_i
loss_list.append(loss_i)
loss = torch.stack(loss_list, -1)
min_indices = torch.argmin(loss, dim=1)
mask = torch.zeros_like(loss, device=loss.device)
mask.scatter_(1, min_indices.unsqueeze(1), 1)
loss = (loss * mask).mean() * loss.shape[-1]
self.scaler.scale(loss).backward()
return {"loss": loss, "output": output}
def _train_one_sub_epoch(self, epoch: int, sub_epoch: int) -> dict[str, any]:
train_loss = AverageMeter()
with tqdm(
total=len(self.data_provider.train),
desc=f"Train Epoch #{epoch + 1}, Sub Epoch #{sub_epoch + 1}",
disable=not is_master(),
file=sys.stdout,
) as t:
for i, data in enumerate(self.data_provider.train):
feed_dict = data
# preprocessing
feed_dict = self.before_step(feed_dict)
# clear gradient
self.optimizer.zero_grad()
# forward & backward
output_dict = self.run_step(feed_dict)
# update: optimizer, lr_scheduler
self.after_step()
loss = output_dict["loss"]
loss = sync_tensor(loss)
train_loss.update(loss, data["image"].shape[0] * get_dist_size())
if is_master():
self.wandb_log.log(
{
"train_loss": train_loss.avg,
"epoch": epoch,
"sub_epoch": sub_epoch,
"learning_rate": sorted(set([group["lr"] for group in self.optimizer.param_groups]))[0],
}
)
t.set_postfix(
{
"loss": train_loss.avg,
"bs": data["image"].shape[0] * get_dist_size(),
"res": data["image"].shape[2],
"lr": list_join(
sorted(set([group["lr"] for group in self.optimizer.param_groups])),
"#",
"%.1E",
),
"progress": self.run_config.progress,
}
)
t.update()
return {
"train_loss": train_loss.avg,
}
def train_one_sub_epoch(self, epoch: int, sub_epoch: int) -> dict[str, any]:
self.model.train()
self.data_provider.set_epoch_and_sub_epoch(epoch, sub_epoch)
train_info_dict = self._train_one_sub_epoch(epoch, sub_epoch)
return train_info_dict
def train(self) -> None:
for sub_epoch in range(self.start_epoch, self.run_config.n_epochs):
epoch = sub_epoch // self.data_provider.sub_epochs_per_epoch
train_info_dict = self.train_one_sub_epoch(epoch, sub_epoch)
val_info_dict = self.validate(epoch=epoch, sub_epoch=sub_epoch)
val_iou = val_info_dict["val_iou"]
is_best = val_iou > self.best_val
self.best_val = max(val_iou, self.best_val)
self.save_model(
only_state_dict=False,
epoch=sub_epoch,
model_name=f"checkpoint_{epoch}_{sub_epoch}.pt",
)
def prep_for_training(self, run_config: SAMRunConfig, amp="fp32") -> None:
self.run_config = run_config
self.model = nn.parallel.DistributedDataParallel(
self.model.cuda(),
device_ids=[get_dist_local_rank()],
find_unused_parameters=True,
)
self.run_config.global_step = 0
self.run_config.batch_per_epoch = len(self.data_provider.train)
assert self.run_config.batch_per_epoch > 0, "Training set is empty"
# build optimizer
self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)
# amp
self.amp = amp
self.scaler = torch.cuda.amp.GradScaler(enabled=self.enable_amp)