harry commited on
Commit
1fff313
·
1 Parent(s): 91c78a9

feat: baseline model

Browse files
.gitignore CHANGED
@@ -4,4 +4,7 @@ __pycache__/
4
  wandb/
5
  checkpoints/
6
  *.egg-info/
7
- dist/
 
 
 
 
4
  wandb/
5
  checkpoints/
6
  *.egg-info/
7
+ dist/
8
+ mnist/
9
+ data/
10
+ runs/
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python.analysis.typeCheckingMode": "standard"
3
+ }
mnist_classifier/__init__.py ADDED
File without changes
mnist_classifier/configs/config.yaml DELETED
@@ -1,15 +0,0 @@
1
- training:
2
- batch_size: 64
3
- max_epochs: 10
4
- learning_rate: 0.001
5
- early_stopping_patience: 5
6
-
7
- model:
8
- conv1_channels: 32
9
- conv2_channels: 64
10
- fc1_size: 128
11
- dropout_rate: 0.25
12
-
13
- wandb:
14
- project: "mnist-classifier"
15
- entity: "bardenha"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mnist_classifier/data/datamodule.py DELETED
@@ -1,27 +0,0 @@
1
- from typing import Dict, Any
2
-
3
- import pytorch_lightning as pl
4
- from datasets import load_dataset
5
- from torch.utils.data import DataLoader
6
-
7
- class MNISTDataModule(pl.LightningDataModule):
8
- def __init__(self, config: Dict[str, Any]):
9
- super().__init__()
10
- self.config = config
11
-
12
- def setup(self, stage=None):
13
- self.dataset = load_dataset('mnist')
14
- self.dataset = self.dataset.with_transform(self.config.transform_dataset)
15
-
16
- def train_dataloader(self):
17
- return DataLoader(
18
- self.dataset['train'],
19
- batch_size=self.config.batch_size,
20
- shuffle=True
21
- )
22
-
23
- def val_dataloader(self):
24
- return DataLoader(
25
- self.dataset['test'], # Using test set as validation
26
- batch_size=self.config.batch_size
27
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mnist_classifier/dataset.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import datasets, transforms
3
+
4
+ class MNISTDataModule:
5
+ def __init__(self, batch_size=64, val_batch_size=1000):
6
+ self.batch_size = batch_size
7
+ self.val_batch_size = val_batch_size
8
+
9
+ def get_dataloaders(self):
10
+ """Create training and test dataloaders."""
11
+ transform = transforms.Compose([
12
+ transforms.ToTensor(),
13
+ transforms.Normalize((0.5,), (0.5,))
14
+ ])
15
+ train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
16
+ test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
17
+
18
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
19
+ test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=self.val_batch_size, shuffle=False)
20
+
21
+ return train_loader, test_loader
mnist_classifier/model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class MNISTModel(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
9
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
10
+ self.dropout1 = nn.Dropout2d(0.25)
11
+ self.dropout2 = nn.Dropout2d(0.5)
12
+ self.fc1 = nn.Linear(9216, 128)
13
+ self.fc2 = nn.Linear(128, 10)
14
+
15
+ def forward(self, x):
16
+ x = self.conv1(x)
17
+ x = F.relu(x)
18
+ x = self.conv2(x)
19
+ x = F.relu(x)
20
+ x = F.max_pool2d(x, 2)
21
+ x = self.dropout1(x)
22
+ x = torch.flatten(x, 1)
23
+ x = self.fc1(x)
24
+ x = F.relu(x)
25
+ x = self.dropout2(x)
26
+ x = self.fc2(x)
27
+ return F.log_softmax(x, dim=1)
mnist_classifier/models/mnist_model.py DELETED
@@ -1,124 +0,0 @@
1
- from typing import Dict, Any
2
-
3
- import pytorch_lightning as pl
4
- import torch
5
- import torch.nn as nn
6
- import torchmetrics
7
- import wandb
8
-
9
- # Simple CNN architecture for MNIST
10
- class MNISTNet(nn.Module):
11
- def __init__(self, config: Dict[str, Any]):
12
- super().__init__()
13
- self.conv1 = nn.Conv2d(1, config['model']['conv1_channels'], kernel_size=3)
14
- self.conv2 = nn.Conv2d(config['model']['conv1_channels'],
15
- config['model']['conv2_channels'], kernel_size=3)
16
- self.pool = nn.MaxPool2d(2)
17
- self.dropout = nn.Dropout(config['model']['dropout_rate'])
18
- self.fc1 = nn.Linear(config['model']['conv2_channels'] * 5 * 5,
19
- config['model']['fc1_size'])
20
- self.fc2 = nn.Linear(config['model']['fc1_size'], 10)
21
-
22
- def forward(self, x):
23
- x = torch.relu(self.conv1(x))
24
- x = self.pool(torch.relu(self.conv2(x)))
25
- x = self.dropout(x)
26
- x = x.view(-1, 64 * 5 * 5)
27
- x = torch.relu(self.fc1(x))
28
- x = self.fc2(x)
29
- return x
30
-
31
-
32
-
33
- class MNISTModule(pl.LightningModule):
34
- def __init__(self, config: Dict[str, Any]):
35
- super().__init__()
36
- self.config = config
37
- self.model = MNISTNet(config)
38
-
39
- # Initialize metrics
40
- self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
41
- self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
42
- self.train_f1 = torchmetrics.F1Score(task='multiclass', num_classes=10)
43
- self.val_f1 = torchmetrics.F1Score(task='multiclass', num_classes=10)
44
- self.confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=10)
45
-
46
- def forward(self, x):
47
- return self.model(x)
48
-
49
- def training_step(self, batch, batch_idx):
50
- x, y = batch['pixel_values'], batch['label']
51
- logits = self(x)
52
- loss = nn.CrossEntropyLoss()(logits, y)
53
-
54
- # Calculate and log metrics
55
- preds = torch.argmax(logits, dim=1)
56
- self.train_accuracy(preds, y)
57
- self.train_f1(preds, y)
58
-
59
- # Log metrics
60
- self.log('train_loss', loss, prog_bar=True)
61
- self.log('train_accuracy', self.train_accuracy, prog_bar=True)
62
- self.log('train_f1', self.train_f1, prog_bar=True)
63
-
64
- return loss
65
-
66
- def validation_step(self, batch, batch_idx):
67
- x, y = batch['pixel_values'], batch['label']
68
- logits = self(x)
69
- loss = nn.CrossEntropyLoss()(logits, y)
70
-
71
- # Calculate metrics
72
- preds = torch.argmax(logits, dim=1)
73
- self.val_accuracy(preds, y)
74
- self.val_f1(preds, y)
75
- self.confusion_matrix(preds, y)
76
-
77
- # Log metrics
78
- self.log('val_loss', loss, prog_bar=True)
79
- self.log('val_accuracy', self.val_accuracy, prog_bar=True)
80
- self.log('val_f1', self.val_f1, prog_bar=True)
81
-
82
- # Log sample predictions periodically
83
- if batch_idx == 0: # First batch of each epoch
84
- self._log_sample_predictions(x, y, preds)
85
-
86
- def _log_sample_predictions(self, images, labels, predictions):
87
- # Log a grid of sample predictions
88
- if self.logger:
89
- n_samples = min(16, len(images))
90
- self.logger.experiment.log({
91
- "sample_predictions": [
92
- wandb.Image(
93
- images[i],
94
- caption=f"True: {labels[i].item()} Pred: {predictions[i].item()}"
95
- )
96
- for i in range(n_samples)
97
- ]
98
- })
99
-
100
- def on_validation_epoch_end(self):
101
- # Log confusion matrix at the end of each validation epoch
102
- conf_mat = self.confusion_matrix.compute()
103
- self.logger.experiment.log({
104
- "confusion_matrix": wandb.plot.confusion_matrix(
105
- probs=None,
106
- y_true=conf_mat.flatten(),
107
- preds=None,
108
- class_names=range(10)
109
- )
110
- })
111
- self.confusion_matrix.reset()
112
-
113
- def configure_optimizers(self):
114
- optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
115
- scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
116
- optimizer, mode='min', factor=0.1, patience=3, verbose=True
117
- )
118
- return {
119
- "optimizer": optimizer,
120
- "lr_scheduler": {
121
- "scheduler": scheduler,
122
- "monitor": "val_loss"
123
- }
124
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mnist_classifier/train.py CHANGED
@@ -1,49 +1,79 @@
1
- import pytorch_lightning as pl
2
- from pytorch_lightning.loggers import WandbLogger
3
- from pathlib import Path
4
-
5
- from mnist_classifier.models.mnist_model import MNISTModule
6
- from mnist_classifier.data.datamodule import MNISTDataModule
7
- from mnist_classifier.utils.metrics import load_config
8
-
9
- def main():
10
- config = load_config(Path("mnist_classifier/configs/config.yaml"))
11
-
12
- # Initialize wandb logger
13
- wandb_logger = WandbLogger(
14
- project=config['wandb']['project'],
15
- entity=config['wandb']['entity']
16
- )
17
-
18
- # Initialize trainer
19
- trainer = pl.Trainer(
20
- max_epochs=config['training']['max_epochs'],
21
- accelerator='gpu',
22
- devices=[0],
23
- logger=wandb_logger,
24
- callbacks=[
25
- pl.callbacks.ModelCheckpoint(
26
- dirpath='checkpoints',
27
- filename='mnist-{epoch:02d}-{val_loss:.2f}',
28
- save_top_k=3,
29
- monitor='val_loss',
30
- mode='min'
31
- ),
32
- pl.callbacks.EarlyStopping(
33
- monitor='val_loss',
34
- patience=config['training']['early_stopping_patience'],
35
- mode='min'
36
- ),
37
- pl.callbacks.LearningRateMonitor(logging_interval='epoch')
38
- ]
39
- )
40
-
41
- # Initialize data module and model
42
- data_module = MNISTDataModule(config)
43
- model = MNISTModule(config)
44
-
45
- # Train
46
- trainer.fit(model, data_module)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
- main()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.tensorboard.writer import SummaryWriter
6
+ from mnist_classifier.dataset import MNISTDataModule
7
+ from mnist_classifier.model import MNISTModel
8
+
9
+ def train():
10
+ # Set device
11
+ device = torch.device('cuda')
12
+ print(f"Using device: {device}")
13
+
14
+ # Initialize tensorboard
15
+ writer = SummaryWriter('runs/mnist_experiment')
16
+
17
+ # Setup data
18
+ data_module = MNISTDataModule(batch_size=64, val_batch_size=1000)
19
+ train_loader, test_loader = data_module.get_dataloaders()
20
+
21
+ # Initialize model, optimizer, and loss function
22
+ model = MNISTModel().to(device)
23
+ optimizer = optim.Adam(model.parameters())
24
+ criterion = nn.CrossEntropyLoss()
25
+
26
+ # Training loop
27
+ num_epochs = 10
28
+ for epoch in range(num_epochs):
29
+ model.train()
30
+ running_loss = 0.0
31
+ correct = 0
32
+ total = 0
33
+
34
+ for batch_idx, batch in enumerate(train_loader):
35
+ images, labels = batch[0].to(device), batch[1].to(device)
36
+
37
+ optimizer.zero_grad()
38
+ outputs = model(images)
39
+ loss = criterion(outputs, labels)
40
+ loss.backward()
41
+ optimizer.step()
42
+
43
+ running_loss += loss.item()
44
+ _, predicted = outputs.max(1)
45
+ total += labels.size(0)
46
+ correct += predicted.eq(labels).sum().item()
47
+
48
+ if batch_idx % 100 == 99:
49
+ writer.add_scalar('training loss',
50
+ running_loss / 100,
51
+ epoch * len(train_loader) + batch_idx)
52
+ writer.add_scalar('training accuracy',
53
+ 100. * correct / total,
54
+ epoch * len(train_loader) + batch_idx)
55
+ running_loss = 0.0
56
+
57
+ # Validation phase
58
+ model.eval()
59
+ test_loss = 0
60
+ correct = 0
61
+ total = 0
62
+ with torch.no_grad():
63
+ for batch in test_loader:
64
+ images = batch[0].to(device)
65
+ labels = batch[1].to(device)
66
+ outputs = model(images)
67
+ loss = criterion(outputs, labels)
68
+
69
+ test_loss += loss.item()
70
+ _, predicted = outputs.max(1)
71
+ total += labels.size(0)
72
+ correct += predicted.eq(labels).sum().item()
73
+
74
+ accuracy = 100. * correct / total
75
+ writer.add_scalar('test accuracy', accuracy, epoch)
76
+ print(f'Epoch {epoch+1}: Test Accuracy: {accuracy:.2f}%')
77
 
78
  if __name__ == "__main__":
79
+ train()
mnist_classifier/utils/metrics.py DELETED
@@ -1,6 +0,0 @@
1
- import yaml
2
- from pathlib import Path
3
-
4
- def load_config(config_path: str):
5
- with open(config_path, 'r') as f:
6
- return yaml.safe_load(f)
 
 
 
 
 
 
 
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -8,14 +8,14 @@ readme = "README.md"
8
 
9
  [tool.poetry.dependencies]
10
  python = "^3.10"
11
- torch = "^2.4.0"
12
- torchvision = "^0.15.0"
13
- pytorch-lightning = "^2.0.0"
14
- wandb = "^0.15.0"
15
- torchmetrics = "^1.0.0"
16
- datasets = "^2.0.0"
17
- huggingface-hub = "^0.16.0"
18
  pyyaml = "^6.0"
 
 
 
 
 
 
 
19
 
20
  [tool.poetry.group.dev.dependencies]
21
  pytest = "^7.0.0"
 
8
 
9
  [tool.poetry.dependencies]
10
  python = "^3.10"
 
 
 
 
 
 
 
11
  pyyaml = "^6.0"
12
+ torch = "^2.5.1"
13
+ torchvision = "^0.20.1"
14
+ transformers = "^4.46.3"
15
+ datasets = "^3.1.0"
16
+ tensorboard = "^2.18.0"
17
+ tqdm = "^4.67.0"
18
+ types-tqdm = "^4.67.0.20241119"
19
 
20
  [tool.poetry.group.dev.dependencies]
21
  pytest = "^7.0.0"
tests/test_model.py DELETED
@@ -1,11 +0,0 @@
1
- import pytest
2
- import torch
3
- from mnist_classifier.models.mnist_model import MNISTNet
4
- from mnist_classifier.utils.metrics import load_config
5
-
6
- def test_mnist_net_forward():
7
- config = load_config('mnist_classifier/configs/config.yaml')
8
- model = MNISTNet(config)
9
- x = torch.randn(1, 1, 28, 28)
10
- output = model(x)
11
- assert output.shape == (1, 10)