from typing import Tuple from pathlib import Path from tqdm import tqdm from accelerate import Accelerator from accelerate.utils import set_seed from matplotlib import cm import numpy as np import matplotlib.pyplot as plt import argparse import json import wandb from datasets import load_dataset import torch from torch.utils.data import DataLoader from models.unet import UNet from dataset import SegmentationDataset, collate_fn from utils import get_transform, mask_transform, EMA from get_loss import get_composite_criterion from models.vit import ViTSegmentation from models.dino import DINOSegmentationModel color_map = cm.get_cmap('tab20', 18) fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255 def mask_to_color(mask: np.ndarray) -> np.ndarray: h, w = mask.shape color_mask = np.zeros((h, w, 3), dtype=np.uint8) for class_idx in range(18): color_mask[mask == class_idx] = fixed_colors[class_idx] return color_mask def create_combined_image( x: torch.Tensor, y: torch.Tensor, y_pred: torch.Tensor, mean: list[float] = [0.485, 0.456, 0.406], std: list[float] = [0.229, 0.224, 0.225] ) -> np.ndarray: batch_size, _, height, width = x.shape combined_height = height * 3 combined_width = width * batch_size combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8) for i in range(batch_size): image = x[i].cpu().permute(1, 2, 0).numpy() image = (image * std + mean).clip(0, 1) image = (image * 255).astype(np.uint8) true_mask = y[i].cpu().numpy() true_mask_color = mask_to_color(true_mask) pred_mask = y_pred[i].cpu().numpy() pred_mask_color = mask_to_color(pred_mask) combined_image[:height, i * width:(i + 1) * width, :] = image combined_image[height:2 * height, i * width:(i + 1) * width, :] = true_mask_color combined_image[2 * height:, i * width:(i + 1) * width, :] = pred_mask_color return combined_image def compute_metrics(y_pred: torch.Tensor, y: torch.Tensor, num_classes: int = 18) -> Tuple[float, float, float, float, float, float]: pred_mask = y_pred.unsqueeze(-1) == torch.arange(num_classes, device=y_pred.device).reshape(1, 1, 1, -1) target_mask = y.unsqueeze(-1) == torch.arange(num_classes, device=y.device).reshape(1, 1, 1, -1) class_present = (target_mask.sum(dim=(0, 1, 2)) > 0).float() tp = (pred_mask & target_mask).sum(dim=(0, 1, 2)).float() fp = (pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float() fn = (~pred_mask & target_mask).sum(dim=(0, 1, 2)).float() tn = (~pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float() overall_tp = tp.sum() overall_fp = fp.sum() overall_fn = fn.sum() overall_tn = tn.sum() precision = tp / (tp + fp).clamp(min=1e-8) recall = tp / (tp + fn).clamp(min=1e-8) accuracy = (tp + tn) / (tp + tn + fp + fn) macro_precision = ((precision * class_present).sum() / class_present.sum().clamp(min=1e-8)).item() macro_recall = ((recall * class_present).sum() / class_present.sum().clamp(min=1e-8)).item() macro_accuracy = accuracy.mean().item() micro_precision = (overall_tp / (overall_tp + overall_fp).clamp(min=1e-8)).item() micro_recall = (overall_tp / (overall_tp + overall_fn).clamp(min=1e-8)).item() global_accuracy = ((y_pred == y).sum() / (y.shape[0] * y.shape[1] * y.shape[2])).item() return macro_precision, macro_recall, macro_accuracy, micro_precision, micro_recall, global_accuracy def parse_args(): parser = argparse.ArgumentParser(description="Train a model on human parsing dataset") parser.add_argument("--data-path", type=str, default="mattmdjaga/human_parsing_dataset", help="Path to the data") parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing") parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader") parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader") parser.add_argument("--num-epochs", type=int, default=15, help="Number of training epochs") parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type") parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer") parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping") parser.add_argument("--logs-dir", type=str, default="dino-logs", help="Directory for saving logs") parser.add_argument("--model", type=str, default="dino", choices=["unet", "vit", "dino"], help="Model class name") parser.add_argument("--losses-path", type=str, default="losses_config.json", help="Path to the losses") parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision") parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps") parser.add_argument("--project-name", type=str, default="human_parsing_segmentation_ttk", help="WandB project name") parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights") parser.add_argument("--log-steps", type=int, default=400, help="Number of steps between logging training images and metrics") parser.add_argument("--seed", type=int, default=42, help="Value of the seed") return parser.parse_args() def main() -> None: args = parse_args() set_seed(args.seed) accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision) with open(args.losses_path, "r") as fp: losses_config = json.load(fp) with accelerator.main_process_first(): logs_dir = Path(args.logs_dir) logs_dir.mkdir(exist_ok=True) wandb.init(project=args.project_name, dir=logs_dir) wandb.save(args.losses_path) optimizer_class = getattr(torch.optim, args.optimizer) if args.model == "unet": model = UNet().to(accelerator.device) optimizer = optimizer_class(model.parameters(), lr=args.learning_rate) elif args.model == "vit": model = ViTSegmentation().to(accelerator.device) optimizer = optimizer_class(model.parameters(), lr=args.learning_rate) elif args.model == "dino": model = DINOSegmentationModel().to(accelerator.device) optimizer = optimizer_class(model.segmentation_head.parameters(), lr=args.learning_rate) else: raise NotImplementedError("Incorrect model name") transform = get_transform(model.mean, model.std) dataset = load_dataset(args.data_path, split="train") train_dataset = SegmentationDataset(dataset, train=True, transform=transform, target_transform=mask_transform) valid_dataset = SegmentationDataset(dataset, train=False, transform=transform, target_transform=mask_transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn) valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn) criterion = get_composite_criterion(losses_config) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader)) model, optimizer, train_loader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader, lr_scheduler) best_accuracy = 0 print(f"params: {sum([p.numel() for p in model.parameters()])/1e6:.2f} M") print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M") train_loss_ema, train_macro_precision_ema, train_macro_recall_ema, train_macro_accuracy_ema, train_micro_precision_ema, train_micro_recall_ema, train_global_accuracy_ema = EMA(), EMA(), EMA(), EMA(), EMA(), EMA(), EMA() for epoch in range(1, args.num_epochs + 1): model.train() print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M") exit() pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}") for index, (x, y) in enumerate(pbar): x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device) with accelerator.accumulate(model): with accelerator.autocast(): output = model(x) loss = criterion(output, y) accelerator.backward(loss) train_loss = loss.item() grad_norm = None _, y_pred = output.max(dim=1) train_macro_precision, train_macro_recall, train_macro_accuracy, train_micro_precision, train_micro_recall, train_global_accuracy = compute_metrics(y_pred, y) if accelerator.sync_gradients: grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item() optimizer.step() lr_scheduler.step() optimizer.zero_grad() if (index + 1) % args.log_steps == 0 and accelerator.is_main_process: images_to_log = [] combined_image = create_combined_image(x, y, y_pred) images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Train, Epoch {epoch}, Batch {index})")) wandb.log({"train_samples": images_to_log}) pbar.set_postfix({"loss": train_loss_ema(train_loss), "macro_precision": train_macro_precision_ema(train_macro_precision), "macro_recall": train_macro_recall_ema(train_macro_recall), "macro_accuracy": train_macro_accuracy_ema(train_macro_accuracy), "micro_precision": train_micro_precision_ema(train_micro_precision), "micro_recall": train_micro_recall_ema(train_micro_recall), "global_accuracy": train_global_accuracy_ema(train_global_accuracy)}) log_data = { "train/epoch": epoch, "train/loss": train_loss, "train/macro_accuracy": train_macro_accuracy, "train/learning_rate": optimizer.param_groups[0]["lr"], "train/macro_precision": train_macro_precision, "train/macro_recall": train_macro_recall, "train/micro_precision": train_micro_precision, "train/micro_recall": train_micro_recall, "train/global_accuracy": train_global_accuracy, } if grad_norm is not None: log_data["train/grad_norm"] = grad_norm if accelerator.is_main_process: wandb.log(log_data) accelerator.wait_for_everyone() model.eval() valid_loss, valid_macro_accuracies, valid_macro_precisions, valid_macro_recalls, valid_global_accuracies, valid_micro_precisions, valid_micro_recalls = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 with torch.inference_mode(): pbar = tqdm(valid_loader, desc=f"Val epoch {epoch}/{args.num_epochs}") for index, (x, y) in enumerate(valid_loader): x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device) output = model(x) _, y_pred = output.max(dim=1) if (index + 1) % args.log_steps == 0 and accelerator.is_main_process: images_to_log = [] combined_image = create_combined_image(x, y, y_pred) images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Validation, Epoch {epoch})")) wandb.log({"valid_samples": images_to_log}) valid_macro_precision, valid_macro_recall, valid_macro_accuracy, valid_micro_precision, valid_micro_recall, valid_global_accuracy = compute_metrics(y_pred, y) valid_macro_precisions += valid_macro_precision valid_macro_recalls += valid_macro_recall valid_macro_accuracies += valid_macro_accuracy valid_micro_precisions += valid_micro_precision valid_micro_recalls += valid_micro_recall valid_global_accuracies += valid_global_accuracy loss = criterion(output, y) valid_loss += loss.item() valid_loss = valid_loss / len(valid_loader) valid_macro_accuracies = valid_macro_accuracies / len(valid_loader) valid_macro_precisions = valid_macro_precisions / len(valid_loader) valid_macro_recalls = valid_macro_recalls / len(valid_loader) valid_global_accuracies = valid_global_accuracies / len(valid_loader) valid_micro_precisions = valid_micro_precisions / len(valid_loader) valid_micro_recalls = valid_micro_recalls / len(valid_loader) accelerator.print(f"loss: {valid_loss:.3f}, valid_macro_precision: {valid_macro_precisions:.3f}, valid_macro_recall: {valid_macro_recalls:.3f}, valid_macro_accuracy: {valid_macro_accuracies:.3f}, valid_micro_precision: {valid_micro_precisions:.3f}, valid_micro_recall: {valid_micro_recalls:.3f}, valid_global_accuracy: {valid_global_accuracies:.3f}") if accelerator.is_main_process: wandb.log( { "val/epoch": epoch, "val/loss": valid_loss, "val/macro_accuracy": valid_macro_accuracies, "val/macro_precision": valid_macro_precisions, "val/macro_recall": valid_macro_recalls, "val/global_accuracy": valid_global_accuracies, "val/micro_precision": valid_micro_precisions, "val/micro_recall": valid_micro_recalls, } ) if valid_global_accuracies > best_accuracy: best_accuracy = valid_global_accuracies if args.model in ["dino", "vit"]: accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-best.pth") else: accelerator.save(model.state_dict(), logs_dir / f"checkpoint-best.pth") accelerator.print(f"new best_accuracy {best_accuracy}, {epoch=}") if epoch % args.save_frequency == 0: if args.model in ["dino", "vit"]: accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth") else: accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth") accelerator.wait_for_everyone() accelerator.wait_for_everyone() wandb.finish() if __name__ == "__main__": main()