Spaces:
Runtime error
Runtime error
import random | |
import sys | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import wandb | |
from PIL import Image | |
from tqdm import tqdm | |
from efficientvit.apps.trainer import Trainer | |
from efficientvit.apps.utils import AverageMeter, get_dist_local_rank, get_dist_size, is_master, sync_tensor | |
from efficientvit.models.utils import list_join | |
from efficientvit.samcore.data_provider import SAMDataProvider | |
from efficientvit.samcore.trainer import SAMRunConfig | |
from efficientvit.samcore.trainer.utils import ( | |
compute_boundary_iou, | |
compute_iou, | |
loss_masks, | |
mask_iou_batch, | |
masks_sample_points, | |
) | |
__all__ = ["SAMTrainer"] | |
class SAMTrainer(Trainer): | |
def __init__( | |
self, | |
path: str, | |
model: nn.Module, | |
data_provider: SAMDataProvider, | |
) -> None: | |
super().__init__( | |
path=path, | |
model=model, | |
data_provider=data_provider, | |
) | |
if is_master(): | |
self.wandb_log = wandb.init(project="efficientvit-sam") | |
def _validate(self, model, data_loader, epoch: int, sub_epoch: int) -> dict[str, any]: | |
val_loss = AverageMeter() | |
val_iou = AverageMeter() | |
val_iou_boundary = AverageMeter() | |
with torch.no_grad(): | |
with tqdm( | |
total=len(data_loader), | |
desc=f"Validate Epoch #{epoch + 1}, Sub Epoch #{sub_epoch+1}", | |
disable=not is_master(), | |
file=sys.stdout, | |
) as t: | |
for i, data in enumerate(data_loader): | |
image = data["image"].cuda() | |
masks = data["masks"].cuda() | |
bboxs = data["bboxs"].cuda() * 2 if image.shape[2] == 512 else data["bboxs"].cuda() | |
points = data["points"].cuda() * 2 if image.shape[2] == 512 else data["points"].cuda() | |
bboxs[..., 2] = bboxs[..., 0] + bboxs[..., 2] | |
bboxs[..., 3] = bboxs[..., 1] + bboxs[..., 3] | |
batched_input = [] | |
for b_i in range(len(image)): | |
dict_input = dict() | |
dict_input["image"] = image[b_i] | |
dict_input["boxes"] = bboxs[b_i] | |
batched_input.append(dict_input) | |
output, iou_predictions = model(batched_input, True) | |
B, M, N, H, W = output.shape | |
output = torch.stack( | |
[ | |
output[k][torch.arange(M), iou_predictions[k].argmax(-1).squeeze()] | |
for k in range(len(output)) | |
], | |
dim=0, | |
) | |
output = ( | |
F.interpolate(output, size=(image.shape[2], image.shape[3]), mode="bilinear") | |
.reshape(-1, image.shape[2], image.shape[3]) | |
.unsqueeze(1) | |
) | |
masks = masks.reshape(-1, image.shape[2], image.shape[3]).unsqueeze(1) | |
loss_mask, loss_dice = loss_masks(output, masks, len(output)) | |
loss = loss_mask * 20 + loss_dice | |
iou = compute_iou(output, masks * 255) | |
boundary_iou = compute_boundary_iou(output, masks * 255) | |
loss = sync_tensor(loss) | |
iou = sync_tensor(iou) | |
boundary_iou = sync_tensor(boundary_iou) | |
val_loss.update(loss, image.shape[0] * get_dist_size()) | |
val_iou.update(iou, image.shape[0] * get_dist_size()) | |
val_iou_boundary.update(boundary_iou, image.shape[0] * get_dist_size()) | |
t.set_postfix( | |
{ | |
"loss": val_loss.avg, | |
"iou": val_iou.avg, | |
"boundary_iou": val_iou_boundary.avg, | |
"bs": image.shape[0] * get_dist_size(), | |
} | |
) | |
t.update() | |
if is_master(): | |
self.wandb_log.log( | |
{"val_loss": val_loss.avg, "val_iou": val_iou.avg, "val_boundary_iou": val_iou_boundary.avg} | |
) | |
return { | |
"val_loss": val_loss.avg, | |
"val_iou": val_iou.avg, | |
"val_boundary_iou": val_iou_boundary.avg, | |
} | |
def validate(self, model=None, data_loader=None, epoch=0, sub_epoch=0) -> dict[str, any]: | |
model = model or self.eval_network | |
if data_loader is None: | |
data_loader = self.data_provider.valid | |
model.eval() | |
return self._validate(model, data_loader, epoch, sub_epoch) | |
def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: | |
image = feed_dict["image"].cuda() | |
masks = feed_dict["masks"].cuda() | |
bboxs = feed_dict["bboxs"].cuda() * 2 if image.shape[2] == 512 else feed_dict["bboxs"].cuda() | |
points = feed_dict["points"].cuda() * 2 if image.shape[2] == 512 else feed_dict["points"].cuda() | |
bboxs[..., 2] = bboxs[..., 0] + bboxs[..., 2] | |
bboxs[..., 3] = bboxs[..., 1] + bboxs[..., 3] | |
return { | |
"image": image, | |
"masks": masks, | |
"points": points, | |
"bboxs": bboxs, | |
} | |
def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: | |
image = feed_dict["image"] | |
masks = feed_dict["masks"] | |
bboxs = feed_dict["bboxs"] | |
points = feed_dict["points"] | |
batched_input = [] | |
for b_i in range(len(image)): | |
dict_input = dict() | |
dict_input["image"] = image[b_i] | |
if random.random() >= 0.5: | |
dict_input["boxes"] = bboxs[b_i] | |
else: | |
try: | |
n_p = int(random.random() * 10 + 1) | |
dict_input["point_coords"] = masks_sample_points(masks[b_i], k=n_p) | |
if image.shape[2] == 512: | |
dict_input["point_coords"] = dict_input["point_coords"] * 2 | |
dict_input["point_labels"] = torch.ones((points[b_i].shape[0], n_p), device=image.device) | |
except: | |
dict_input["boxes"] = bboxs[b_i] | |
batched_input.append(dict_input) | |
with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.enable_amp): | |
if random.random() >= 0.5: | |
output, iou_predictions = self.model(batched_input, multimask_output=True) | |
else: | |
output, iou_predictions = self.model(batched_input, multimask_output=False) | |
masks = masks.reshape(-1, image.shape[2], image.shape[3]).unsqueeze(1) | |
loss_list = [] | |
for i in range(output.shape[2]): | |
output_i = ( | |
F.interpolate(output[:, :, i], size=(image.shape[2], image.shape[3]), mode="bilinear") | |
.reshape(-1, image.shape[2], image.shape[3]) | |
.unsqueeze(1) | |
) | |
loss_mask_i, loss_dice_i = loss_masks(output_i, masks, len(output_i), mode="none") | |
loss_i = loss_mask_i * 20 + loss_dice_i | |
loss_list.append(loss_i) | |
loss = torch.stack(loss_list, -1) | |
min_indices = torch.argmin(loss, dim=1) | |
mask = torch.zeros_like(loss, device=loss.device) | |
mask.scatter_(1, min_indices.unsqueeze(1), 1) | |
loss = (loss * mask).mean() * loss.shape[-1] | |
self.scaler.scale(loss).backward() | |
return {"loss": loss, "output": output} | |
def _train_one_sub_epoch(self, epoch: int, sub_epoch: int) -> dict[str, any]: | |
train_loss = AverageMeter() | |
with tqdm( | |
total=len(self.data_provider.train), | |
desc=f"Train Epoch #{epoch + 1}, Sub Epoch #{sub_epoch + 1}", | |
disable=not is_master(), | |
file=sys.stdout, | |
) as t: | |
for i, data in enumerate(self.data_provider.train): | |
feed_dict = data | |
# preprocessing | |
feed_dict = self.before_step(feed_dict) | |
# clear gradient | |
self.optimizer.zero_grad() | |
# forward & backward | |
output_dict = self.run_step(feed_dict) | |
# update: optimizer, lr_scheduler | |
self.after_step() | |
loss = output_dict["loss"] | |
loss = sync_tensor(loss) | |
train_loss.update(loss, data["image"].shape[0] * get_dist_size()) | |
if is_master(): | |
self.wandb_log.log( | |
{ | |
"train_loss": train_loss.avg, | |
"epoch": epoch, | |
"sub_epoch": sub_epoch, | |
"learning_rate": sorted(set([group["lr"] for group in self.optimizer.param_groups]))[0], | |
} | |
) | |
t.set_postfix( | |
{ | |
"loss": train_loss.avg, | |
"bs": data["image"].shape[0] * get_dist_size(), | |
"res": data["image"].shape[2], | |
"lr": list_join( | |
sorted(set([group["lr"] for group in self.optimizer.param_groups])), | |
"#", | |
"%.1E", | |
), | |
"progress": self.run_config.progress, | |
} | |
) | |
t.update() | |
return { | |
"train_loss": train_loss.avg, | |
} | |
def train_one_sub_epoch(self, epoch: int, sub_epoch: int) -> dict[str, any]: | |
self.model.train() | |
self.data_provider.set_epoch_and_sub_epoch(epoch, sub_epoch) | |
train_info_dict = self._train_one_sub_epoch(epoch, sub_epoch) | |
return train_info_dict | |
def train(self) -> None: | |
for sub_epoch in range(self.start_epoch, self.run_config.n_epochs): | |
epoch = sub_epoch // self.data_provider.sub_epochs_per_epoch | |
train_info_dict = self.train_one_sub_epoch(epoch, sub_epoch) | |
val_info_dict = self.validate(epoch=epoch, sub_epoch=sub_epoch) | |
val_iou = val_info_dict["val_iou"] | |
is_best = val_iou > self.best_val | |
self.best_val = max(val_iou, self.best_val) | |
self.save_model( | |
only_state_dict=False, | |
epoch=sub_epoch, | |
model_name=f"checkpoint_{epoch}_{sub_epoch}.pt", | |
) | |
def prep_for_training(self, run_config: SAMRunConfig, amp="fp32") -> None: | |
self.run_config = run_config | |
self.model = nn.parallel.DistributedDataParallel( | |
self.model.cuda(), | |
device_ids=[get_dist_local_rank()], | |
find_unused_parameters=True, | |
) | |
self.run_config.global_step = 0 | |
self.run_config.batch_per_epoch = len(self.data_provider.train) | |
assert self.run_config.batch_per_epoch > 0, "Training set is empty" | |
# build optimizer | |
self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model) | |
# amp | |
self.amp = amp | |
self.scaler = torch.cuda.amp.GradScaler(enabled=self.enable_amp) | |