|
--- |
|
title: Variational Autoencoder (VAE) - MNIST |
|
emoji: π¨ |
|
colorFrom: blue |
|
colorTo: purple |
|
sdk: pytorch |
|
app_file: Untitled.ipynb |
|
pinned: false |
|
license: mit |
|
tags: |
|
- deep-learning |
|
- generative-ai |
|
- pytorch |
|
- vae |
|
- variational-autoencoder |
|
- mnist |
|
- computer-vision |
|
- unsupervised-learning |
|
- representation-learning |
|
datasets: |
|
- mnist |
|
--- |
|
|
|
# Variational Autoencoder (VAE) - MNIST Implementation |
|
|
|
A comprehensive PyTorch implementation of Variational Autoencoders trained on the MNIST dataset with detailed analysis and visualizations. |
|
|
|
## Model Description |
|
|
|
This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the MNIST handwritten digits dataset. The model learns to encode images into a 2-dimensional latent space and decode them back to reconstructed images, enabling both data compression and generation of new digit-like images. |
|
|
|
### Architecture Details |
|
|
|
- **Model Type**: Variational Autoencoder (VAE) |
|
- **Framework**: PyTorch |
|
- **Input**: 28Γ28 grayscale images (784 dimensions) |
|
- **Latent Space**: 2 dimensions (for visualization) |
|
- **Hidden Layers**: 256 β 128 (encoder), 128 β 256 (decoder) |
|
- **Total Parameters**: ~400K |
|
- **Model Size**: 1.8MB |
|
|
|
### Key Components |
|
|
|
1. **Encoder Network**: Maps input images to latent distribution parameters (ΞΌ, ΟΒ²) |
|
2. **Reparameterization Trick**: Enables differentiable sampling from the latent distribution |
|
3. **Decoder Network**: Reconstructs images from latent space samples |
|
4. **Loss Function**: Combines reconstruction loss (binary cross-entropy) and KL divergence |
|
|
|
## Training Details |
|
|
|
- **Dataset**: MNIST (60,000 training images, 10,000 test images) |
|
- **Batch Size**: 128 |
|
- **Epochs**: 20 |
|
- **Optimizer**: Adam |
|
- **Learning Rate**: 1e-3 |
|
- **Beta Parameter**: 1.0 (standard VAE) |
|
|
|
## Model Performance |
|
|
|
### Metrics |
|
- **Final Training Loss**: ~85.2 |
|
- **Final Validation Loss**: ~86.1 |
|
- **Reconstruction Loss**: ~83.5 |
|
- **KL Divergence**: ~1.7 |
|
|
|
### Capabilities |
|
- β
High-quality digit reconstruction |
|
- β
Smooth latent space interpolation |
|
- β
Generation of new digit-like samples |
|
- β
Well-organized latent space with digit clusters |
|
|
|
## Usage |
|
|
|
### Quick Start |
|
|
|
```python |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import matplotlib.pyplot as plt |
|
from torchvision import datasets, transforms |
|
|
|
# Load the model (after downloading the files) |
|
class VAE(nn.Module): |
|
def __init__(self, input_dim=784, latent_dim=2, hidden_dim=256, beta=1.0): |
|
super(VAE, self).__init__() |
|
# ... (full implementation in the notebook) |
|
|
|
def forward(self, x): |
|
# ... (full implementation in the notebook) |
|
pass |
|
|
|
# Load trained model |
|
model = VAE() |
|
model.load_state_dict(torch.load('vae_logs_latent2_beta1.0/best_vae_model.pth')) |
|
model.eval() |
|
|
|
# Generate new samples |
|
with torch.no_grad(): |
|
# Sample from latent space |
|
z = torch.randn(16, 2) # 16 samples, 2D latent space |
|
generated_images = model.decode(z) |
|
|
|
# Reshape and visualize |
|
generated_images = generated_images.view(-1, 28, 28) |
|
# Plot the generated images... |
|
``` |
|
|
|
### Visualizations Available |
|
|
|
1. **Latent Space Visualization**: 2D scatter plot showing digit clusters |
|
2. **Reconstructions**: Original vs. reconstructed digit comparisons |
|
3. **Generated Samples**: New digits sampled from the latent space |
|
4. **Interpolations**: Smooth transitions between different digits |
|
5. **Training Curves**: Loss components over training epochs |
|
|
|
## Files and Outputs |
|
|
|
- `Untitled.ipynb`: Complete implementation with training and visualization |
|
- `best_vae_model.pth`: Trained model weights |
|
- `training_metrics.csv`: Detailed training metrics |
|
- `generated_samples.png`: Grid of generated digit samples |
|
- `latent_space_visualization.png`: 2D latent space plot |
|
- `reconstruction_comparison.png`: Original vs reconstructed images |
|
- `latent_interpolation.png`: Interpolation between digit pairs |
|
- `comprehensive_training_curves.png`: Training loss curves |
|
|
|
## Applications |
|
|
|
This VAE implementation can be used for: |
|
|
|
- **Generative Modeling**: Create new handwritten digit images |
|
- **Dimensionality Reduction**: Compress images to 2D representations |
|
- **Anomaly Detection**: Identify unusual digits using reconstruction error |
|
- **Data Augmentation**: Generate synthetic training data |
|
- **Representation Learning**: Learn meaningful features for downstream tasks |
|
- **Educational Purposes**: Understand VAE concepts and implementation |
|
|
|
## Research and Educational Value |
|
|
|
This implementation serves as an excellent educational resource for: |
|
|
|
- Understanding Variational Autoencoders theory and practice |
|
- Learning PyTorch implementation techniques |
|
- Exploring generative modeling concepts |
|
- Analyzing latent space representations |
|
- Studying the balance between reconstruction and regularization |
|
|
|
## Citation |
|
|
|
If you use this implementation in your research or projects, please cite: |
|
|
|
```bibtex |
|
@misc{vae_mnist_implementation, |
|
title={Variational Autoencoder Implementation for MNIST}, |
|
author={Gruhesh Kurra}, |
|
year={2024}, |
|
url={https://huggingface.co/karthik-2905/VariationalAutoencoders} |
|
} |
|
``` |
|
|
|
## License |
|
|
|
This project is licensed under the MIT License - see the LICENSE file for details. |
|
|
|
## Additional Resources |
|
|
|
- **GitHub Repository**: [VariationalAutoencoders](https://github.com/GruheshKurra/VariationalAutoencoders) |
|
- **Detailed Documentation**: Check `grok.md` for comprehensive VAE explanations |
|
- **Training Logs**: Complete metrics and analysis in the log directories |
|
|
|
--- |
|
|
|
**Tags**: deep-learning, generative-ai, pytorch, vae, mnist, computer-vision, unsupervised-learning |
|
|
|
**Model Card Authors**: Gruhesh Kurra |