Spaces:
Runtime error
Runtime error
from typing import Tuple | |
from pathlib import Path | |
from tqdm import tqdm | |
from accelerate import Accelerator | |
from accelerate.utils import set_seed | |
from matplotlib import cm | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import argparse | |
import json | |
import wandb | |
from datasets import load_dataset | |
import torch | |
from torch.utils.data import DataLoader | |
from models.unet import UNet | |
from dataset import SegmentationDataset, collate_fn | |
from utils import get_transform, mask_transform, EMA | |
from get_loss import get_composite_criterion | |
from models.vit import ViTSegmentation | |
from models.dino import DINOSegmentationModel | |
color_map = cm.get_cmap('tab20', 9) | |
fixed_colors = np.array([color_map(i)[:3] for i in range(9)]) * 255 | |
def mask_to_color(mask: np.ndarray) -> np.ndarray: | |
h, w = mask.shape | |
color_mask = np.zeros((h, w, 3), dtype=np.uint8) | |
for class_idx in range(9): | |
color_mask[mask == class_idx] = fixed_colors[class_idx] | |
return color_mask | |
def create_combined_image( | |
x: torch.Tensor, | |
y: torch.Tensor, | |
y_pred: torch.Tensor, | |
mean: list[float] = [0.485, 0.456, 0.406], | |
std: list[float] = [0.229, 0.224, 0.225] | |
) -> np.ndarray: | |
batch_size, _, height, width = x.shape | |
combined_height = height * 3 | |
combined_width = width * batch_size | |
combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8) | |
for i in range(batch_size): | |
image = x[i].cpu().permute(1, 2, 0).numpy() | |
image = (image * std + mean).clip(0, 1) | |
image = (image * 255).astype(np.uint8) | |
true_mask = y[i].cpu().numpy() | |
true_mask_color = mask_to_color(true_mask) | |
pred_mask = y_pred[i].cpu().numpy() | |
pred_mask_color = mask_to_color(pred_mask) | |
combined_image[:height, i * width:(i + 1) * width, :] = image | |
combined_image[height:2 * height, i * width:(i + 1) * width, :] = true_mask_color | |
combined_image[2 * height:, i * width:(i + 1) * width, :] = pred_mask_color | |
return combined_image | |
def compute_metrics(y_pred: torch.Tensor, y: torch.Tensor, num_classes: int = 9) -> Tuple[float, float, float, float, float, float]: | |
pred_mask = y_pred.unsqueeze(-1) == torch.arange(num_classes, device=y_pred.device).reshape(1, 1, 1, -1) | |
target_mask = y.unsqueeze(-1) == torch.arange(num_classes, device=y.device).reshape(1, 1, 1, -1) | |
class_present = (target_mask.sum(dim=(0, 1, 2)) > 0).float() | |
tp = (pred_mask & target_mask).sum(dim=(0, 1, 2)).float() | |
fp = (pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float() | |
fn = (~pred_mask & target_mask).sum(dim=(0, 1, 2)).float() | |
tn = (~pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float() | |
overall_tp = tp.sum() | |
overall_fp = fp.sum() | |
overall_fn = fn.sum() | |
overall_tn = tn.sum() | |
precision = tp / (tp + fp).clamp(min=1e-8) | |
recall = tp / (tp + fn).clamp(min=1e-8) | |
accuracy = (tp + tn) / (tp + tn + fp + fn) | |
macro_precision = ((precision * class_present).sum() / class_present.sum().clamp(min=1e-8)).item() | |
macro_recall = ((recall * class_present).sum() / class_present.sum().clamp(min=1e-8)).item() | |
macro_accuracy = accuracy.mean().item() | |
micro_precision = (overall_tp / (overall_tp + overall_fp).clamp(min=1e-8)).item() | |
micro_recall = (overall_tp / (overall_tp + overall_fn).clamp(min=1e-8)).item() | |
global_accuracy = ((y_pred == y).sum() / (y.shape[0] * y.shape[1] * y.shape[2])).item() | |
return macro_precision, macro_recall, macro_accuracy, micro_precision, micro_recall, global_accuracy | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Train a model on human parsing dataset") | |
parser.add_argument("--data-path", type=str, default="data/portraits", help="Path to the data") | |
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing") | |
parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader") | |
parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader") | |
parser.add_argument("--num-epochs", type=int, default=15, help="Number of training epochs") | |
parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type") | |
parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer") | |
parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping") | |
parser.add_argument("--logs-dir", type=str, default="unet-logs", help="Directory for saving logs") | |
parser.add_argument("--model", type=str, default="unet", choices=["unet", "vit", "dino"], help="Model class name") | |
parser.add_argument("--losses-path", type=str, default="losses_config.json", help="Path to the losses") | |
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision") | |
parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps") | |
parser.add_argument("--project-name", type=str, default="face_segmentation_unet", help="WandB project name") | |
parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights") | |
parser.add_argument("--log-steps", type=int, default=400, help="Number of steps for logging images") | |
parser.add_argument("--seed", type=int, default=42, help="Value of the seed") | |
return parser.parse_args() | |
def main() -> None: | |
args = parse_args() | |
set_seed(args.seed) | |
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision) | |
with open(args.losses_path, "r") as fp: | |
losses_config = json.load(fp) | |
with accelerator.main_process_first(): | |
logs_dir = Path(args.logs_dir) | |
logs_dir.mkdir(exist_ok=True) | |
wandb.init(project=args.project_name, dir=logs_dir) | |
wandb.save(args.losses_path) | |
optimizer_class = getattr(torch.optim, args.optimizer) | |
if args.model == "unet": | |
model = UNet().to(accelerator.device) | |
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate) | |
elif args.model == "vit": | |
model = ViTSegmentation().to(accelerator.device) | |
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate) | |
elif args.model == "dino": | |
model = DINOSegmentationModel().to(accelerator.device) | |
optimizer = optimizer_class(model.segmentation_head.parameters(), lr=args.learning_rate) | |
else: | |
raise NotImplementedError("Incorrect model name") | |
transform = get_transform(model.mean, model.std) | |
train_dataset = SegmentationDataset(args.data_path, subset="train", transform=transform, target_transform=mask_transform) | |
valid_dataset = SegmentationDataset(args.data_path, subset="test", transform=transform, target_transform=mask_transform) | |
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn) | |
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn) | |
criterion = get_composite_criterion(losses_config) | |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader)) | |
model, optimizer, train_loader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader, lr_scheduler) | |
best_accuracy = 0 | |
print(f"params: {sum([p.numel() for p in model.parameters()])/1e6:.2f} M") | |
print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M") | |
train_loss_ema, train_macro_precision_ema, train_macro_recall_ema, train_macro_accuracy_ema, train_micro_precision_ema, train_micro_recall_ema, train_global_accuracy_ema = EMA(), EMA(), EMA(), EMA(), EMA(), EMA(), EMA() | |
for epoch in range(1, args.num_epochs + 1): | |
model.train() | |
pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}") | |
for index, (x, y) in enumerate(pbar): | |
x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device) | |
with accelerator.accumulate(model): | |
with accelerator.autocast(): | |
output = model(x) | |
loss = criterion(output, y) | |
accelerator.backward(loss) | |
train_loss = loss.item() | |
grad_norm = None | |
_, y_pred = output.max(dim=1) | |
train_macro_precision, train_macro_recall, train_macro_accuracy, train_micro_precision, train_micro_recall, train_global_accuracy = compute_metrics(y_pred, y) | |
if accelerator.sync_gradients: | |
grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item() | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process: | |
images_to_log = [] | |
combined_image = create_combined_image(x, y, y_pred, model.mean, model.std) | |
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Train, Epoch {epoch}, Batch {index})")) | |
wandb.log({"train_samples": images_to_log}) | |
pbar.set_postfix({"loss": train_loss_ema(train_loss), "macro_precision": train_macro_precision_ema(train_macro_precision), "macro_recall": train_macro_recall_ema(train_macro_recall), "macro_accuracy": train_macro_accuracy_ema(train_macro_accuracy), "micro_precision": train_micro_precision_ema(train_micro_precision), "micro_recall": train_micro_recall_ema(train_micro_recall), "global_accuracy": train_global_accuracy_ema(train_global_accuracy)}) | |
log_data = { | |
"train/epoch": epoch, | |
"train/loss": train_loss, | |
"train/macro_accuracy": train_macro_accuracy, | |
"train/learning_rate": optimizer.param_groups[0]["lr"], | |
"train/macro_precision": train_macro_precision, | |
"train/macro_recall": train_macro_recall, | |
"train/micro_precision": train_micro_precision, | |
"train/micro_recall": train_micro_recall, | |
"train/global_accuracy": train_global_accuracy, | |
} | |
if grad_norm is not None: | |
log_data["train/grad_norm"] = grad_norm | |
if accelerator.is_main_process: | |
wandb.log(log_data) | |
accelerator.wait_for_everyone() | |
model.eval() | |
valid_loss, valid_macro_accuracies, valid_macro_precisions, valid_macro_recalls, valid_global_accuracies, valid_micro_precisions, valid_micro_recalls = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 | |
with torch.inference_mode(): | |
pbar = tqdm(valid_loader, desc=f"Val epoch {epoch}/{args.num_epochs}") | |
for index, (x, y) in enumerate(valid_loader): | |
x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device) | |
output = model(x) | |
_, y_pred = output.max(dim=1) | |
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process: | |
images_to_log = [] | |
combined_image = create_combined_image(x, y, y_pred, model.mean, model.std) | |
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Validation, Epoch {epoch})")) | |
wandb.log({"valid_samples": images_to_log}) | |
valid_macro_precision, valid_macro_recall, valid_macro_accuracy, valid_micro_precision, valid_micro_recall, valid_global_accuracy = compute_metrics(y_pred, y) | |
valid_macro_precisions += valid_macro_precision | |
valid_macro_recalls += valid_macro_recall | |
valid_macro_accuracies += valid_macro_accuracy | |
valid_micro_precisions += valid_micro_precision | |
valid_micro_recalls += valid_micro_recall | |
valid_global_accuracies += valid_global_accuracy | |
loss = criterion(output, y) | |
valid_loss += loss.item() | |
valid_loss = valid_loss / len(valid_loader) | |
valid_macro_accuracies = valid_macro_accuracies / len(valid_loader) | |
valid_macro_precisions = valid_macro_precisions / len(valid_loader) | |
valid_macro_recalls = valid_macro_recalls / len(valid_loader) | |
valid_global_accuracies = valid_global_accuracies / len(valid_loader) | |
valid_micro_precisions = valid_micro_precisions / len(valid_loader) | |
valid_micro_recalls = valid_micro_recalls / len(valid_loader) | |
accelerator.print(f"loss: {valid_loss:.3f}, valid_macro_precision: {valid_macro_precisions:.3f}, valid_macro_recall: {valid_macro_recalls:.3f}, valid_macro_accuracy: {valid_macro_accuracies:.3f}, valid_micro_precision: {valid_micro_precisions:.3f}, valid_micro_recall: {valid_micro_recalls:.3f}, valid_global_accuracy: {valid_global_accuracies:.3f}") | |
if accelerator.is_main_process: | |
wandb.log( | |
{ | |
"val/epoch": epoch, | |
"val/loss": valid_loss, | |
"val/macro_accuracy": valid_macro_accuracies, | |
"val/macro_precision": valid_macro_precisions, | |
"val/macro_recall": valid_macro_recalls, | |
"val/global_accuracy": valid_global_accuracies, | |
"val/micro_precision": valid_micro_precisions, | |
"val/micro_recall": valid_micro_recalls, | |
} | |
) | |
if valid_global_accuracies > best_accuracy: | |
best_accuracy = valid_global_accuracies | |
if args.model in ["dino", "vit"]: | |
accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-best.pth") | |
else: | |
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-best.pth") | |
accelerator.print(f"new best_accuracy {best_accuracy}, {epoch=}") | |
if epoch % args.save_frequency == 0: | |
if args.model in ["dino", "vit"]: | |
accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth") | |
else: | |
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth") | |
accelerator.wait_for_everyone() | |
accelerator.wait_for_everyone() | |
wandb.finish() | |
if __name__ == "__main__": | |
main() |