eksemyashkina's picture
Upload 9 files
af720c2 verified
raw
history blame
15.3 kB
from typing import Tuple
from pathlib import Path
from tqdm import tqdm
from accelerate import Accelerator
from accelerate.utils import set_seed
from matplotlib import cm
import numpy as np
import matplotlib.pyplot as plt
import argparse
import json
import wandb
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from models.unet import UNet
from dataset import SegmentationDataset, collate_fn
from utils import get_transform, mask_transform, EMA
from get_loss import get_composite_criterion
from models.vit import ViTSegmentation
from models.dino import DINOSegmentationModel
color_map = cm.get_cmap('tab20', 18)
fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255
def mask_to_color(mask: np.ndarray) -> np.ndarray:
h, w = mask.shape
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
for class_idx in range(18):
color_mask[mask == class_idx] = fixed_colors[class_idx]
return color_mask
def create_combined_image(
x: torch.Tensor,
y: torch.Tensor,
y_pred: torch.Tensor,
mean: list[float] = [0.485, 0.456, 0.406],
std: list[float] = [0.229, 0.224, 0.225]
) -> np.ndarray:
batch_size, _, height, width = x.shape
combined_height = height * 3
combined_width = width * batch_size
combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
for i in range(batch_size):
image = x[i].cpu().permute(1, 2, 0).numpy()
image = (image * std + mean).clip(0, 1)
image = (image * 255).astype(np.uint8)
true_mask = y[i].cpu().numpy()
true_mask_color = mask_to_color(true_mask)
pred_mask = y_pred[i].cpu().numpy()
pred_mask_color = mask_to_color(pred_mask)
combined_image[:height, i * width:(i + 1) * width, :] = image
combined_image[height:2 * height, i * width:(i + 1) * width, :] = true_mask_color
combined_image[2 * height:, i * width:(i + 1) * width, :] = pred_mask_color
return combined_image
def compute_metrics(y_pred: torch.Tensor, y: torch.Tensor, num_classes: int = 18) -> Tuple[float, float, float, float, float, float]:
pred_mask = y_pred.unsqueeze(-1) == torch.arange(num_classes, device=y_pred.device).reshape(1, 1, 1, -1)
target_mask = y.unsqueeze(-1) == torch.arange(num_classes, device=y.device).reshape(1, 1, 1, -1)
class_present = (target_mask.sum(dim=(0, 1, 2)) > 0).float()
tp = (pred_mask & target_mask).sum(dim=(0, 1, 2)).float()
fp = (pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float()
fn = (~pred_mask & target_mask).sum(dim=(0, 1, 2)).float()
tn = (~pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float()
overall_tp = tp.sum()
overall_fp = fp.sum()
overall_fn = fn.sum()
overall_tn = tn.sum()
precision = tp / (tp + fp).clamp(min=1e-8)
recall = tp / (tp + fn).clamp(min=1e-8)
accuracy = (tp + tn) / (tp + tn + fp + fn)
macro_precision = ((precision * class_present).sum() / class_present.sum().clamp(min=1e-8)).item()
macro_recall = ((recall * class_present).sum() / class_present.sum().clamp(min=1e-8)).item()
macro_accuracy = accuracy.mean().item()
micro_precision = (overall_tp / (overall_tp + overall_fp).clamp(min=1e-8)).item()
micro_recall = (overall_tp / (overall_tp + overall_fn).clamp(min=1e-8)).item()
global_accuracy = ((y_pred == y).sum() / (y.shape[0] * y.shape[1] * y.shape[2])).item()
return macro_precision, macro_recall, macro_accuracy, micro_precision, micro_recall, global_accuracy
def parse_args():
parser = argparse.ArgumentParser(description="Train a model on human parsing dataset")
parser.add_argument("--data-path", type=str, default="mattmdjaga/human_parsing_dataset", help="Path to the data")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing")
parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader")
parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader")
parser.add_argument("--num-epochs", type=int, default=15, help="Number of training epochs")
parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer")
parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
parser.add_argument("--logs-dir", type=str, default="dino-logs", help="Directory for saving logs")
parser.add_argument("--model", type=str, default="dino", choices=["unet", "vit", "dino"], help="Model class name")
parser.add_argument("--losses-path", type=str, default="losses_config.json", help="Path to the losses")
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision")
parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps")
parser.add_argument("--project-name", type=str, default="human_parsing_segmentation_ttk", help="WandB project name")
parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
parser.add_argument("--log-steps", type=int, default=400, help="Number of steps between logging training images and metrics")
parser.add_argument("--seed", type=int, default=42, help="Value of the seed")
return parser.parse_args()
def main() -> None:
args = parse_args()
set_seed(args.seed)
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
with open(args.losses_path, "r") as fp:
losses_config = json.load(fp)
with accelerator.main_process_first():
logs_dir = Path(args.logs_dir)
logs_dir.mkdir(exist_ok=True)
wandb.init(project=args.project_name, dir=logs_dir)
wandb.save(args.losses_path)
optimizer_class = getattr(torch.optim, args.optimizer)
if args.model == "unet":
model = UNet().to(accelerator.device)
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
elif args.model == "vit":
model = ViTSegmentation().to(accelerator.device)
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
elif args.model == "dino":
model = DINOSegmentationModel().to(accelerator.device)
optimizer = optimizer_class(model.segmentation_head.parameters(), lr=args.learning_rate)
else:
raise NotImplementedError("Incorrect model name")
transform = get_transform(model.mean, model.std)
dataset = load_dataset(args.data_path, split="train")
train_dataset = SegmentationDataset(dataset, train=True, transform=transform, target_transform=mask_transform)
valid_dataset = SegmentationDataset(dataset, train=False, transform=transform, target_transform=mask_transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn)
criterion = get_composite_criterion(losses_config)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader))
model, optimizer, train_loader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader, lr_scheduler)
best_accuracy = 0
print(f"params: {sum([p.numel() for p in model.parameters()])/1e6:.2f} M")
print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M")
train_loss_ema, train_macro_precision_ema, train_macro_recall_ema, train_macro_accuracy_ema, train_micro_precision_ema, train_micro_recall_ema, train_global_accuracy_ema = EMA(), EMA(), EMA(), EMA(), EMA(), EMA(), EMA()
for epoch in range(1, args.num_epochs + 1):
model.train()
print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M")
exit()
pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}")
for index, (x, y) in enumerate(pbar):
x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device)
with accelerator.accumulate(model):
with accelerator.autocast():
output = model(x)
loss = criterion(output, y)
accelerator.backward(loss)
train_loss = loss.item()
grad_norm = None
_, y_pred = output.max(dim=1)
train_macro_precision, train_macro_recall, train_macro_accuracy, train_micro_precision, train_micro_recall, train_global_accuracy = compute_metrics(y_pred, y)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
images_to_log = []
combined_image = create_combined_image(x, y, y_pred)
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Train, Epoch {epoch}, Batch {index})"))
wandb.log({"train_samples": images_to_log})
pbar.set_postfix({"loss": train_loss_ema(train_loss), "macro_precision": train_macro_precision_ema(train_macro_precision), "macro_recall": train_macro_recall_ema(train_macro_recall), "macro_accuracy": train_macro_accuracy_ema(train_macro_accuracy), "micro_precision": train_micro_precision_ema(train_micro_precision), "micro_recall": train_micro_recall_ema(train_micro_recall), "global_accuracy": train_global_accuracy_ema(train_global_accuracy)})
log_data = {
"train/epoch": epoch,
"train/loss": train_loss,
"train/macro_accuracy": train_macro_accuracy,
"train/learning_rate": optimizer.param_groups[0]["lr"],
"train/macro_precision": train_macro_precision,
"train/macro_recall": train_macro_recall,
"train/micro_precision": train_micro_precision,
"train/micro_recall": train_micro_recall,
"train/global_accuracy": train_global_accuracy,
}
if grad_norm is not None:
log_data["train/grad_norm"] = grad_norm
if accelerator.is_main_process:
wandb.log(log_data)
accelerator.wait_for_everyone()
model.eval()
valid_loss, valid_macro_accuracies, valid_macro_precisions, valid_macro_recalls, valid_global_accuracies, valid_micro_precisions, valid_micro_recalls = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
with torch.inference_mode():
pbar = tqdm(valid_loader, desc=f"Val epoch {epoch}/{args.num_epochs}")
for index, (x, y) in enumerate(valid_loader):
x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device)
output = model(x)
_, y_pred = output.max(dim=1)
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
images_to_log = []
combined_image = create_combined_image(x, y, y_pred)
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Validation, Epoch {epoch})"))
wandb.log({"valid_samples": images_to_log})
valid_macro_precision, valid_macro_recall, valid_macro_accuracy, valid_micro_precision, valid_micro_recall, valid_global_accuracy = compute_metrics(y_pred, y)
valid_macro_precisions += valid_macro_precision
valid_macro_recalls += valid_macro_recall
valid_macro_accuracies += valid_macro_accuracy
valid_micro_precisions += valid_micro_precision
valid_micro_recalls += valid_micro_recall
valid_global_accuracies += valid_global_accuracy
loss = criterion(output, y)
valid_loss += loss.item()
valid_loss = valid_loss / len(valid_loader)
valid_macro_accuracies = valid_macro_accuracies / len(valid_loader)
valid_macro_precisions = valid_macro_precisions / len(valid_loader)
valid_macro_recalls = valid_macro_recalls / len(valid_loader)
valid_global_accuracies = valid_global_accuracies / len(valid_loader)
valid_micro_precisions = valid_micro_precisions / len(valid_loader)
valid_micro_recalls = valid_micro_recalls / len(valid_loader)
accelerator.print(f"loss: {valid_loss:.3f}, valid_macro_precision: {valid_macro_precisions:.3f}, valid_macro_recall: {valid_macro_recalls:.3f}, valid_macro_accuracy: {valid_macro_accuracies:.3f}, valid_micro_precision: {valid_micro_precisions:.3f}, valid_micro_recall: {valid_micro_recalls:.3f}, valid_global_accuracy: {valid_global_accuracies:.3f}")
if accelerator.is_main_process:
wandb.log(
{
"val/epoch": epoch,
"val/loss": valid_loss,
"val/macro_accuracy": valid_macro_accuracies,
"val/macro_precision": valid_macro_precisions,
"val/macro_recall": valid_macro_recalls,
"val/global_accuracy": valid_global_accuracies,
"val/micro_precision": valid_micro_precisions,
"val/micro_recall": valid_micro_recalls,
}
)
if valid_global_accuracies > best_accuracy:
best_accuracy = valid_global_accuracies
if args.model in ["dino", "vit"]:
accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-best.pth")
else:
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-best.pth")
accelerator.print(f"new best_accuracy {best_accuracy}, {epoch=}")
if epoch % args.save_frequency == 0:
if args.model in ["dino", "vit"]:
accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
else:
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
accelerator.wait_for_everyone()
accelerator.wait_for_everyone()
wandb.finish()
if __name__ == "__main__":
main()