harry
commited on
Commit
·
1fff313
1
Parent(s):
91c78a9
feat: baseline model
Browse files- .gitignore +4 -1
- .vscode/settings.json +3 -0
- mnist_classifier/__init__.py +0 -0
- mnist_classifier/configs/config.yaml +0 -15
- mnist_classifier/data/datamodule.py +0 -27
- mnist_classifier/dataset.py +21 -0
- mnist_classifier/model.py +27 -0
- mnist_classifier/models/mnist_model.py +0 -124
- mnist_classifier/train.py +77 -47
- mnist_classifier/utils/metrics.py +0 -6
- poetry.lock +0 -0
- pyproject.toml +7 -7
- tests/test_model.py +0 -11
.gitignore
CHANGED
@@ -4,4 +4,7 @@ __pycache__/
|
|
4 |
wandb/
|
5 |
checkpoints/
|
6 |
*.egg-info/
|
7 |
-
dist/
|
|
|
|
|
|
|
|
4 |
wandb/
|
5 |
checkpoints/
|
6 |
*.egg-info/
|
7 |
+
dist/
|
8 |
+
mnist/
|
9 |
+
data/
|
10 |
+
runs/
|
.vscode/settings.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.analysis.typeCheckingMode": "standard"
|
3 |
+
}
|
mnist_classifier/__init__.py
ADDED
File without changes
|
mnist_classifier/configs/config.yaml
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
training:
|
2 |
-
batch_size: 64
|
3 |
-
max_epochs: 10
|
4 |
-
learning_rate: 0.001
|
5 |
-
early_stopping_patience: 5
|
6 |
-
|
7 |
-
model:
|
8 |
-
conv1_channels: 32
|
9 |
-
conv2_channels: 64
|
10 |
-
fc1_size: 128
|
11 |
-
dropout_rate: 0.25
|
12 |
-
|
13 |
-
wandb:
|
14 |
-
project: "mnist-classifier"
|
15 |
-
entity: "bardenha"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mnist_classifier/data/datamodule.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
from typing import Dict, Any
|
2 |
-
|
3 |
-
import pytorch_lightning as pl
|
4 |
-
from datasets import load_dataset
|
5 |
-
from torch.utils.data import DataLoader
|
6 |
-
|
7 |
-
class MNISTDataModule(pl.LightningDataModule):
|
8 |
-
def __init__(self, config: Dict[str, Any]):
|
9 |
-
super().__init__()
|
10 |
-
self.config = config
|
11 |
-
|
12 |
-
def setup(self, stage=None):
|
13 |
-
self.dataset = load_dataset('mnist')
|
14 |
-
self.dataset = self.dataset.with_transform(self.config.transform_dataset)
|
15 |
-
|
16 |
-
def train_dataloader(self):
|
17 |
-
return DataLoader(
|
18 |
-
self.dataset['train'],
|
19 |
-
batch_size=self.config.batch_size,
|
20 |
-
shuffle=True
|
21 |
-
)
|
22 |
-
|
23 |
-
def val_dataloader(self):
|
24 |
-
return DataLoader(
|
25 |
-
self.dataset['test'], # Using test set as validation
|
26 |
-
batch_size=self.config.batch_size
|
27 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mnist_classifier/dataset.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import datasets, transforms
|
3 |
+
|
4 |
+
class MNISTDataModule:
|
5 |
+
def __init__(self, batch_size=64, val_batch_size=1000):
|
6 |
+
self.batch_size = batch_size
|
7 |
+
self.val_batch_size = val_batch_size
|
8 |
+
|
9 |
+
def get_dataloaders(self):
|
10 |
+
"""Create training and test dataloaders."""
|
11 |
+
transform = transforms.Compose([
|
12 |
+
transforms.ToTensor(),
|
13 |
+
transforms.Normalize((0.5,), (0.5,))
|
14 |
+
])
|
15 |
+
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
|
16 |
+
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
|
17 |
+
|
18 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
|
19 |
+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=self.val_batch_size, shuffle=False)
|
20 |
+
|
21 |
+
return train_loader, test_loader
|
mnist_classifier/model.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class MNISTModel(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
9 |
+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
10 |
+
self.dropout1 = nn.Dropout2d(0.25)
|
11 |
+
self.dropout2 = nn.Dropout2d(0.5)
|
12 |
+
self.fc1 = nn.Linear(9216, 128)
|
13 |
+
self.fc2 = nn.Linear(128, 10)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
x = self.conv1(x)
|
17 |
+
x = F.relu(x)
|
18 |
+
x = self.conv2(x)
|
19 |
+
x = F.relu(x)
|
20 |
+
x = F.max_pool2d(x, 2)
|
21 |
+
x = self.dropout1(x)
|
22 |
+
x = torch.flatten(x, 1)
|
23 |
+
x = self.fc1(x)
|
24 |
+
x = F.relu(x)
|
25 |
+
x = self.dropout2(x)
|
26 |
+
x = self.fc2(x)
|
27 |
+
return F.log_softmax(x, dim=1)
|
mnist_classifier/models/mnist_model.py
DELETED
@@ -1,124 +0,0 @@
|
|
1 |
-
from typing import Dict, Any
|
2 |
-
|
3 |
-
import pytorch_lightning as pl
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torchmetrics
|
7 |
-
import wandb
|
8 |
-
|
9 |
-
# Simple CNN architecture for MNIST
|
10 |
-
class MNISTNet(nn.Module):
|
11 |
-
def __init__(self, config: Dict[str, Any]):
|
12 |
-
super().__init__()
|
13 |
-
self.conv1 = nn.Conv2d(1, config['model']['conv1_channels'], kernel_size=3)
|
14 |
-
self.conv2 = nn.Conv2d(config['model']['conv1_channels'],
|
15 |
-
config['model']['conv2_channels'], kernel_size=3)
|
16 |
-
self.pool = nn.MaxPool2d(2)
|
17 |
-
self.dropout = nn.Dropout(config['model']['dropout_rate'])
|
18 |
-
self.fc1 = nn.Linear(config['model']['conv2_channels'] * 5 * 5,
|
19 |
-
config['model']['fc1_size'])
|
20 |
-
self.fc2 = nn.Linear(config['model']['fc1_size'], 10)
|
21 |
-
|
22 |
-
def forward(self, x):
|
23 |
-
x = torch.relu(self.conv1(x))
|
24 |
-
x = self.pool(torch.relu(self.conv2(x)))
|
25 |
-
x = self.dropout(x)
|
26 |
-
x = x.view(-1, 64 * 5 * 5)
|
27 |
-
x = torch.relu(self.fc1(x))
|
28 |
-
x = self.fc2(x)
|
29 |
-
return x
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
class MNISTModule(pl.LightningModule):
|
34 |
-
def __init__(self, config: Dict[str, Any]):
|
35 |
-
super().__init__()
|
36 |
-
self.config = config
|
37 |
-
self.model = MNISTNet(config)
|
38 |
-
|
39 |
-
# Initialize metrics
|
40 |
-
self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
|
41 |
-
self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
|
42 |
-
self.train_f1 = torchmetrics.F1Score(task='multiclass', num_classes=10)
|
43 |
-
self.val_f1 = torchmetrics.F1Score(task='multiclass', num_classes=10)
|
44 |
-
self.confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=10)
|
45 |
-
|
46 |
-
def forward(self, x):
|
47 |
-
return self.model(x)
|
48 |
-
|
49 |
-
def training_step(self, batch, batch_idx):
|
50 |
-
x, y = batch['pixel_values'], batch['label']
|
51 |
-
logits = self(x)
|
52 |
-
loss = nn.CrossEntropyLoss()(logits, y)
|
53 |
-
|
54 |
-
# Calculate and log metrics
|
55 |
-
preds = torch.argmax(logits, dim=1)
|
56 |
-
self.train_accuracy(preds, y)
|
57 |
-
self.train_f1(preds, y)
|
58 |
-
|
59 |
-
# Log metrics
|
60 |
-
self.log('train_loss', loss, prog_bar=True)
|
61 |
-
self.log('train_accuracy', self.train_accuracy, prog_bar=True)
|
62 |
-
self.log('train_f1', self.train_f1, prog_bar=True)
|
63 |
-
|
64 |
-
return loss
|
65 |
-
|
66 |
-
def validation_step(self, batch, batch_idx):
|
67 |
-
x, y = batch['pixel_values'], batch['label']
|
68 |
-
logits = self(x)
|
69 |
-
loss = nn.CrossEntropyLoss()(logits, y)
|
70 |
-
|
71 |
-
# Calculate metrics
|
72 |
-
preds = torch.argmax(logits, dim=1)
|
73 |
-
self.val_accuracy(preds, y)
|
74 |
-
self.val_f1(preds, y)
|
75 |
-
self.confusion_matrix(preds, y)
|
76 |
-
|
77 |
-
# Log metrics
|
78 |
-
self.log('val_loss', loss, prog_bar=True)
|
79 |
-
self.log('val_accuracy', self.val_accuracy, prog_bar=True)
|
80 |
-
self.log('val_f1', self.val_f1, prog_bar=True)
|
81 |
-
|
82 |
-
# Log sample predictions periodically
|
83 |
-
if batch_idx == 0: # First batch of each epoch
|
84 |
-
self._log_sample_predictions(x, y, preds)
|
85 |
-
|
86 |
-
def _log_sample_predictions(self, images, labels, predictions):
|
87 |
-
# Log a grid of sample predictions
|
88 |
-
if self.logger:
|
89 |
-
n_samples = min(16, len(images))
|
90 |
-
self.logger.experiment.log({
|
91 |
-
"sample_predictions": [
|
92 |
-
wandb.Image(
|
93 |
-
images[i],
|
94 |
-
caption=f"True: {labels[i].item()} Pred: {predictions[i].item()}"
|
95 |
-
)
|
96 |
-
for i in range(n_samples)
|
97 |
-
]
|
98 |
-
})
|
99 |
-
|
100 |
-
def on_validation_epoch_end(self):
|
101 |
-
# Log confusion matrix at the end of each validation epoch
|
102 |
-
conf_mat = self.confusion_matrix.compute()
|
103 |
-
self.logger.experiment.log({
|
104 |
-
"confusion_matrix": wandb.plot.confusion_matrix(
|
105 |
-
probs=None,
|
106 |
-
y_true=conf_mat.flatten(),
|
107 |
-
preds=None,
|
108 |
-
class_names=range(10)
|
109 |
-
)
|
110 |
-
})
|
111 |
-
self.confusion_matrix.reset()
|
112 |
-
|
113 |
-
def configure_optimizers(self):
|
114 |
-
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
115 |
-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
116 |
-
optimizer, mode='min', factor=0.1, patience=3, verbose=True
|
117 |
-
)
|
118 |
-
return {
|
119 |
-
"optimizer": optimizer,
|
120 |
-
"lr_scheduler": {
|
121 |
-
"scheduler": scheduler,
|
122 |
-
"monitor": "val_loss"
|
123 |
-
}
|
124 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mnist_classifier/train.py
CHANGED
@@ -1,49 +1,79 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
from
|
6 |
-
from mnist_classifier.
|
7 |
-
from mnist_classifier.
|
8 |
-
|
9 |
-
def
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
if __name__ == "__main__":
|
49 |
-
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
6 |
+
from mnist_classifier.dataset import MNISTDataModule
|
7 |
+
from mnist_classifier.model import MNISTModel
|
8 |
+
|
9 |
+
def train():
|
10 |
+
# Set device
|
11 |
+
device = torch.device('cuda')
|
12 |
+
print(f"Using device: {device}")
|
13 |
+
|
14 |
+
# Initialize tensorboard
|
15 |
+
writer = SummaryWriter('runs/mnist_experiment')
|
16 |
+
|
17 |
+
# Setup data
|
18 |
+
data_module = MNISTDataModule(batch_size=64, val_batch_size=1000)
|
19 |
+
train_loader, test_loader = data_module.get_dataloaders()
|
20 |
+
|
21 |
+
# Initialize model, optimizer, and loss function
|
22 |
+
model = MNISTModel().to(device)
|
23 |
+
optimizer = optim.Adam(model.parameters())
|
24 |
+
criterion = nn.CrossEntropyLoss()
|
25 |
+
|
26 |
+
# Training loop
|
27 |
+
num_epochs = 10
|
28 |
+
for epoch in range(num_epochs):
|
29 |
+
model.train()
|
30 |
+
running_loss = 0.0
|
31 |
+
correct = 0
|
32 |
+
total = 0
|
33 |
+
|
34 |
+
for batch_idx, batch in enumerate(train_loader):
|
35 |
+
images, labels = batch[0].to(device), batch[1].to(device)
|
36 |
+
|
37 |
+
optimizer.zero_grad()
|
38 |
+
outputs = model(images)
|
39 |
+
loss = criterion(outputs, labels)
|
40 |
+
loss.backward()
|
41 |
+
optimizer.step()
|
42 |
+
|
43 |
+
running_loss += loss.item()
|
44 |
+
_, predicted = outputs.max(1)
|
45 |
+
total += labels.size(0)
|
46 |
+
correct += predicted.eq(labels).sum().item()
|
47 |
+
|
48 |
+
if batch_idx % 100 == 99:
|
49 |
+
writer.add_scalar('training loss',
|
50 |
+
running_loss / 100,
|
51 |
+
epoch * len(train_loader) + batch_idx)
|
52 |
+
writer.add_scalar('training accuracy',
|
53 |
+
100. * correct / total,
|
54 |
+
epoch * len(train_loader) + batch_idx)
|
55 |
+
running_loss = 0.0
|
56 |
+
|
57 |
+
# Validation phase
|
58 |
+
model.eval()
|
59 |
+
test_loss = 0
|
60 |
+
correct = 0
|
61 |
+
total = 0
|
62 |
+
with torch.no_grad():
|
63 |
+
for batch in test_loader:
|
64 |
+
images = batch[0].to(device)
|
65 |
+
labels = batch[1].to(device)
|
66 |
+
outputs = model(images)
|
67 |
+
loss = criterion(outputs, labels)
|
68 |
+
|
69 |
+
test_loss += loss.item()
|
70 |
+
_, predicted = outputs.max(1)
|
71 |
+
total += labels.size(0)
|
72 |
+
correct += predicted.eq(labels).sum().item()
|
73 |
+
|
74 |
+
accuracy = 100. * correct / total
|
75 |
+
writer.add_scalar('test accuracy', accuracy, epoch)
|
76 |
+
print(f'Epoch {epoch+1}: Test Accuracy: {accuracy:.2f}%')
|
77 |
|
78 |
if __name__ == "__main__":
|
79 |
+
train()
|
mnist_classifier/utils/metrics.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
import yaml
|
2 |
-
from pathlib import Path
|
3 |
-
|
4 |
-
def load_config(config_path: str):
|
5 |
-
with open(config_path, 'r') as f:
|
6 |
-
return yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
poetry.lock
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
CHANGED
@@ -8,14 +8,14 @@ readme = "README.md"
|
|
8 |
|
9 |
[tool.poetry.dependencies]
|
10 |
python = "^3.10"
|
11 |
-
torch = "^2.4.0"
|
12 |
-
torchvision = "^0.15.0"
|
13 |
-
pytorch-lightning = "^2.0.0"
|
14 |
-
wandb = "^0.15.0"
|
15 |
-
torchmetrics = "^1.0.0"
|
16 |
-
datasets = "^2.0.0"
|
17 |
-
huggingface-hub = "^0.16.0"
|
18 |
pyyaml = "^6.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
[tool.poetry.group.dev.dependencies]
|
21 |
pytest = "^7.0.0"
|
|
|
8 |
|
9 |
[tool.poetry.dependencies]
|
10 |
python = "^3.10"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
pyyaml = "^6.0"
|
12 |
+
torch = "^2.5.1"
|
13 |
+
torchvision = "^0.20.1"
|
14 |
+
transformers = "^4.46.3"
|
15 |
+
datasets = "^3.1.0"
|
16 |
+
tensorboard = "^2.18.0"
|
17 |
+
tqdm = "^4.67.0"
|
18 |
+
types-tqdm = "^4.67.0.20241119"
|
19 |
|
20 |
[tool.poetry.group.dev.dependencies]
|
21 |
pytest = "^7.0.0"
|
tests/test_model.py
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
import pytest
|
2 |
-
import torch
|
3 |
-
from mnist_classifier.models.mnist_model import MNISTNet
|
4 |
-
from mnist_classifier.utils.metrics import load_config
|
5 |
-
|
6 |
-
def test_mnist_net_forward():
|
7 |
-
config = load_config('mnist_classifier/configs/config.yaml')
|
8 |
-
model = MNISTNet(config)
|
9 |
-
x = torch.randn(1, 1, 28, 28)
|
10 |
-
output = model(x)
|
11 |
-
assert output.shape == (1, 10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|