File size: 2,321 Bytes
29a4de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
from model import SSLModel  

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the saved SSL model
model = SSLModel(resnet18(pretrained=False)).to(device)
saved_model_path = "models/saves/run2/ssl_checkpoint_epoch_15.pth"
checkpoint = torch.load(saved_model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print(f"Model loaded from {saved_model_path}")

transform = T.Compose([
    T.Resize(32),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=256, shuffle=False)

# Extract embeddings and corresponding labels
embeddings = []
labels = []

print("Extracting embeddings...")
with torch.no_grad():
    for imgs, lbls in dataloader:
        imgs = imgs.to(device)
        z = model(imgs)  # Get the embeddings
        embeddings.append(z.cpu().numpy())
        labels.append(lbls.numpy())

# Concatenate all embeddings and labels
embeddings = np.concatenate(embeddings, axis=0)
labels = np.concatenate(labels, axis=0)

# Reduce dimensionality using t-SNE
print("Reducing dimensionality...")
tsne = TSNE(n_components=2, random_state=42, init="pca", learning_rate="auto")
reduced_embeddings = tsne.fit_transform(embeddings)

# Plot embeddings
def plot_embeddings(embeddings, labels, class_names):
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(
        embeddings[:, 0],
        embeddings[:, 1],
        c=labels,
        cmap="tab10",
        alpha=0.7
    )
    legend = plt.legend(
        handles=scatter.legend_elements()[0],
        labels=class_names,
        loc="upper right",
        title="Classes"
    )
    plt.title("t-SNE Visualization of SSL Embeddings")
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.grid(True)
    plt.show()

# Get CIFAR-10 class names
class_names = dataset.classes

# Plot the embeddings
plot_embeddings(reduced_embeddings, labels, class_names)