File size: 5,062 Bytes
29a4de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.models import resnet18


# Define SSLModel with ResNet-18 backbone
class SSLModel(nn.Module):
    def __init__(self, backbone, projection_dim=128):
        super(SSLModel, self).__init__()
        self.backbone = backbone
        self.projection_head = nn.Sequential(
            nn.Linear(backbone.fc.in_features, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )
        self.backbone.fc = nn.Identity()  # Remove classification head

    def forward(self, x):
        features = self.backbone(x)
        projections = self.projection_head(features)
        return projections


# Contrastive Loss
def contrastive_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.shape[0]

    # Concatenate both views
    z = torch.cat([z_i, z_j], dim=0)  # Shape: (2 * batch_size, projection_dim)

    # Similarity matrix computation (dot product normalized by temperature)
    sim_matrix = torch.mm(z, z.T) / temperature  # Shape: (2 * batch_size, 2 * batch_size)

    # Normalize to prevent instability
    sim_matrix = sim_matrix - torch.max(sim_matrix, dim=1, keepdim=True)[0]

    # Mask out self-similarity
    mask = torch.eye(sim_matrix.size(0), device=sim_matrix.device).bool()
    sim_matrix = sim_matrix.masked_fill(mask, -float("inf"))

    # Extract positive similarities (z_i, z_j) and (z_j, z_i)
    pos_sim = torch.cat([
        torch.diag(sim_matrix, sim_matrix.size(0) // 2),
        torch.diag(sim_matrix, -sim_matrix.size(0) // 2)
    ])

    # Compute contrastive loss
    loss = -torch.log(torch.exp(pos_sim) / torch.sum(torch.exp(sim_matrix), dim=1))
    return loss.mean()


def train_ssl():
    # Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Data Transformations
    transform = T.Compose([
        T.RandomResizedCrop(32),
        T.RandomHorizontalFlip(),
        T.ColorJitter(0.4, 0.4, 0.4, 0.1),
        T.RandomGrayscale(p=0.2),
        T.GaussianBlur(kernel_size=3),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
    ])

    # Load Dataset
    train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True, num_workers=4)

    # Initialize Model
    model = SSLModel(resnet18(pretrained=False)).to(device)

    # Optimizer and Scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

    # Resume Training (if checkpoint exists)
    start_epoch = 1
    checkpoint_path = "models/saves/run2/ssl_checkpoint_epoch_14.pth"
    if os.path.exists(checkpoint_path):
        print(f"Resuming training from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1

    # Create "checkpoints" directory if it doesn't exist
    os.makedirs("checkpoints", exist_ok=True)

    # Training Loop
    model.train()
    total_epochs = 15  # Adjust based on the training plan
    for epoch in range(start_epoch, total_epochs + 1):
        epoch_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{total_epochs}", unit="batch")

        for batch in progress_bar:
            imgs, _ = batch
            imgs = imgs.to(device, non_blocking=True)

            # Create two augmented views
            z_i = model(imgs)
            z_j = model(imgs)

            # Validate embeddings
            assert not torch.isnan(z_i).any(), "z_i contains NaN values!"
            assert not torch.isnan(z_j).any(), "z_j contains NaN values!"

            try:
                loss = contrastive_loss(z_i, z_j)
            except Exception as e:
                print(f"Loss computation failed: {e}")
                continue

            optimizer.zero_grad()
            loss.backward()

            # Clip Gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Accumulate epoch loss
            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")

        scheduler.step()
        print(f"Epoch {epoch}, Average Loss: {epoch_loss / len(train_loader):.4f}")

        # Save checkpoint
        save_path = f"checkpoints/ssl_checkpoint_epoch_{epoch}.pth"
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }, save_path)
        print(f"Model saved to {save_path}")


if __name__ == "__main__":
    train_ssl()