File size: 5,635 Bytes
da32975 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
---
title: Variational Autoencoder (VAE) - MNIST
emoji: 🎨
colorFrom: blue
colorTo: purple
sdk: pytorch
app_file: Untitled.ipynb
pinned: false
license: mit
tags:
- deep-learning
- generative-ai
- pytorch
- vae
- variational-autoencoder
- mnist
- computer-vision
- unsupervised-learning
- representation-learning
datasets:
- mnist
---
# Variational Autoencoder (VAE) - MNIST Implementation
A comprehensive PyTorch implementation of Variational Autoencoders trained on the MNIST dataset with detailed analysis and visualizations.
## Model Description
This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the MNIST handwritten digits dataset. The model learns to encode images into a 2-dimensional latent space and decode them back to reconstructed images, enabling both data compression and generation of new digit-like images.
### Architecture Details
- **Model Type**: Variational Autoencoder (VAE)
- **Framework**: PyTorch
- **Input**: 28×28 grayscale images (784 dimensions)
- **Latent Space**: 2 dimensions (for visualization)
- **Hidden Layers**: 256 → 128 (encoder), 128 → 256 (decoder)
- **Total Parameters**: ~400K
- **Model Size**: 1.8MB
### Key Components
1. **Encoder Network**: Maps input images to latent distribution parameters (μ, σ²)
2. **Reparameterization Trick**: Enables differentiable sampling from the latent distribution
3. **Decoder Network**: Reconstructs images from latent space samples
4. **Loss Function**: Combines reconstruction loss (binary cross-entropy) and KL divergence
## Training Details
- **Dataset**: MNIST (60,000 training images, 10,000 test images)
- **Batch Size**: 128
- **Epochs**: 20
- **Optimizer**: Adam
- **Learning Rate**: 1e-3
- **Beta Parameter**: 1.0 (standard VAE)
## Model Performance
### Metrics
- **Final Training Loss**: ~85.2
- **Final Validation Loss**: ~86.1
- **Reconstruction Loss**: ~83.5
- **KL Divergence**: ~1.7
### Capabilities
- ✅ High-quality digit reconstruction
- ✅ Smooth latent space interpolation
- ✅ Generation of new digit-like samples
- ✅ Well-organized latent space with digit clusters
## Usage
### Quick Start
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
# Load the model (after downloading the files)
class VAE(nn.Module):
def __init__(self, input_dim=784, latent_dim=2, hidden_dim=256, beta=1.0):
super(VAE, self).__init__()
# ... (full implementation in the notebook)
def forward(self, x):
# ... (full implementation in the notebook)
pass
# Load trained model
model = VAE()
model.load_state_dict(torch.load('vae_logs_latent2_beta1.0/best_vae_model.pth'))
model.eval()
# Generate new samples
with torch.no_grad():
# Sample from latent space
z = torch.randn(16, 2) # 16 samples, 2D latent space
generated_images = model.decode(z)
# Reshape and visualize
generated_images = generated_images.view(-1, 28, 28)
# Plot the generated images...
```
### Visualizations Available
1. **Latent Space Visualization**: 2D scatter plot showing digit clusters
2. **Reconstructions**: Original vs. reconstructed digit comparisons
3. **Generated Samples**: New digits sampled from the latent space
4. **Interpolations**: Smooth transitions between different digits
5. **Training Curves**: Loss components over training epochs
## Files and Outputs
- `Untitled.ipynb`: Complete implementation with training and visualization
- `best_vae_model.pth`: Trained model weights
- `training_metrics.csv`: Detailed training metrics
- `generated_samples.png`: Grid of generated digit samples
- `latent_space_visualization.png`: 2D latent space plot
- `reconstruction_comparison.png`: Original vs reconstructed images
- `latent_interpolation.png`: Interpolation between digit pairs
- `comprehensive_training_curves.png`: Training loss curves
## Applications
This VAE implementation can be used for:
- **Generative Modeling**: Create new handwritten digit images
- **Dimensionality Reduction**: Compress images to 2D representations
- **Anomaly Detection**: Identify unusual digits using reconstruction error
- **Data Augmentation**: Generate synthetic training data
- **Representation Learning**: Learn meaningful features for downstream tasks
- **Educational Purposes**: Understand VAE concepts and implementation
## Research and Educational Value
This implementation serves as an excellent educational resource for:
- Understanding Variational Autoencoders theory and practice
- Learning PyTorch implementation techniques
- Exploring generative modeling concepts
- Analyzing latent space representations
- Studying the balance between reconstruction and regularization
## Citation
If you use this implementation in your research or projects, please cite:
```bibtex
@misc{vae_mnist_implementation,
title={Variational Autoencoder Implementation for MNIST},
author={Gruhesh Kurra},
year={2024},
url={https://huggingface.co/karthik-2905/VariationalAutoencoders}
}
```
## License
This project is licensed under the MIT License - see the LICENSE file for details.
## Additional Resources
- **GitHub Repository**: [VariationalAutoencoders](https://github.com/GruheshKurra/VariationalAutoencoders)
- **Detailed Documentation**: Check `grok.md` for comprehensive VAE explanations
- **Training Logs**: Complete metrics and analysis in the log directories
---
**Tags**: deep-learning, generative-ai, pytorch, vae, mnist, computer-vision, unsupervised-learning
**Model Card Authors**: Gruhesh Kurra |