Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
from torch.utils.data import DataLoader, Subset | |
from torchvision import datasets, transforms | |
import torch.nn.functional as F | |
import os | |
from tqdm import tqdm | |
import random | |
import numpy as np | |
from torch.utils.tensorboard import SummaryWriter | |
import json | |
from datetime import timedelta, datetime | |
import logging | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import socket | |
import argparse | |
import math | |
# Set random seeds for reproducibility | |
def set_seed(seed=42): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
# Training configuration | |
class Config: | |
num_epochs = 150 | |
batch_size = 512 | |
learning_rate = 0.1 | |
momentum = 0.9 | |
weight_decay = 1e-4 | |
num_workers = 16 | |
subset_size = None | |
print_freq = 100 | |
# Add gradient accumulation steps if needed | |
accum_iter = 1 # Can be increased if memory allows | |
# Add mixed precision training parameters | |
use_amp = True # Enable automatic mixed precision | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def get_data_loaders(subset_size=None, distributed=False, world_size=None, rank=None): | |
# ImageNet normalization values | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
) | |
# Modified data augmentation for training | |
train_transform = transforms.Compose( | |
[ | |
transforms.RandomResizedCrop(224), # Removed interpolation and antialias | |
transforms.RandomHorizontalFlip(), | |
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), | |
transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)), | |
transforms.ToTensor(), | |
normalize, | |
transforms.RandomErasing(p=0.5), # Moved after ToTensor | |
] | |
) | |
# Modified transform for validation | |
val_transform = transforms.Compose( | |
[ | |
transforms.Resize(256), # Removed antialias | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
normalize, | |
] | |
) | |
training_folder_name = "ILSVRC/Data/CLS-LOC/train" | |
val_folder_name = "ILSVRC/Data/CLS-LOC/val" | |
train_dataset = torchvision.datasets.ImageFolder( | |
root=training_folder_name, transform=train_transform | |
) | |
val_dataset = torchvision.datasets.ImageFolder( | |
root=val_folder_name, transform=val_transform | |
) | |
# Create subset for initial testing | |
if subset_size: | |
train_indices = torch.randperm(len(train_dataset))[:subset_size] | |
val_indices = torch.randperm(len(val_dataset))[: subset_size // 10] | |
train_dataset = Subset(train_dataset, train_indices) | |
val_dataset = Subset(val_dataset, val_indices) | |
# Create samplers for distributed training | |
train_sampler = None | |
val_sampler = None | |
if distributed: | |
train_sampler = torch.utils.data.distributed.DistributedSampler( | |
train_dataset, num_replicas=world_size, rank=rank | |
) | |
val_sampler = torch.utils.data.distributed.DistributedSampler( | |
val_dataset, num_replicas=world_size, rank=rank | |
) | |
# Create data loaders | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=Config.batch_size, | |
shuffle=(train_sampler is None), | |
num_workers=Config.num_workers, | |
pin_memory=True, | |
sampler=train_sampler, | |
persistent_workers=True, | |
prefetch_factor=2, | |
) | |
val_loader = DataLoader( | |
val_dataset, | |
batch_size=Config.batch_size, | |
shuffle=False, | |
num_workers=Config.num_workers, | |
pin_memory=True, | |
sampler=val_sampler, | |
persistent_workers=True, | |
prefetch_factor=2, | |
) | |
return train_loader, val_loader, train_sampler | |
def train_epoch(model, train_loader, criterion, optimizer, epoch, device): | |
epoch_start_time = datetime.now() | |
model.train() | |
running_loss = 0.0 | |
correct = 0 | |
total = 0 | |
# Create GradScaler for mixed precision training | |
scaler = torch.cuda.amp.GradScaler(enabled=Config.use_amp) | |
pbar = tqdm(train_loader, desc=f"Epoch {epoch}") | |
optimizer.zero_grad() | |
for i, data in enumerate(pbar): | |
try: | |
images, targets = data | |
images, targets = images.to(device), targets.to(device) | |
# Mixed precision training | |
with torch.cuda.amp.autocast(enabled=Config.use_amp): | |
outputs = model(images) | |
loss = criterion(outputs, targets) | |
loss = ( | |
loss / Config.accum_iter | |
) # Normalize loss for gradient accumulation | |
# Backward pass with gradient scaling | |
scaler.scale(loss).backward() | |
# Gradient accumulation | |
if ((i + 1) % Config.accum_iter == 0) or (i + 1 == len(train_loader)): | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
scaler.step(optimizer) | |
scaler.update() | |
optimizer.zero_grad() | |
running_loss += loss.item() * Config.accum_iter | |
_, predicted = outputs.max(1) | |
total += targets.size(0) | |
correct += predicted.eq(targets).sum().item() | |
if i % Config.print_freq == 0: | |
accuracy = 100.0 * correct / total | |
pbar.set_postfix( | |
{ | |
"loss": running_loss / (i + 1), | |
"acc": f"{accuracy:.2f}%", | |
"lr": optimizer.param_groups[0]["lr"], | |
} | |
) | |
except Exception as e: | |
print(f"Error in batch {i}: {str(e)}") | |
continue | |
# Calculate epoch time and return metrics | |
epoch_time = datetime.now() - epoch_start_time | |
epoch_metrics = { | |
"time": epoch_time, | |
"loss": running_loss / len(train_loader), | |
"accuracy": 100.0 * correct / total, | |
} | |
return epoch_metrics | |
def validate(model, val_loader, criterion, device): | |
model.eval() | |
top1 = AverageMeter() | |
top5 = AverageMeter() | |
losses = AverageMeter() | |
with torch.no_grad(), torch.cuda.amp.autocast(enabled=Config.use_amp): | |
for images, targets in tqdm(val_loader, desc="Validating"): | |
images, targets = images.to(device), targets.to(device) | |
output = model(images) | |
loss = criterion(output, targets) | |
# Compute top-1 and top-5 accuracy | |
maxk = max((1, 5)) | |
batch_size = targets.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(targets.view(1, -1).expand_as(pred)) | |
# Top-1 accuracy | |
top1_acc = correct[0].float().sum() * 100.0 / batch_size | |
top1.update(top1_acc.item(), batch_size) | |
# Top-5 accuracy | |
top5_acc = correct[:5].float().sum() * 100.0 / batch_size | |
top5.update(top5_acc.item(), batch_size) | |
losses.update(loss.item(), batch_size) | |
return top1.avg, top5.avg, losses.avg | |
# Add ResNet building blocks | |
class BasicBlock(nn.Module): | |
expansion = 1 | |
def __init__(self, in_channels, out_channels, stride=1): | |
super().__init__() | |
self.conv1 = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
bias=False, | |
) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.conv2 = nn.Conv2d( | |
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False | |
) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.shortcut = nn.Sequential() | |
if stride != 1 or in_channels != out_channels: | |
self.shortcut = nn.Sequential( | |
nn.Conv2d( | |
in_channels, out_channels, kernel_size=1, stride=stride, bias=False | |
), | |
nn.BatchNorm2d(out_channels), | |
) | |
def forward(self, x): | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out += self.shortcut(x) | |
out = F.relu(out) | |
return out | |
class Bottleneck(nn.Module): | |
expansion = 4 | |
def __init__(self, in_channels, out_channels, stride=1): | |
super().__init__() | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.conv2 = nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
bias=False, | |
) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.conv3 = nn.Conv2d( | |
out_channels, out_channels * self.expansion, kernel_size=1, bias=False | |
) | |
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) | |
self.shortcut = nn.Sequential() | |
if stride != 1 or in_channels != out_channels * self.expansion: | |
self.shortcut = nn.Sequential( | |
nn.Conv2d( | |
in_channels, | |
out_channels * self.expansion, | |
kernel_size=1, | |
stride=stride, | |
bias=False, | |
), | |
nn.BatchNorm2d(out_channels * self.expansion), | |
) | |
def forward(self, x): | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = F.relu(self.bn2(self.conv2(out))) | |
out = self.bn3(self.conv3(out)) | |
out += self.shortcut(x) | |
out = F.relu(out) | |
return out | |
class ResNet(nn.Module): | |
def __init__(self, block, num_blocks, num_classes=1000): | |
super().__init__() | |
self.in_channels = 64 | |
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) | |
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) | |
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) | |
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) | |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
self.fc = nn.Linear(512 * block.expansion, num_classes) | |
# Initialize weights | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
def _make_layer(self, block, out_channels, num_blocks, stride): | |
strides = [stride] + [1] * (num_blocks - 1) | |
layers = [] | |
for stride in strides: | |
layers.append(block(self.in_channels, out_channels, stride)) | |
self.in_channels = out_channels * block.expansion | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.maxpool(out) | |
out = self.layer1(out) | |
out = self.layer2(out) | |
out = self.layer3(out) | |
out = self.layer4(out) | |
out = self.avgpool(out) | |
out = torch.flatten(out, 1) | |
out = self.fc(out) | |
return out | |
# Replace the model creation in main() with this: | |
def create_resnet50(): | |
return ResNet(Bottleneck, [3, 4, 6, 3]) | |
# Add logging setup function | |
def setup_logging(log_dir): | |
# Create local log directory | |
os.makedirs(log_dir, exist_ok=True) | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
handlers=[ | |
logging.FileHandler(os.path.join(log_dir, "training.log")), | |
logging.StreamHandler(), | |
], | |
) | |
return logging.getLogger(__name__) | |
# Add distributed training setup | |
def setup_distributed(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--nodes", type=int, default=1) | |
args = parser.parse_args() | |
if "LOCAL_RANK" not in os.environ: | |
os.environ["LOCAL_RANK"] = "-1" | |
args.local_rank = int(os.environ["LOCAL_RANK"]) | |
if "WORLD_SIZE" in os.environ: | |
args.world_size = int(os.environ["WORLD_SIZE"]) | |
else: | |
args.world_size = args.nodes | |
return args | |
# Add this function to get dataset statistics | |
def get_dataset_stats(train_loader, val_loader): | |
stats = { | |
"num_train_samples": len(train_loader.dataset), | |
"num_val_samples": len(val_loader.dataset), | |
"num_classes": len(train_loader.dataset.dataset.classes) | |
if hasattr(train_loader.dataset, "dataset") | |
else len(train_loader.dataset.classes), | |
"batch_size": train_loader.batch_size, | |
"num_train_batches": len(train_loader), | |
"num_val_batches": len(val_loader), | |
"device_count": torch.cuda.device_count(), | |
"max_epochs": Config.num_epochs, | |
"learning_rate": Config.learning_rate, | |
"weight_decay": Config.weight_decay, | |
"num_workers": Config.num_workers, | |
} | |
# Get class distribution | |
if hasattr(train_loader.dataset, "dataset"): | |
# For subset dataset | |
classes = train_loader.dataset.dataset.classes | |
class_to_idx = train_loader.dataset.dataset.class_to_idx | |
else: | |
# For full dataset | |
classes = train_loader.dataset.classes | |
class_to_idx = train_loader.dataset.class_to_idx | |
stats["classes"] = classes | |
stats["class_to_idx"] = class_to_idx | |
return stats | |
# Modify the main function to support distributed training | |
def main(): | |
start_time = datetime.now() | |
args = setup_distributed() | |
# Setup distributed training | |
if args.local_rank != -1: | |
torch.cuda.set_device(args.local_rank) | |
dist.init_process_group( | |
backend="nccl", | |
init_method="env://", # Use environment variables for initialization | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Setup logging and tensorboard | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
log_dir = f"runs/resnet50_{timestamp}" | |
if args.local_rank <= 0: # Only create directories for master process | |
os.makedirs(log_dir, exist_ok=True) | |
writer = SummaryWriter(log_dir) | |
logger = setup_logging(log_dir) | |
logger.info(f"Starting training on {socket.gethostname()}") | |
logger.info(f"Available GPUs: {torch.cuda.device_count()}") | |
logger.info(f"Training started at: {start_time}") | |
set_seed() | |
# Create model | |
model = create_resnet50() | |
if args.local_rank != -1: | |
model = DDP(model.to(device), device_ids=[args.local_rank]) | |
else: | |
model = torch.nn.DataParallel(model).to(device) | |
# Rest of your training setup | |
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) | |
optimizer = optim.SGD( | |
model.parameters(), | |
lr=Config.learning_rate, | |
momentum=Config.momentum, | |
weight_decay=Config.weight_decay, | |
nesterov=True, | |
) | |
# Cosine annealing with warmup | |
warmup_epochs = 5 | |
def warmup_lr_scheduler(epoch): | |
if epoch < warmup_epochs: | |
return epoch / warmup_epochs | |
return 0.5 * ( | |
1 | |
+ math.cos( | |
math.pi * (epoch - warmup_epochs) / (Config.num_epochs - warmup_epochs) | |
) | |
) | |
scheduler = optim.lr_scheduler.LambdaLR(optimizer, warmup_lr_scheduler) | |
# Get data loaders with distributed sampler | |
train_loader, val_loader, train_sampler = get_data_loaders( | |
subset_size=Config.subset_size, | |
distributed=(args.local_rank != -1), | |
world_size=dist.get_world_size() if args.local_rank != -1 else None, | |
rank=args.local_rank if args.local_rank != -1 else None, | |
) | |
# Log dataset statistics | |
if args.local_rank <= 0: | |
dataset_stats = get_dataset_stats(train_loader, val_loader) | |
logger.info("Dataset Statistics:") | |
logger.info(f"Training samples: {dataset_stats['num_train_samples']}") | |
logger.info(f"Validation samples: {dataset_stats['num_val_samples']}") | |
logger.info(f"Number of classes: {dataset_stats['num_classes']}") | |
logger.info(f"Batch size: {dataset_stats['batch_size']}") | |
logger.info(f"Training batches per epoch: {dataset_stats['num_train_batches']}") | |
logger.info(f"Validation batches per epoch: {dataset_stats['num_val_batches']}") | |
best_acc = 0 | |
# Training loop | |
total_training_time = timedelta() | |
# Training loop | |
for epoch in range(Config.num_epochs): | |
if args.local_rank <= 0: | |
logger.info(f"Starting epoch {epoch}") | |
if train_sampler is not None: | |
train_sampler.set_epoch(epoch) | |
# Train for one epoch and get metrics | |
train_metrics = train_epoch( | |
model, train_loader, criterion, optimizer, epoch, device | |
) | |
total_training_time += train_metrics["time"] | |
# train_epoch(model, train_loader, criterion, optimizer, epoch, device) | |
if args.local_rank <= 0: # Only validate on master process | |
# Log training metrics | |
logger.info( | |
f"Epoch {epoch} completed in {train_metrics['time']}, " | |
f"Training Loss: {train_metrics['loss']:.4f}, " | |
f"Training Accuracy: {train_metrics['accuracy']:.2f}%" | |
) | |
top1_acc, top5_acc, val_loss = validate( | |
model, val_loader, criterion, device | |
) | |
# Log validation metrics | |
logger.info( | |
f"Validation metrics - " | |
f"Top1 Acc: {top1_acc:.2f}%, " | |
f"Top5 Acc: {top5_acc:.2f}%, " | |
f"Val Loss: {val_loss:.4f}" | |
) | |
# Log cumulative time | |
logger.info(f"Total training time so far: {total_training_time}") | |
# Log to tensorboard | |
writer.add_scalar("Training/Loss", train_metrics["loss"], epoch) | |
writer.add_scalar("Training/Accuracy", train_metrics["accuracy"], epoch) | |
writer.add_scalar( | |
"Training/Time", train_metrics["time"].total_seconds(), epoch | |
) | |
writer.add_scalar("Validation/Top1_Accuracy", top1_acc, epoch) | |
writer.add_scalar("Validation/Top5_Accuracy", top5_acc, epoch) | |
writer.add_scalar("Validation/Loss", val_loss, epoch) | |
is_best = top1_acc > best_acc | |
best_acc = max(top1_acc, best_acc) | |
# Save checkpoint | |
torch.save( | |
{ | |
"epoch": epoch, | |
"model_state_dict": model.state_dict(), | |
"optimizer_state_dict": optimizer.state_dict(), | |
"scheduler_state_dict": scheduler.state_dict(), | |
"best_acc": best_acc, | |
"top1_accuracy": top1_acc, | |
"top5_accuracy": top5_acc, | |
}, | |
os.path.join(log_dir, "best_model.pth"), | |
) | |
if top1_acc >= 70.0: | |
logger.info( | |
f"\nTarget accuracy of 70% achieved! Current accuracy: {top1_acc:.2f}%" | |
) | |
torch.save( | |
{ | |
"epoch": epoch, | |
"model_state_dict": model.state_dict(), | |
"optimizer_state_dict": optimizer.state_dict(), | |
"best_acc": best_acc, | |
"top1_accuracy": top1_acc, | |
"top5_accuracy": top5_acc, | |
}, | |
os.path.join(log_dir, "target_achieved_model.pth"), | |
) | |
# break | |
# Save metrics after each epoch | |
# with open(os.path.join(log_dir, "metrics.json"), "w") as f: | |
# json.dump(train_metrics, f, indent=4) | |
scheduler.step() | |
if args.local_rank <= 0: | |
end_time = datetime.now() | |
training_time = end_time - start_time | |
writer.close() | |
logger.info("\nTraining completed!") | |
logger.info(f"Total training time: {training_time}") | |
logger.info(f"Best Top-1 Accuracy: {train_metrics['best_top1_acc']:.2f}%") | |
logger.info( | |
f"Target accuracy of 70% {'achieved' if train_metrics['best_top1_acc'] >= 70.0 else 'not achieved'}" | |
) | |
if __name__ == "__main__": | |
main() | |