imagenet-classification / train_resnet50.py
nikhiljais's picture
Renamed train file name
591bfc6
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import torch.nn.functional as F
import os
from tqdm import tqdm
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import json
from datetime import timedelta, datetime
import logging
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import socket
import argparse
import math
# Set random seeds for reproducibility
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Training configuration
class Config:
num_epochs = 150
batch_size = 512
learning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
num_workers = 16
subset_size = None
print_freq = 100
# Add gradient accumulation steps if needed
accum_iter = 1 # Can be increased if memory allows
# Add mixed precision training parameters
use_amp = True # Enable automatic mixed precision
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_data_loaders(subset_size=None, distributed=False, world_size=None, rank=None):
# ImageNet normalization values
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
# Modified data augmentation for training
train_transform = transforms.Compose(
[
transforms.RandomResizedCrop(224), # Removed interpolation and antialias
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
transforms.ToTensor(),
normalize,
transforms.RandomErasing(p=0.5), # Moved after ToTensor
]
)
# Modified transform for validation
val_transform = transforms.Compose(
[
transforms.Resize(256), # Removed antialias
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
)
training_folder_name = "ILSVRC/Data/CLS-LOC/train"
val_folder_name = "ILSVRC/Data/CLS-LOC/val"
train_dataset = torchvision.datasets.ImageFolder(
root=training_folder_name, transform=train_transform
)
val_dataset = torchvision.datasets.ImageFolder(
root=val_folder_name, transform=val_transform
)
# Create subset for initial testing
if subset_size:
train_indices = torch.randperm(len(train_dataset))[:subset_size]
val_indices = torch.randperm(len(val_dataset))[: subset_size // 10]
train_dataset = Subset(train_dataset, train_indices)
val_dataset = Subset(val_dataset, val_indices)
# Create samplers for distributed training
train_sampler = None
val_sampler = None
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=world_size, rank=rank
)
val_sampler = torch.utils.data.distributed.DistributedSampler(
val_dataset, num_replicas=world_size, rank=rank
)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=Config.batch_size,
shuffle=(train_sampler is None),
num_workers=Config.num_workers,
pin_memory=True,
sampler=train_sampler,
persistent_workers=True,
prefetch_factor=2,
)
val_loader = DataLoader(
val_dataset,
batch_size=Config.batch_size,
shuffle=False,
num_workers=Config.num_workers,
pin_memory=True,
sampler=val_sampler,
persistent_workers=True,
prefetch_factor=2,
)
return train_loader, val_loader, train_sampler
def train_epoch(model, train_loader, criterion, optimizer, epoch, device):
epoch_start_time = datetime.now()
model.train()
running_loss = 0.0
correct = 0
total = 0
# Create GradScaler for mixed precision training
scaler = torch.cuda.amp.GradScaler(enabled=Config.use_amp)
pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
optimizer.zero_grad()
for i, data in enumerate(pbar):
try:
images, targets = data
images, targets = images.to(device), targets.to(device)
# Mixed precision training
with torch.cuda.amp.autocast(enabled=Config.use_amp):
outputs = model(images)
loss = criterion(outputs, targets)
loss = (
loss / Config.accum_iter
) # Normalize loss for gradient accumulation
# Backward pass with gradient scaling
scaler.scale(loss).backward()
# Gradient accumulation
if ((i + 1) % Config.accum_iter == 0) or (i + 1 == len(train_loader)):
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
running_loss += loss.item() * Config.accum_iter
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if i % Config.print_freq == 0:
accuracy = 100.0 * correct / total
pbar.set_postfix(
{
"loss": running_loss / (i + 1),
"acc": f"{accuracy:.2f}%",
"lr": optimizer.param_groups[0]["lr"],
}
)
except Exception as e:
print(f"Error in batch {i}: {str(e)}")
continue
# Calculate epoch time and return metrics
epoch_time = datetime.now() - epoch_start_time
epoch_metrics = {
"time": epoch_time,
"loss": running_loss / len(train_loader),
"accuracy": 100.0 * correct / total,
}
return epoch_metrics
def validate(model, val_loader, criterion, device):
model.eval()
top1 = AverageMeter()
top5 = AverageMeter()
losses = AverageMeter()
with torch.no_grad(), torch.cuda.amp.autocast(enabled=Config.use_amp):
for images, targets in tqdm(val_loader, desc="Validating"):
images, targets = images.to(device), targets.to(device)
output = model(images)
loss = criterion(output, targets)
# Compute top-1 and top-5 accuracy
maxk = max((1, 5))
batch_size = targets.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(targets.view(1, -1).expand_as(pred))
# Top-1 accuracy
top1_acc = correct[0].float().sum() * 100.0 / batch_size
top1.update(top1_acc.item(), batch_size)
# Top-5 accuracy
top5_acc = correct[:5].float().sum() * 100.0 / batch_size
top5.update(top5_acc.item(), batch_size)
losses.update(loss.item(), batch_size)
return top1.avg, top5.avg, losses.avg
# Add ResNet building blocks
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(
out_channels, out_channels * self.expansion, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * self.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels * self.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(out_channels * self.expansion),
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=1000):
super().__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
# Initialize weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.maxpool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool(out)
out = torch.flatten(out, 1)
out = self.fc(out)
return out
# Replace the model creation in main() with this:
def create_resnet50():
return ResNet(Bottleneck, [3, 4, 6, 3])
# Add logging setup function
def setup_logging(log_dir):
# Create local log directory
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(os.path.join(log_dir, "training.log")),
logging.StreamHandler(),
],
)
return logging.getLogger(__name__)
# Add distributed training setup
def setup_distributed():
parser = argparse.ArgumentParser()
parser.add_argument("--nodes", type=int, default=1)
args = parser.parse_args()
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = "-1"
args.local_rank = int(os.environ["LOCAL_RANK"])
if "WORLD_SIZE" in os.environ:
args.world_size = int(os.environ["WORLD_SIZE"])
else:
args.world_size = args.nodes
return args
# Add this function to get dataset statistics
def get_dataset_stats(train_loader, val_loader):
stats = {
"num_train_samples": len(train_loader.dataset),
"num_val_samples": len(val_loader.dataset),
"num_classes": len(train_loader.dataset.dataset.classes)
if hasattr(train_loader.dataset, "dataset")
else len(train_loader.dataset.classes),
"batch_size": train_loader.batch_size,
"num_train_batches": len(train_loader),
"num_val_batches": len(val_loader),
"device_count": torch.cuda.device_count(),
"max_epochs": Config.num_epochs,
"learning_rate": Config.learning_rate,
"weight_decay": Config.weight_decay,
"num_workers": Config.num_workers,
}
# Get class distribution
if hasattr(train_loader.dataset, "dataset"):
# For subset dataset
classes = train_loader.dataset.dataset.classes
class_to_idx = train_loader.dataset.dataset.class_to_idx
else:
# For full dataset
classes = train_loader.dataset.classes
class_to_idx = train_loader.dataset.class_to_idx
stats["classes"] = classes
stats["class_to_idx"] = class_to_idx
return stats
# Modify the main function to support distributed training
def main():
start_time = datetime.now()
args = setup_distributed()
# Setup distributed training
if args.local_rank != -1:
torch.cuda.set_device(args.local_rank)
dist.init_process_group(
backend="nccl",
init_method="env://", # Use environment variables for initialization
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Setup logging and tensorboard
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = f"runs/resnet50_{timestamp}"
if args.local_rank <= 0: # Only create directories for master process
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir)
logger = setup_logging(log_dir)
logger.info(f"Starting training on {socket.gethostname()}")
logger.info(f"Available GPUs: {torch.cuda.device_count()}")
logger.info(f"Training started at: {start_time}")
set_seed()
# Create model
model = create_resnet50()
if args.local_rank != -1:
model = DDP(model.to(device), device_ids=[args.local_rank])
else:
model = torch.nn.DataParallel(model).to(device)
# Rest of your training setup
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.SGD(
model.parameters(),
lr=Config.learning_rate,
momentum=Config.momentum,
weight_decay=Config.weight_decay,
nesterov=True,
)
# Cosine annealing with warmup
warmup_epochs = 5
def warmup_lr_scheduler(epoch):
if epoch < warmup_epochs:
return epoch / warmup_epochs
return 0.5 * (
1
+ math.cos(
math.pi * (epoch - warmup_epochs) / (Config.num_epochs - warmup_epochs)
)
)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, warmup_lr_scheduler)
# Get data loaders with distributed sampler
train_loader, val_loader, train_sampler = get_data_loaders(
subset_size=Config.subset_size,
distributed=(args.local_rank != -1),
world_size=dist.get_world_size() if args.local_rank != -1 else None,
rank=args.local_rank if args.local_rank != -1 else None,
)
# Log dataset statistics
if args.local_rank <= 0:
dataset_stats = get_dataset_stats(train_loader, val_loader)
logger.info("Dataset Statistics:")
logger.info(f"Training samples: {dataset_stats['num_train_samples']}")
logger.info(f"Validation samples: {dataset_stats['num_val_samples']}")
logger.info(f"Number of classes: {dataset_stats['num_classes']}")
logger.info(f"Batch size: {dataset_stats['batch_size']}")
logger.info(f"Training batches per epoch: {dataset_stats['num_train_batches']}")
logger.info(f"Validation batches per epoch: {dataset_stats['num_val_batches']}")
best_acc = 0
# Training loop
total_training_time = timedelta()
# Training loop
for epoch in range(Config.num_epochs):
if args.local_rank <= 0:
logger.info(f"Starting epoch {epoch}")
if train_sampler is not None:
train_sampler.set_epoch(epoch)
# Train for one epoch and get metrics
train_metrics = train_epoch(
model, train_loader, criterion, optimizer, epoch, device
)
total_training_time += train_metrics["time"]
# train_epoch(model, train_loader, criterion, optimizer, epoch, device)
if args.local_rank <= 0: # Only validate on master process
# Log training metrics
logger.info(
f"Epoch {epoch} completed in {train_metrics['time']}, "
f"Training Loss: {train_metrics['loss']:.4f}, "
f"Training Accuracy: {train_metrics['accuracy']:.2f}%"
)
top1_acc, top5_acc, val_loss = validate(
model, val_loader, criterion, device
)
# Log validation metrics
logger.info(
f"Validation metrics - "
f"Top1 Acc: {top1_acc:.2f}%, "
f"Top5 Acc: {top5_acc:.2f}%, "
f"Val Loss: {val_loss:.4f}"
)
# Log cumulative time
logger.info(f"Total training time so far: {total_training_time}")
# Log to tensorboard
writer.add_scalar("Training/Loss", train_metrics["loss"], epoch)
writer.add_scalar("Training/Accuracy", train_metrics["accuracy"], epoch)
writer.add_scalar(
"Training/Time", train_metrics["time"].total_seconds(), epoch
)
writer.add_scalar("Validation/Top1_Accuracy", top1_acc, epoch)
writer.add_scalar("Validation/Top5_Accuracy", top5_acc, epoch)
writer.add_scalar("Validation/Loss", val_loss, epoch)
is_best = top1_acc > best_acc
best_acc = max(top1_acc, best_acc)
# Save checkpoint
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"best_acc": best_acc,
"top1_accuracy": top1_acc,
"top5_accuracy": top5_acc,
},
os.path.join(log_dir, "best_model.pth"),
)
if top1_acc >= 70.0:
logger.info(
f"\nTarget accuracy of 70% achieved! Current accuracy: {top1_acc:.2f}%"
)
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"best_acc": best_acc,
"top1_accuracy": top1_acc,
"top5_accuracy": top5_acc,
},
os.path.join(log_dir, "target_achieved_model.pth"),
)
# break
# Save metrics after each epoch
# with open(os.path.join(log_dir, "metrics.json"), "w") as f:
# json.dump(train_metrics, f, indent=4)
scheduler.step()
if args.local_rank <= 0:
end_time = datetime.now()
training_time = end_time - start_time
writer.close()
logger.info("\nTraining completed!")
logger.info(f"Total training time: {training_time}")
logger.info(f"Best Top-1 Accuracy: {train_metrics['best_top1_acc']:.2f}%")
logger.info(
f"Target accuracy of 70% {'achieved' if train_metrics['best_top1_acc'] >= 70.0 else 'not achieved'}"
)
if __name__ == "__main__":
main()