Spaces:
Sleeping
Sleeping
import streamlit as st | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import json | |
import os | |
import numpy as np | |
from streamlit_lottie import st_lottie | |
# Custom CSS to increase text size and beautify the app | |
# Custom CSS to increase text size, beautify the app, and highlight the final message | |
st.markdown(""" | |
<style> | |
.big-font { | |
font-size:60px !important; | |
text-align: center; | |
} | |
.slider-label { | |
font-size:25px !important; | |
font-weight: bold; | |
} | |
.small-text { | |
font-size:12px !important; | |
} | |
.medium-text { | |
font-size:16px !important; | |
} | |
.center-text { | |
text-align: center; | |
} | |
.highlight { | |
font-family: 'Courier New', Courier, monospace; | |
background-color: #f0f0f0; | |
padding: 10px; | |
border-radius: 5px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Load Lottie animation | |
def load_lottiefile(filepath: str): | |
with open(filepath, "r") as f: | |
return json.load(f) | |
# Load training history | |
def load_history(history_path): | |
with open(history_path, 'r') as f: | |
history = json.load(f) | |
return history | |
# Smooth data | |
def smooth_data(data, window_size): | |
return np.convolve(data, np.ones(window_size)/window_size, mode='valid') | |
# Streamlit app | |
# Streamlit app | |
st.markdown('<h1 class="big-font">TuNNe</h1>', unsafe_allow_html=True) | |
st.markdown('<h2 class="center-text">Tuning a Neural Network</h2>', unsafe_allow_html=True) | |
st.markdown('<p class="center-text">This app demonstrates how different hyperparameters affect the training of a neural network using the MNIST dataset (50% of the data).</p>', unsafe_allow_html=True) | |
# Load and display Lottie animation | |
lottie_animation = load_lottiefile("Animation - 1719728959093.json") | |
st_lottie(lottie_animation, height=300, key="header_animation") | |
# Directory containing models | |
model_dir = "modelllls" | |
model_files = [f for f in os.listdir(model_dir) if f.endswith('.json')] | |
# Extract available hyperparameters from model filenames | |
def extract_hyperparameters(model_files): | |
hyperparameters = [] | |
for f in model_files: | |
parts = f.split('_') | |
lr = float(parts[2][2:]) | |
bs = int(parts[3][2:]) | |
epochs = int(parts[4][6:].replace('.json', '')) | |
hyperparameters.append((lr, bs, epochs)) | |
return hyperparameters | |
hyperparameters = extract_hyperparameters(model_files) | |
# Get unique values for each hyperparameter | |
learning_rates = sorted(set(lr for lr, _, _ in hyperparameters)) | |
# Select slider for learning rate | |
st.markdown('<p class="slider-label">Learning Rate</p>', unsafe_allow_html=True) | |
if len(learning_rates) > 1: | |
selected_lr = st.select_slider("Learning Rate", options=learning_rates, label_visibility="collapsed") | |
else: | |
selected_lr = learning_rates[0] if learning_rates else None | |
st.write(f"Only one learning rate available: {selected_lr}") | |
# Filter batch sizes based on selected learning rate | |
filtered_bs = sorted(set(bs for lr, bs, _ in hyperparameters if lr == selected_lr)) | |
st.markdown('<p class="slider-label">Batch Size</p>', unsafe_allow_html=True) | |
if len(filtered_bs) > 1: | |
selected_bs = st.select_slider("Batch Size", options=filtered_bs, label_visibility="collapsed") | |
else: | |
selected_bs = filtered_bs[0] if filtered_bs else None | |
st.write(f"Only one batch size available: {selected_bs}") | |
# Filter epochs based on selected learning rate and batch size | |
filtered_epochs = sorted(set(epochs for lr, bs, epochs in hyperparameters if lr == selected_lr and bs == selected_bs)) | |
st.markdown('<p class="slider-label">Epochs</p>', unsafe_allow_html=True) | |
if len(filtered_epochs) > 1: | |
selected_epochs = st.select_slider("Epochs", options=filtered_epochs, label_visibility="collapsed") | |
else: | |
selected_epochs = filtered_epochs[0] if filtered_epochs else None | |
st.write(f"Only one epoch option available: {selected_epochs}") | |
# Options for grid and smoothing | |
enable_grid = st.checkbox("Enable Grid Lines") | |
if selected_epochs and selected_epochs > 20: | |
smoothing_window = st.slider("Smoothing Window (every 4 epochs)", min_value=1, max_value=5, step=1, value=1) | |
# Find the corresponding history file | |
if selected_lr is not None and selected_bs is not None and selected_epochs is not None: | |
history_filename = f"mnist_model_lr{selected_lr}_bs{selected_bs}_epochs{selected_epochs}.json" | |
history_path = os.path.join(model_dir, history_filename) | |
if os.path.exists(history_path): | |
history = load_history(history_path) | |
# Plot training & validation accuracy values | |
fig, ax = plt.subplots() | |
accuracy = history['accuracy'] | |
val_accuracy = history['val_accuracy'] | |
if selected_epochs > 20 and 'smoothing_window' in locals() and smoothing_window > 1: | |
accuracy = smooth_data(accuracy, smoothing_window * 4) | |
val_accuracy = smooth_data(val_accuracy, smoothing_window * 4) | |
sns.lineplot(x=range(len(accuracy)), y=accuracy, ax=ax, label='Train Accuracy') | |
sns.lineplot(x=range(len(val_accuracy)), y=val_accuracy, ax=ax, label='Validation Accuracy') | |
ax.set_title('Model Accuracy', fontsize=15) | |
ax.set_ylabel('Accuracy', fontsize=12) | |
ax.set_xlabel('Epoch', fontsize=12) | |
ax.legend(loc='upper left', fontsize=10) | |
if enable_grid: | |
ax.grid(True) | |
st.pyplot(fig) | |
# Plot training & validation loss values | |
fig, ax = plt.subplots() | |
loss = history['loss'] | |
val_loss = history['val_loss'] | |
if selected_epochs > 20 and 'smoothing_window' in locals() and smoothing_window > 1: | |
loss = smooth_data(loss, smoothing_window * 4) | |
val_loss = smooth_data(val_loss, smoothing_window * 4) | |
sns.lineplot(x=range(len(loss)), y=loss, ax=ax, label='Train Loss') | |
sns.lineplot(x=range(len(val_loss)), y=val_loss, ax=ax, label='Validation Loss') | |
ax.set_title('Model Loss', fontsize=15) | |
ax.set_ylabel('Loss', fontsize=12) | |
ax.set_xlabel('Epoch', fontsize=12) | |
ax.legend(loc='upper left', fontsize=10) | |
if enable_grid: | |
ax.grid(True) | |
st.pyplot(fig) | |
else: | |
st.error(f"History file not found: {history_path}") | |
else: | |
st.error("Unable to load model due to missing hyperparameters") | |
# Final message | |
st.markdown(""" | |
<p class="medium-text">There is no rule of thumb for hyperparameters. The combination varies, and this is just to give an idea or an interactive way of showing how each parameter affects the model training.</p> | |
<p class="medium-text">Though the code is generated using AI, the tuning has to be done by a human. π</p> | |
""", unsafe_allow_html=True) | |