torinriley's picture
Upload 4 files
29a4de2 verified
import os
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.models import resnet18
# Define SSLModel with ResNet-18 backbone
class SSLModel(nn.Module):
def __init__(self, backbone, projection_dim=128):
super(SSLModel, self).__init__()
self.backbone = backbone
self.projection_head = nn.Sequential(
nn.Linear(backbone.fc.in_features, 512),
nn.ReLU(),
nn.Linear(512, projection_dim)
)
self.backbone.fc = nn.Identity() # Remove classification head
def forward(self, x):
features = self.backbone(x)
projections = self.projection_head(features)
return projections
# Contrastive Loss
def contrastive_loss(z_i, z_j, temperature=0.5):
batch_size = z_i.shape[0]
# Concatenate both views
z = torch.cat([z_i, z_j], dim=0) # Shape: (2 * batch_size, projection_dim)
# Similarity matrix computation (dot product normalized by temperature)
sim_matrix = torch.mm(z, z.T) / temperature # Shape: (2 * batch_size, 2 * batch_size)
# Normalize to prevent instability
sim_matrix = sim_matrix - torch.max(sim_matrix, dim=1, keepdim=True)[0]
# Mask out self-similarity
mask = torch.eye(sim_matrix.size(0), device=sim_matrix.device).bool()
sim_matrix = sim_matrix.masked_fill(mask, -float("inf"))
# Extract positive similarities (z_i, z_j) and (z_j, z_i)
pos_sim = torch.cat([
torch.diag(sim_matrix, sim_matrix.size(0) // 2),
torch.diag(sim_matrix, -sim_matrix.size(0) // 2)
])
# Compute contrastive loss
loss = -torch.log(torch.exp(pos_sim) / torch.sum(torch.exp(sim_matrix), dim=1))
return loss.mean()
def train_ssl():
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Data Transformations
transform = T.Compose([
T.RandomResizedCrop(32),
T.RandomHorizontalFlip(),
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
T.RandomGrayscale(p=0.2),
T.GaussianBlur(kernel_size=3),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
])
# Load Dataset
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True, num_workers=4)
# Initialize Model
model = SSLModel(resnet18(pretrained=False)).to(device)
# Optimizer and Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
# Resume Training (if checkpoint exists)
start_epoch = 1
checkpoint_path = "models/saves/run2/ssl_checkpoint_epoch_14.pth"
if os.path.exists(checkpoint_path):
print(f"Resuming training from checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"] + 1
# Create "checkpoints" directory if it doesn't exist
os.makedirs("checkpoints", exist_ok=True)
# Training Loop
model.train()
total_epochs = 15 # Adjust based on the training plan
for epoch in range(start_epoch, total_epochs + 1):
epoch_loss = 0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{total_epochs}", unit="batch")
for batch in progress_bar:
imgs, _ = batch
imgs = imgs.to(device, non_blocking=True)
# Create two augmented views
z_i = model(imgs)
z_j = model(imgs)
# Validate embeddings
assert not torch.isnan(z_i).any(), "z_i contains NaN values!"
assert not torch.isnan(z_j).any(), "z_j contains NaN values!"
try:
loss = contrastive_loss(z_i, z_j)
except Exception as e:
print(f"Loss computation failed: {e}")
continue
optimizer.zero_grad()
loss.backward()
# Clip Gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Accumulate epoch loss
epoch_loss += loss.item()
progress_bar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
print(f"Epoch {epoch}, Average Loss: {epoch_loss / len(train_loader):.4f}")
# Save checkpoint
save_path = f"checkpoints/ssl_checkpoint_epoch_{epoch}.pth"
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}, save_path)
print(f"Model saved to {save_path}")
if __name__ == "__main__":
train_ssl()