File size: 5,635 Bytes
da32975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
---
title: Variational Autoencoder (VAE) - MNIST
emoji: 🎨
colorFrom: blue
colorTo: purple
sdk: pytorch
app_file: Untitled.ipynb
pinned: false
license: mit
tags:
- deep-learning
- generative-ai
- pytorch
- vae
- variational-autoencoder
- mnist
- computer-vision
- unsupervised-learning
- representation-learning
datasets:
- mnist
---

# Variational Autoencoder (VAE) - MNIST Implementation

A comprehensive PyTorch implementation of Variational Autoencoders trained on the MNIST dataset with detailed analysis and visualizations.

## Model Description

This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the MNIST handwritten digits dataset. The model learns to encode images into a 2-dimensional latent space and decode them back to reconstructed images, enabling both data compression and generation of new digit-like images.

### Architecture Details

- **Model Type**: Variational Autoencoder (VAE)
- **Framework**: PyTorch
- **Input**: 28×28 grayscale images (784 dimensions)
- **Latent Space**: 2 dimensions (for visualization)
- **Hidden Layers**: 256 → 128 (encoder), 128 → 256 (decoder)
- **Total Parameters**: ~400K
- **Model Size**: 1.8MB

### Key Components

1. **Encoder Network**: Maps input images to latent distribution parameters (μ, σ²)
2. **Reparameterization Trick**: Enables differentiable sampling from the latent distribution
3. **Decoder Network**: Reconstructs images from latent space samples
4. **Loss Function**: Combines reconstruction loss (binary cross-entropy) and KL divergence

## Training Details

- **Dataset**: MNIST (60,000 training images, 10,000 test images)
- **Batch Size**: 128
- **Epochs**: 20
- **Optimizer**: Adam
- **Learning Rate**: 1e-3
- **Beta Parameter**: 1.0 (standard VAE)

## Model Performance

### Metrics
- **Final Training Loss**: ~85.2
- **Final Validation Loss**: ~86.1
- **Reconstruction Loss**: ~83.5
- **KL Divergence**: ~1.7

### Capabilities
- ✅ High-quality digit reconstruction
- ✅ Smooth latent space interpolation
- ✅ Generation of new digit-like samples
- ✅ Well-organized latent space with digit clusters

## Usage

### Quick Start

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

# Load the model (after downloading the files)
class VAE(nn.Module):
    def __init__(self, input_dim=784, latent_dim=2, hidden_dim=256, beta=1.0):
        super(VAE, self).__init__()
        # ... (full implementation in the notebook)
    
    def forward(self, x):
        # ... (full implementation in the notebook)
        pass

# Load trained model
model = VAE()
model.load_state_dict(torch.load('vae_logs_latent2_beta1.0/best_vae_model.pth'))
model.eval()

# Generate new samples
with torch.no_grad():
    # Sample from latent space
    z = torch.randn(16, 2)  # 16 samples, 2D latent space
    generated_images = model.decode(z)
    
    # Reshape and visualize
    generated_images = generated_images.view(-1, 28, 28)
    # Plot the generated images...
```

### Visualizations Available

1. **Latent Space Visualization**: 2D scatter plot showing digit clusters
2. **Reconstructions**: Original vs. reconstructed digit comparisons  
3. **Generated Samples**: New digits sampled from the latent space
4. **Interpolations**: Smooth transitions between different digits
5. **Training Curves**: Loss components over training epochs

## Files and Outputs

- `Untitled.ipynb`: Complete implementation with training and visualization
- `best_vae_model.pth`: Trained model weights
- `training_metrics.csv`: Detailed training metrics
- `generated_samples.png`: Grid of generated digit samples
- `latent_space_visualization.png`: 2D latent space plot
- `reconstruction_comparison.png`: Original vs reconstructed images
- `latent_interpolation.png`: Interpolation between digit pairs
- `comprehensive_training_curves.png`: Training loss curves

## Applications

This VAE implementation can be used for:

- **Generative Modeling**: Create new handwritten digit images
- **Dimensionality Reduction**: Compress images to 2D representations
- **Anomaly Detection**: Identify unusual digits using reconstruction error
- **Data Augmentation**: Generate synthetic training data
- **Representation Learning**: Learn meaningful features for downstream tasks
- **Educational Purposes**: Understand VAE concepts and implementation

## Research and Educational Value

This implementation serves as an excellent educational resource for:

- Understanding Variational Autoencoders theory and practice
- Learning PyTorch implementation techniques
- Exploring generative modeling concepts
- Analyzing latent space representations
- Studying the balance between reconstruction and regularization

## Citation

If you use this implementation in your research or projects, please cite:

```bibtex
@misc{vae_mnist_implementation,
  title={Variational Autoencoder Implementation for MNIST},
  author={Gruhesh Kurra},
  year={2024},
  url={https://huggingface.co/karthik-2905/VariationalAutoencoders}
}
```

## License

This project is licensed under the MIT License - see the LICENSE file for details.

## Additional Resources

- **GitHub Repository**: [VariationalAutoencoders](https://github.com/GruheshKurra/VariationalAutoencoders)
- **Detailed Documentation**: Check `grok.md` for comprehensive VAE explanations
- **Training Logs**: Complete metrics and analysis in the log directories

---

**Tags**: deep-learning, generative-ai, pytorch, vae, mnist, computer-vision, unsupervised-learning

**Model Card Authors**: Gruhesh Kurra