torinriley commited on
Commit
29a4de2
·
verified ·
1 Parent(s): 1826e11

Upload 4 files

Browse files
Files changed (4) hide show
  1. Tuned_Model.pth +3 -0
  2. inference.py +78 -0
  3. model.py +120 -0
  4. train.py +147 -0
Tuned_Model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63d913d1a97519e4710c567617fe0d0fa6eebc94b9afb4e575956e16387d114d
3
+ size 134316314
inference.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ import torchvision.datasets as datasets
4
+ from torch.utils.data import DataLoader
5
+ from torchvision.models import resnet18
6
+ from sklearn.manifold import TSNE
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from model import SSLModel
10
+
11
+ # Device setup
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # Load the saved SSL model
15
+ model = SSLModel(resnet18(pretrained=False)).to(device)
16
+ saved_model_path = "models/saves/run2/ssl_checkpoint_epoch_15.pth"
17
+ checkpoint = torch.load(saved_model_path, map_location=device)
18
+ model.load_state_dict(checkpoint["model_state_dict"])
19
+ model.eval()
20
+ print(f"Model loaded from {saved_model_path}")
21
+
22
+ transform = T.Compose([
23
+ T.Resize(32),
24
+ T.ToTensor(),
25
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
26
+ ])
27
+ dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
28
+ dataloader = DataLoader(dataset, batch_size=256, shuffle=False)
29
+
30
+ # Extract embeddings and corresponding labels
31
+ embeddings = []
32
+ labels = []
33
+
34
+ print("Extracting embeddings...")
35
+ with torch.no_grad():
36
+ for imgs, lbls in dataloader:
37
+ imgs = imgs.to(device)
38
+ z = model(imgs) # Get the embeddings
39
+ embeddings.append(z.cpu().numpy())
40
+ labels.append(lbls.numpy())
41
+
42
+ # Concatenate all embeddings and labels
43
+ embeddings = np.concatenate(embeddings, axis=0)
44
+ labels = np.concatenate(labels, axis=0)
45
+
46
+ # Reduce dimensionality using t-SNE
47
+ print("Reducing dimensionality...")
48
+ tsne = TSNE(n_components=2, random_state=42, init="pca", learning_rate="auto")
49
+ reduced_embeddings = tsne.fit_transform(embeddings)
50
+
51
+ # Plot embeddings
52
+ def plot_embeddings(embeddings, labels, class_names):
53
+ plt.figure(figsize=(10, 8))
54
+ scatter = plt.scatter(
55
+ embeddings[:, 0],
56
+ embeddings[:, 1],
57
+ c=labels,
58
+ cmap="tab10",
59
+ alpha=0.7
60
+ )
61
+ legend = plt.legend(
62
+ handles=scatter.legend_elements()[0],
63
+ labels=class_names,
64
+ loc="upper right",
65
+ title="Classes"
66
+ )
67
+ plt.title("t-SNE Visualization of SSL Embeddings")
68
+ plt.xlabel("Dimension 1")
69
+ plt.ylabel("Dimension 2")
70
+ plt.grid(True)
71
+ plt.show()
72
+
73
+ # Get CIFAR-10 class names
74
+ class_names = dataset.classes
75
+
76
+ # Plot the embeddings
77
+ plot_embeddings(reduced_embeddings, labels, class_names)
78
+
model.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as T
4
+ import torchvision.datasets as datasets
5
+ from torch.utils.data import DataLoader
6
+ from tqdm import tqdm
7
+ from torchvision.models import resnet18
8
+
9
+
10
+ class SSLModel(nn.Module):
11
+ def __init__(self, backbone, projection_dim=128):
12
+ super(SSLModel, self).__init__()
13
+ self.backbone = backbone
14
+ self.projection_head = nn.Sequential(
15
+ nn.Linear(backbone.fc.in_features, 512),
16
+ nn.ReLU(),
17
+ nn.Linear(512, projection_dim)
18
+ )
19
+ self.backbone.fc = nn.Identity()
20
+
21
+ def forward(self, x):
22
+ features = self.backbone(x)
23
+ projections = self.projection_head(features)
24
+ return projections
25
+
26
+
27
+ def contrastive_loss(z_i, z_j, temperature=0.5):
28
+ batch_size = z_i.shape[0]
29
+
30
+ # Concatenate both views
31
+ z = torch.cat([z_i, z_j], dim=0) # (2 * batch_size, projection_dim)
32
+
33
+ # Similarity matrix computation (dot product normalized by temperature)
34
+ sim_matrix = torch.mm(z, z.T) / temperature # (2 * batch_size, 2 * batch_size)
35
+
36
+ sim_matrix -= torch.max(sim_matrix, dim=1, keepdim=True)[0]
37
+
38
+ # Mask out self-similarity
39
+ mask = torch.eye(sim_matrix.size(0), device=sim_matrix.device).bool()
40
+ sim_matrix = sim_matrix.masked_fill(mask, -float("inf"))
41
+
42
+ # Extract positive similarities (z_i, z_j) and (z_j, z_i)
43
+ pos_sim = torch.cat([
44
+ torch.diag(sim_matrix, sim_matrix.size(0) // 2),
45
+ torch.diag(sim_matrix, -sim_matrix.size(0) // 2)
46
+ ])
47
+
48
+ loss = -torch.log(torch.exp(pos_sim) / torch.sum(torch.exp(sim_matrix), dim=1))
49
+ return loss.mean()
50
+
51
+
52
+ if __name__ == "__main__":
53
+ transform = T.Compose([
54
+ T.RandomResizedCrop(32),
55
+ T.RandomHorizontalFlip(),
56
+ T.ColorJitter(0.4, 0.4, 0.4, 0.1),
57
+ T.RandomGrayscale(p=0.2),
58
+ T.GaussianBlur(kernel_size=3),
59
+ T.ToTensor(),
60
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
61
+ ])
62
+
63
+ # Load CIFAR-10 dataset
64
+ train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
65
+ train_loader = DataLoader(
66
+ train_dataset,
67
+ batch_size=256,
68
+ shuffle=True,
69
+ pin_memory=True,
70
+ num_workers=4
71
+ )
72
+
73
+ model = SSLModel(resnet18(pretrained=False)).to(device := torch.device("cuda" if torch.cuda.is_available() else "cpu"))
74
+
75
+ optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
76
+
77
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
78
+
79
+ # Training loop
80
+ model.train()
81
+ for epoch in range(10):
82
+ epoch_loss = 0
83
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/10", unit="batch")
84
+
85
+ for batch in progress_bar:
86
+ imgs, _ = batch
87
+ imgs = imgs.to(device, non_blocking=True)
88
+
89
+ # Create two augmented views
90
+ z_i = model(imgs)
91
+ z_j = model(imgs)
92
+
93
+ # Compute contrastive loss
94
+ try:
95
+ loss = contrastive_loss(z_i, z_j)
96
+ except Exception as e:
97
+ print(f"Loss computation failed: {e}")
98
+ continue
99
+
100
+ optimizer.zero_grad()
101
+ loss.backward()
102
+
103
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
104
+
105
+ optimizer.step()
106
+
107
+ epoch_loss += loss.item()
108
+ progress_bar.set_postfix(loss=f"{loss.item():.4f}")
109
+
110
+ scheduler.step()
111
+ print(f"Epoch {epoch + 1}, Average Loss: {epoch_loss / len(train_loader):.4f}")
112
+
113
+ # Save checkpoint
114
+ torch.save({
115
+ "epoch": epoch,
116
+ "model_state_dict": model.state_dict(),
117
+ "optimizer_state_dict": optimizer.state_dict(),
118
+ }, f"ssl_checkpoint_epoch_{epoch + 1}.pth")
119
+ print(f"Model saved to ssl_checkpoint_epoch_{epoch + 1}.pth")
120
+
train.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as T
5
+ import torchvision.datasets as datasets
6
+ from torch.utils.data import DataLoader
7
+ from tqdm import tqdm
8
+ from torchvision.models import resnet18
9
+
10
+
11
+ # Define SSLModel with ResNet-18 backbone
12
+ class SSLModel(nn.Module):
13
+ def __init__(self, backbone, projection_dim=128):
14
+ super(SSLModel, self).__init__()
15
+ self.backbone = backbone
16
+ self.projection_head = nn.Sequential(
17
+ nn.Linear(backbone.fc.in_features, 512),
18
+ nn.ReLU(),
19
+ nn.Linear(512, projection_dim)
20
+ )
21
+ self.backbone.fc = nn.Identity() # Remove classification head
22
+
23
+ def forward(self, x):
24
+ features = self.backbone(x)
25
+ projections = self.projection_head(features)
26
+ return projections
27
+
28
+
29
+ # Contrastive Loss
30
+ def contrastive_loss(z_i, z_j, temperature=0.5):
31
+ batch_size = z_i.shape[0]
32
+
33
+ # Concatenate both views
34
+ z = torch.cat([z_i, z_j], dim=0) # Shape: (2 * batch_size, projection_dim)
35
+
36
+ # Similarity matrix computation (dot product normalized by temperature)
37
+ sim_matrix = torch.mm(z, z.T) / temperature # Shape: (2 * batch_size, 2 * batch_size)
38
+
39
+ # Normalize to prevent instability
40
+ sim_matrix = sim_matrix - torch.max(sim_matrix, dim=1, keepdim=True)[0]
41
+
42
+ # Mask out self-similarity
43
+ mask = torch.eye(sim_matrix.size(0), device=sim_matrix.device).bool()
44
+ sim_matrix = sim_matrix.masked_fill(mask, -float("inf"))
45
+
46
+ # Extract positive similarities (z_i, z_j) and (z_j, z_i)
47
+ pos_sim = torch.cat([
48
+ torch.diag(sim_matrix, sim_matrix.size(0) // 2),
49
+ torch.diag(sim_matrix, -sim_matrix.size(0) // 2)
50
+ ])
51
+
52
+ # Compute contrastive loss
53
+ loss = -torch.log(torch.exp(pos_sim) / torch.sum(torch.exp(sim_matrix), dim=1))
54
+ return loss.mean()
55
+
56
+
57
+ def train_ssl():
58
+ # Device setup
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+
61
+ # Data Transformations
62
+ transform = T.Compose([
63
+ T.RandomResizedCrop(32),
64
+ T.RandomHorizontalFlip(),
65
+ T.ColorJitter(0.4, 0.4, 0.4, 0.1),
66
+ T.RandomGrayscale(p=0.2),
67
+ T.GaussianBlur(kernel_size=3),
68
+ T.ToTensor(),
69
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
70
+ ])
71
+
72
+ # Load Dataset
73
+ train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
74
+ train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True, num_workers=4)
75
+
76
+ # Initialize Model
77
+ model = SSLModel(resnet18(pretrained=False)).to(device)
78
+
79
+ # Optimizer and Scheduler
80
+ optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
81
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
82
+
83
+ # Resume Training (if checkpoint exists)
84
+ start_epoch = 1
85
+ checkpoint_path = "models/saves/run2/ssl_checkpoint_epoch_14.pth"
86
+ if os.path.exists(checkpoint_path):
87
+ print(f"Resuming training from checkpoint: {checkpoint_path}")
88
+ checkpoint = torch.load(checkpoint_path, map_location=device)
89
+ model.load_state_dict(checkpoint["model_state_dict"])
90
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
91
+ start_epoch = checkpoint["epoch"] + 1
92
+
93
+ # Create "checkpoints" directory if it doesn't exist
94
+ os.makedirs("checkpoints", exist_ok=True)
95
+
96
+ # Training Loop
97
+ model.train()
98
+ total_epochs = 15 # Adjust based on the training plan
99
+ for epoch in range(start_epoch, total_epochs + 1):
100
+ epoch_loss = 0
101
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{total_epochs}", unit="batch")
102
+
103
+ for batch in progress_bar:
104
+ imgs, _ = batch
105
+ imgs = imgs.to(device, non_blocking=True)
106
+
107
+ # Create two augmented views
108
+ z_i = model(imgs)
109
+ z_j = model(imgs)
110
+
111
+ # Validate embeddings
112
+ assert not torch.isnan(z_i).any(), "z_i contains NaN values!"
113
+ assert not torch.isnan(z_j).any(), "z_j contains NaN values!"
114
+
115
+ try:
116
+ loss = contrastive_loss(z_i, z_j)
117
+ except Exception as e:
118
+ print(f"Loss computation failed: {e}")
119
+ continue
120
+
121
+ optimizer.zero_grad()
122
+ loss.backward()
123
+
124
+ # Clip Gradients
125
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
126
+ optimizer.step()
127
+
128
+ # Accumulate epoch loss
129
+ epoch_loss += loss.item()
130
+ progress_bar.set_postfix(loss=f"{loss.item():.4f}")
131
+
132
+ scheduler.step()
133
+ print(f"Epoch {epoch}, Average Loss: {epoch_loss / len(train_loader):.4f}")
134
+
135
+ # Save checkpoint
136
+ save_path = f"checkpoints/ssl_checkpoint_epoch_{epoch}.pth"
137
+ torch.save({
138
+ "epoch": epoch,
139
+ "model_state_dict": model.state_dict(),
140
+ "optimizer_state_dict": optimizer.state_dict(),
141
+ }, save_path)
142
+ print(f"Model saved to {save_path}")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ train_ssl()
147
+