torinriley
commited on
Upload 4 files
Browse files- Tuned_Model.pth +3 -0
- inference.py +78 -0
- model.py +120 -0
- train.py +147 -0
Tuned_Model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63d913d1a97519e4710c567617fe0d0fa6eebc94b9afb4e575956e16387d114d
|
3 |
+
size 134316314
|
inference.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as T
|
3 |
+
import torchvision.datasets as datasets
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torchvision.models import resnet18
|
6 |
+
from sklearn.manifold import TSNE
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
from model import SSLModel
|
10 |
+
|
11 |
+
# Device setup
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
|
14 |
+
# Load the saved SSL model
|
15 |
+
model = SSLModel(resnet18(pretrained=False)).to(device)
|
16 |
+
saved_model_path = "models/saves/run2/ssl_checkpoint_epoch_15.pth"
|
17 |
+
checkpoint = torch.load(saved_model_path, map_location=device)
|
18 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
19 |
+
model.eval()
|
20 |
+
print(f"Model loaded from {saved_model_path}")
|
21 |
+
|
22 |
+
transform = T.Compose([
|
23 |
+
T.Resize(32),
|
24 |
+
T.ToTensor(),
|
25 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
26 |
+
])
|
27 |
+
dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
|
28 |
+
dataloader = DataLoader(dataset, batch_size=256, shuffle=False)
|
29 |
+
|
30 |
+
# Extract embeddings and corresponding labels
|
31 |
+
embeddings = []
|
32 |
+
labels = []
|
33 |
+
|
34 |
+
print("Extracting embeddings...")
|
35 |
+
with torch.no_grad():
|
36 |
+
for imgs, lbls in dataloader:
|
37 |
+
imgs = imgs.to(device)
|
38 |
+
z = model(imgs) # Get the embeddings
|
39 |
+
embeddings.append(z.cpu().numpy())
|
40 |
+
labels.append(lbls.numpy())
|
41 |
+
|
42 |
+
# Concatenate all embeddings and labels
|
43 |
+
embeddings = np.concatenate(embeddings, axis=0)
|
44 |
+
labels = np.concatenate(labels, axis=0)
|
45 |
+
|
46 |
+
# Reduce dimensionality using t-SNE
|
47 |
+
print("Reducing dimensionality...")
|
48 |
+
tsne = TSNE(n_components=2, random_state=42, init="pca", learning_rate="auto")
|
49 |
+
reduced_embeddings = tsne.fit_transform(embeddings)
|
50 |
+
|
51 |
+
# Plot embeddings
|
52 |
+
def plot_embeddings(embeddings, labels, class_names):
|
53 |
+
plt.figure(figsize=(10, 8))
|
54 |
+
scatter = plt.scatter(
|
55 |
+
embeddings[:, 0],
|
56 |
+
embeddings[:, 1],
|
57 |
+
c=labels,
|
58 |
+
cmap="tab10",
|
59 |
+
alpha=0.7
|
60 |
+
)
|
61 |
+
legend = plt.legend(
|
62 |
+
handles=scatter.legend_elements()[0],
|
63 |
+
labels=class_names,
|
64 |
+
loc="upper right",
|
65 |
+
title="Classes"
|
66 |
+
)
|
67 |
+
plt.title("t-SNE Visualization of SSL Embeddings")
|
68 |
+
plt.xlabel("Dimension 1")
|
69 |
+
plt.ylabel("Dimension 2")
|
70 |
+
plt.grid(True)
|
71 |
+
plt.show()
|
72 |
+
|
73 |
+
# Get CIFAR-10 class names
|
74 |
+
class_names = dataset.classes
|
75 |
+
|
76 |
+
# Plot the embeddings
|
77 |
+
plot_embeddings(reduced_embeddings, labels, class_names)
|
78 |
+
|
model.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.transforms as T
|
4 |
+
import torchvision.datasets as datasets
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torchvision.models import resnet18
|
8 |
+
|
9 |
+
|
10 |
+
class SSLModel(nn.Module):
|
11 |
+
def __init__(self, backbone, projection_dim=128):
|
12 |
+
super(SSLModel, self).__init__()
|
13 |
+
self.backbone = backbone
|
14 |
+
self.projection_head = nn.Sequential(
|
15 |
+
nn.Linear(backbone.fc.in_features, 512),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.Linear(512, projection_dim)
|
18 |
+
)
|
19 |
+
self.backbone.fc = nn.Identity()
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
features = self.backbone(x)
|
23 |
+
projections = self.projection_head(features)
|
24 |
+
return projections
|
25 |
+
|
26 |
+
|
27 |
+
def contrastive_loss(z_i, z_j, temperature=0.5):
|
28 |
+
batch_size = z_i.shape[0]
|
29 |
+
|
30 |
+
# Concatenate both views
|
31 |
+
z = torch.cat([z_i, z_j], dim=0) # (2 * batch_size, projection_dim)
|
32 |
+
|
33 |
+
# Similarity matrix computation (dot product normalized by temperature)
|
34 |
+
sim_matrix = torch.mm(z, z.T) / temperature # (2 * batch_size, 2 * batch_size)
|
35 |
+
|
36 |
+
sim_matrix -= torch.max(sim_matrix, dim=1, keepdim=True)[0]
|
37 |
+
|
38 |
+
# Mask out self-similarity
|
39 |
+
mask = torch.eye(sim_matrix.size(0), device=sim_matrix.device).bool()
|
40 |
+
sim_matrix = sim_matrix.masked_fill(mask, -float("inf"))
|
41 |
+
|
42 |
+
# Extract positive similarities (z_i, z_j) and (z_j, z_i)
|
43 |
+
pos_sim = torch.cat([
|
44 |
+
torch.diag(sim_matrix, sim_matrix.size(0) // 2),
|
45 |
+
torch.diag(sim_matrix, -sim_matrix.size(0) // 2)
|
46 |
+
])
|
47 |
+
|
48 |
+
loss = -torch.log(torch.exp(pos_sim) / torch.sum(torch.exp(sim_matrix), dim=1))
|
49 |
+
return loss.mean()
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
transform = T.Compose([
|
54 |
+
T.RandomResizedCrop(32),
|
55 |
+
T.RandomHorizontalFlip(),
|
56 |
+
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
|
57 |
+
T.RandomGrayscale(p=0.2),
|
58 |
+
T.GaussianBlur(kernel_size=3),
|
59 |
+
T.ToTensor(),
|
60 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
|
61 |
+
])
|
62 |
+
|
63 |
+
# Load CIFAR-10 dataset
|
64 |
+
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
|
65 |
+
train_loader = DataLoader(
|
66 |
+
train_dataset,
|
67 |
+
batch_size=256,
|
68 |
+
shuffle=True,
|
69 |
+
pin_memory=True,
|
70 |
+
num_workers=4
|
71 |
+
)
|
72 |
+
|
73 |
+
model = SSLModel(resnet18(pretrained=False)).to(device := torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
74 |
+
|
75 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
|
76 |
+
|
77 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
|
78 |
+
|
79 |
+
# Training loop
|
80 |
+
model.train()
|
81 |
+
for epoch in range(10):
|
82 |
+
epoch_loss = 0
|
83 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/10", unit="batch")
|
84 |
+
|
85 |
+
for batch in progress_bar:
|
86 |
+
imgs, _ = batch
|
87 |
+
imgs = imgs.to(device, non_blocking=True)
|
88 |
+
|
89 |
+
# Create two augmented views
|
90 |
+
z_i = model(imgs)
|
91 |
+
z_j = model(imgs)
|
92 |
+
|
93 |
+
# Compute contrastive loss
|
94 |
+
try:
|
95 |
+
loss = contrastive_loss(z_i, z_j)
|
96 |
+
except Exception as e:
|
97 |
+
print(f"Loss computation failed: {e}")
|
98 |
+
continue
|
99 |
+
|
100 |
+
optimizer.zero_grad()
|
101 |
+
loss.backward()
|
102 |
+
|
103 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
104 |
+
|
105 |
+
optimizer.step()
|
106 |
+
|
107 |
+
epoch_loss += loss.item()
|
108 |
+
progress_bar.set_postfix(loss=f"{loss.item():.4f}")
|
109 |
+
|
110 |
+
scheduler.step()
|
111 |
+
print(f"Epoch {epoch + 1}, Average Loss: {epoch_loss / len(train_loader):.4f}")
|
112 |
+
|
113 |
+
# Save checkpoint
|
114 |
+
torch.save({
|
115 |
+
"epoch": epoch,
|
116 |
+
"model_state_dict": model.state_dict(),
|
117 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
118 |
+
}, f"ssl_checkpoint_epoch_{epoch + 1}.pth")
|
119 |
+
print(f"Model saved to ssl_checkpoint_epoch_{epoch + 1}.pth")
|
120 |
+
|
train.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision.transforms as T
|
5 |
+
import torchvision.datasets as datasets
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torchvision.models import resnet18
|
9 |
+
|
10 |
+
|
11 |
+
# Define SSLModel with ResNet-18 backbone
|
12 |
+
class SSLModel(nn.Module):
|
13 |
+
def __init__(self, backbone, projection_dim=128):
|
14 |
+
super(SSLModel, self).__init__()
|
15 |
+
self.backbone = backbone
|
16 |
+
self.projection_head = nn.Sequential(
|
17 |
+
nn.Linear(backbone.fc.in_features, 512),
|
18 |
+
nn.ReLU(),
|
19 |
+
nn.Linear(512, projection_dim)
|
20 |
+
)
|
21 |
+
self.backbone.fc = nn.Identity() # Remove classification head
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
features = self.backbone(x)
|
25 |
+
projections = self.projection_head(features)
|
26 |
+
return projections
|
27 |
+
|
28 |
+
|
29 |
+
# Contrastive Loss
|
30 |
+
def contrastive_loss(z_i, z_j, temperature=0.5):
|
31 |
+
batch_size = z_i.shape[0]
|
32 |
+
|
33 |
+
# Concatenate both views
|
34 |
+
z = torch.cat([z_i, z_j], dim=0) # Shape: (2 * batch_size, projection_dim)
|
35 |
+
|
36 |
+
# Similarity matrix computation (dot product normalized by temperature)
|
37 |
+
sim_matrix = torch.mm(z, z.T) / temperature # Shape: (2 * batch_size, 2 * batch_size)
|
38 |
+
|
39 |
+
# Normalize to prevent instability
|
40 |
+
sim_matrix = sim_matrix - torch.max(sim_matrix, dim=1, keepdim=True)[0]
|
41 |
+
|
42 |
+
# Mask out self-similarity
|
43 |
+
mask = torch.eye(sim_matrix.size(0), device=sim_matrix.device).bool()
|
44 |
+
sim_matrix = sim_matrix.masked_fill(mask, -float("inf"))
|
45 |
+
|
46 |
+
# Extract positive similarities (z_i, z_j) and (z_j, z_i)
|
47 |
+
pos_sim = torch.cat([
|
48 |
+
torch.diag(sim_matrix, sim_matrix.size(0) // 2),
|
49 |
+
torch.diag(sim_matrix, -sim_matrix.size(0) // 2)
|
50 |
+
])
|
51 |
+
|
52 |
+
# Compute contrastive loss
|
53 |
+
loss = -torch.log(torch.exp(pos_sim) / torch.sum(torch.exp(sim_matrix), dim=1))
|
54 |
+
return loss.mean()
|
55 |
+
|
56 |
+
|
57 |
+
def train_ssl():
|
58 |
+
# Device setup
|
59 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
+
|
61 |
+
# Data Transformations
|
62 |
+
transform = T.Compose([
|
63 |
+
T.RandomResizedCrop(32),
|
64 |
+
T.RandomHorizontalFlip(),
|
65 |
+
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
|
66 |
+
T.RandomGrayscale(p=0.2),
|
67 |
+
T.GaussianBlur(kernel_size=3),
|
68 |
+
T.ToTensor(),
|
69 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
|
70 |
+
])
|
71 |
+
|
72 |
+
# Load Dataset
|
73 |
+
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
|
74 |
+
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True, num_workers=4)
|
75 |
+
|
76 |
+
# Initialize Model
|
77 |
+
model = SSLModel(resnet18(pretrained=False)).to(device)
|
78 |
+
|
79 |
+
# Optimizer and Scheduler
|
80 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
|
81 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
|
82 |
+
|
83 |
+
# Resume Training (if checkpoint exists)
|
84 |
+
start_epoch = 1
|
85 |
+
checkpoint_path = "models/saves/run2/ssl_checkpoint_epoch_14.pth"
|
86 |
+
if os.path.exists(checkpoint_path):
|
87 |
+
print(f"Resuming training from checkpoint: {checkpoint_path}")
|
88 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
89 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
90 |
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
91 |
+
start_epoch = checkpoint["epoch"] + 1
|
92 |
+
|
93 |
+
# Create "checkpoints" directory if it doesn't exist
|
94 |
+
os.makedirs("checkpoints", exist_ok=True)
|
95 |
+
|
96 |
+
# Training Loop
|
97 |
+
model.train()
|
98 |
+
total_epochs = 15 # Adjust based on the training plan
|
99 |
+
for epoch in range(start_epoch, total_epochs + 1):
|
100 |
+
epoch_loss = 0
|
101 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{total_epochs}", unit="batch")
|
102 |
+
|
103 |
+
for batch in progress_bar:
|
104 |
+
imgs, _ = batch
|
105 |
+
imgs = imgs.to(device, non_blocking=True)
|
106 |
+
|
107 |
+
# Create two augmented views
|
108 |
+
z_i = model(imgs)
|
109 |
+
z_j = model(imgs)
|
110 |
+
|
111 |
+
# Validate embeddings
|
112 |
+
assert not torch.isnan(z_i).any(), "z_i contains NaN values!"
|
113 |
+
assert not torch.isnan(z_j).any(), "z_j contains NaN values!"
|
114 |
+
|
115 |
+
try:
|
116 |
+
loss = contrastive_loss(z_i, z_j)
|
117 |
+
except Exception as e:
|
118 |
+
print(f"Loss computation failed: {e}")
|
119 |
+
continue
|
120 |
+
|
121 |
+
optimizer.zero_grad()
|
122 |
+
loss.backward()
|
123 |
+
|
124 |
+
# Clip Gradients
|
125 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
126 |
+
optimizer.step()
|
127 |
+
|
128 |
+
# Accumulate epoch loss
|
129 |
+
epoch_loss += loss.item()
|
130 |
+
progress_bar.set_postfix(loss=f"{loss.item():.4f}")
|
131 |
+
|
132 |
+
scheduler.step()
|
133 |
+
print(f"Epoch {epoch}, Average Loss: {epoch_loss / len(train_loader):.4f}")
|
134 |
+
|
135 |
+
# Save checkpoint
|
136 |
+
save_path = f"checkpoints/ssl_checkpoint_epoch_{epoch}.pth"
|
137 |
+
torch.save({
|
138 |
+
"epoch": epoch,
|
139 |
+
"model_state_dict": model.state_dict(),
|
140 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
141 |
+
}, save_path)
|
142 |
+
print(f"Model saved to {save_path}")
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
train_ssl()
|
147 |
+
|