VariationalAutoencoders / README_HF.md
karthik-2905's picture
Upload folder using huggingface_hub
da32975 verified
---
title: Variational Autoencoder (VAE) - MNIST
emoji: 🎨
colorFrom: blue
colorTo: purple
sdk: pytorch
app_file: Untitled.ipynb
pinned: false
license: mit
tags:
- deep-learning
- generative-ai
- pytorch
- vae
- variational-autoencoder
- mnist
- computer-vision
- unsupervised-learning
- representation-learning
datasets:
- mnist
---
# Variational Autoencoder (VAE) - MNIST Implementation
A comprehensive PyTorch implementation of Variational Autoencoders trained on the MNIST dataset with detailed analysis and visualizations.
## Model Description
This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the MNIST handwritten digits dataset. The model learns to encode images into a 2-dimensional latent space and decode them back to reconstructed images, enabling both data compression and generation of new digit-like images.
### Architecture Details
- **Model Type**: Variational Autoencoder (VAE)
- **Framework**: PyTorch
- **Input**: 28Γ—28 grayscale images (784 dimensions)
- **Latent Space**: 2 dimensions (for visualization)
- **Hidden Layers**: 256 β†’ 128 (encoder), 128 β†’ 256 (decoder)
- **Total Parameters**: ~400K
- **Model Size**: 1.8MB
### Key Components
1. **Encoder Network**: Maps input images to latent distribution parameters (ΞΌ, σ²)
2. **Reparameterization Trick**: Enables differentiable sampling from the latent distribution
3. **Decoder Network**: Reconstructs images from latent space samples
4. **Loss Function**: Combines reconstruction loss (binary cross-entropy) and KL divergence
## Training Details
- **Dataset**: MNIST (60,000 training images, 10,000 test images)
- **Batch Size**: 128
- **Epochs**: 20
- **Optimizer**: Adam
- **Learning Rate**: 1e-3
- **Beta Parameter**: 1.0 (standard VAE)
## Model Performance
### Metrics
- **Final Training Loss**: ~85.2
- **Final Validation Loss**: ~86.1
- **Reconstruction Loss**: ~83.5
- **KL Divergence**: ~1.7
### Capabilities
- βœ… High-quality digit reconstruction
- βœ… Smooth latent space interpolation
- βœ… Generation of new digit-like samples
- βœ… Well-organized latent space with digit clusters
## Usage
### Quick Start
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
# Load the model (after downloading the files)
class VAE(nn.Module):
def __init__(self, input_dim=784, latent_dim=2, hidden_dim=256, beta=1.0):
super(VAE, self).__init__()
# ... (full implementation in the notebook)
def forward(self, x):
# ... (full implementation in the notebook)
pass
# Load trained model
model = VAE()
model.load_state_dict(torch.load('vae_logs_latent2_beta1.0/best_vae_model.pth'))
model.eval()
# Generate new samples
with torch.no_grad():
# Sample from latent space
z = torch.randn(16, 2) # 16 samples, 2D latent space
generated_images = model.decode(z)
# Reshape and visualize
generated_images = generated_images.view(-1, 28, 28)
# Plot the generated images...
```
### Visualizations Available
1. **Latent Space Visualization**: 2D scatter plot showing digit clusters
2. **Reconstructions**: Original vs. reconstructed digit comparisons
3. **Generated Samples**: New digits sampled from the latent space
4. **Interpolations**: Smooth transitions between different digits
5. **Training Curves**: Loss components over training epochs
## Files and Outputs
- `Untitled.ipynb`: Complete implementation with training and visualization
- `best_vae_model.pth`: Trained model weights
- `training_metrics.csv`: Detailed training metrics
- `generated_samples.png`: Grid of generated digit samples
- `latent_space_visualization.png`: 2D latent space plot
- `reconstruction_comparison.png`: Original vs reconstructed images
- `latent_interpolation.png`: Interpolation between digit pairs
- `comprehensive_training_curves.png`: Training loss curves
## Applications
This VAE implementation can be used for:
- **Generative Modeling**: Create new handwritten digit images
- **Dimensionality Reduction**: Compress images to 2D representations
- **Anomaly Detection**: Identify unusual digits using reconstruction error
- **Data Augmentation**: Generate synthetic training data
- **Representation Learning**: Learn meaningful features for downstream tasks
- **Educational Purposes**: Understand VAE concepts and implementation
## Research and Educational Value
This implementation serves as an excellent educational resource for:
- Understanding Variational Autoencoders theory and practice
- Learning PyTorch implementation techniques
- Exploring generative modeling concepts
- Analyzing latent space representations
- Studying the balance between reconstruction and regularization
## Citation
If you use this implementation in your research or projects, please cite:
```bibtex
@misc{vae_mnist_implementation,
title={Variational Autoencoder Implementation for MNIST},
author={Gruhesh Kurra},
year={2024},
url={https://huggingface.co/karthik-2905/VariationalAutoencoders}
}
```
## License
This project is licensed under the MIT License - see the LICENSE file for details.
## Additional Resources
- **GitHub Repository**: [VariationalAutoencoders](https://github.com/GruheshKurra/VariationalAutoencoders)
- **Detailed Documentation**: Check `grok.md` for comprehensive VAE explanations
- **Training Logs**: Complete metrics and analysis in the log directories
---
**Tags**: deep-learning, generative-ai, pytorch, vae, mnist, computer-vision, unsupervised-learning
**Model Card Authors**: Gruhesh Kurra