resnet-train / train.py
Sreekanth Tangirala
first commit
de2aabe
raw
history blame
3.62 kB
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from model import get_model, save_model
from tqdm import tqdm
def get_transforms():
"""
Define the image transformations
"""
return transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def get_data(subset_size=None):
"""
Load and prepare the dataset
Args:
subset_size (int): If provided, return only a subset of data
"""
transform = get_transforms()
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
if subset_size:
indices = torch.randperm(len(trainset))[:subset_size]
trainset = Subset(trainset, indices)
trainloader = DataLoader(
trainset,
batch_size=32,
shuffle=True,
num_workers=2
)
return trainloader
def train_model(model, trainloader, epochs=100, device='cuda'):
"""
Train the model
Args:
model: The ResNet50 model
trainloader: DataLoader for training data
epochs (int): Number of epochs to train
device (str): Device to train on ('cuda' or 'cpu')
"""
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
'max',
patience=5
)
best_acc = 0.0
# Create epoch progress bar
epoch_pbar = tqdm(range(epochs), desc='Training')
for epoch in epoch_pbar:
model.train()
running_loss = 0.0
correct = 0
total = 0
# Create batch progress bar
batch_pbar = tqdm(trainloader, leave=False, desc=f'Epoch {epoch+1}')
for inputs, labels in batch_pbar:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# Update batch progress bar
batch_pbar.set_postfix({'loss': f'{loss.item():.3f}'})
epoch_acc = 100. * correct / total
avg_loss = running_loss/len(trainloader)
# Update epoch progress bar
epoch_pbar.set_postfix({
'loss': f'{avg_loss:.3f}',
'accuracy': f'{epoch_acc:.2f}%'
})
scheduler.step(epoch_acc)
if epoch_acc > best_acc:
best_acc = epoch_acc
save_model(model, 'best_model.pth')
if epoch_acc > 70:
print(f"\nReached target accuracy of 70%!")
break
if __name__ == "__main__":
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Get data
trainloader = get_data(subset_size=5000) # Using subset for initial testing
# Initialize model
model = get_model(num_classes=10)
# Train model
train_model(model, trainloader, epochs=10, device=device)