File size: 6,708 Bytes
f5f6805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f776208
f5f6805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f776208
 
 
 
 
f5f6805
 
 
 
f776208
 
 
 
 
f5f6805
 
 
 
f776208
 
 
 
 
f5f6805
 
 
f776208
f5f6805
 
 
f776208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5f6805
f776208
f5f6805
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
import numpy as np
from streamlit_lottie import st_lottie

# Custom CSS to increase text size and beautify the app
# Custom CSS to increase text size, beautify the app, and highlight the final message
st.markdown("""
    <style>
    .big-font {
        font-size:60px !important;
        text-align: center;
    }
    .slider-label {
        font-size:25px !important;
        font-weight: bold;
    }
    .small-text {
        font-size:12px !important;
    }
    .medium-text {
        font-size:16px !important;
    }
    .center-text {
        text-align: center;
    }
    .highlight {
        font-family: 'Courier New', Courier, monospace;
        background-color: #f0f0f0;
        padding: 10px;
        border-radius: 5px;
    }
    </style>
    """, unsafe_allow_html=True)

# Load Lottie animation
def load_lottiefile(filepath: str):
    with open(filepath, "r") as f:
        return json.load(f)

# Load training history
def load_history(history_path):
    with open(history_path, 'r') as f:
        history = json.load(f)
    return history

# Smooth data
def smooth_data(data, window_size):
    return np.convolve(data, np.ones(window_size)/window_size, mode='valid')

# Streamlit app
# Streamlit app
st.markdown('<h1 class="big-font">TuNNe</h1>', unsafe_allow_html=True)
st.markdown('<h2 class="center-text">Tuning a Neural Network</h2>', unsafe_allow_html=True)
st.markdown('<p class="center-text">This app demonstrates how different hyperparameters affect the training of a neural network using the MNIST dataset (50% of the data).</p>', unsafe_allow_html=True)

# Load and display Lottie animation
lottie_animation = load_lottiefile("Animation - 1719728959093.json")
st_lottie(lottie_animation, height=300, key="header_animation")

# Directory containing models
model_dir = "modelllls"
model_files = [f for f in os.listdir(model_dir) if f.endswith('.json')]

# Extract available hyperparameters from model filenames
def extract_hyperparameters(model_files):
    hyperparameters = []
    for f in model_files:
        parts = f.split('_')
        lr = float(parts[2][2:])
        bs = int(parts[3][2:])
        epochs = int(parts[4][6:].replace('.json', ''))
        hyperparameters.append((lr, bs, epochs))
    return hyperparameters

hyperparameters = extract_hyperparameters(model_files)

# Get unique values for each hyperparameter
learning_rates = sorted(set(lr for lr, _, _ in hyperparameters))

# Select slider for learning rate
st.markdown('<p class="slider-label">Learning Rate</p>', unsafe_allow_html=True)
if len(learning_rates) > 1:
    selected_lr = st.select_slider("Learning Rate", options=learning_rates, label_visibility="collapsed")
else:
    selected_lr = learning_rates[0] if learning_rates else None
    st.write(f"Only one learning rate available: {selected_lr}")

# Filter batch sizes based on selected learning rate
filtered_bs = sorted(set(bs for lr, bs, _ in hyperparameters if lr == selected_lr))
st.markdown('<p class="slider-label">Batch Size</p>', unsafe_allow_html=True)
if len(filtered_bs) > 1:
    selected_bs = st.select_slider("Batch Size", options=filtered_bs, label_visibility="collapsed")
else:
    selected_bs = filtered_bs[0] if filtered_bs else None
    st.write(f"Only one batch size available: {selected_bs}")

# Filter epochs based on selected learning rate and batch size
filtered_epochs = sorted(set(epochs for lr, bs, epochs in hyperparameters if lr == selected_lr and bs == selected_bs))
st.markdown('<p class="slider-label">Epochs</p>', unsafe_allow_html=True)
if len(filtered_epochs) > 1:
    selected_epochs = st.select_slider("Epochs", options=filtered_epochs, label_visibility="collapsed")
else:
    selected_epochs = filtered_epochs[0] if filtered_epochs else None
    st.write(f"Only one epoch option available: {selected_epochs}")

# Options for grid and smoothing
enable_grid = st.checkbox("Enable Grid Lines")
if selected_epochs and selected_epochs > 20:
    smoothing_window = st.slider("Smoothing Window (every 4 epochs)", min_value=1, max_value=5, step=1, value=1)

# Find the corresponding history file
if selected_lr is not None and selected_bs is not None and selected_epochs is not None:
    history_filename = f"mnist_model_lr{selected_lr}_bs{selected_bs}_epochs{selected_epochs}.json"
    history_path = os.path.join(model_dir, history_filename)

    if os.path.exists(history_path):
        history = load_history(history_path)

        # Plot training & validation accuracy values
        fig, ax = plt.subplots()
        accuracy = history['accuracy']
        val_accuracy = history['val_accuracy']
        if selected_epochs > 20 and 'smoothing_window' in locals() and smoothing_window > 1:
            accuracy = smooth_data(accuracy, smoothing_window * 4)
            val_accuracy = smooth_data(val_accuracy, smoothing_window * 4)
        sns.lineplot(x=range(len(accuracy)), y=accuracy, ax=ax, label='Train Accuracy')
        sns.lineplot(x=range(len(val_accuracy)), y=val_accuracy, ax=ax, label='Validation Accuracy')
        ax.set_title('Model Accuracy', fontsize=15)
        ax.set_ylabel('Accuracy', fontsize=12)
        ax.set_xlabel('Epoch', fontsize=12)
        ax.legend(loc='upper left', fontsize=10)
        if enable_grid:
            ax.grid(True)
        st.pyplot(fig)

        # Plot training & validation loss values
        fig, ax = plt.subplots()
        loss = history['loss']
        val_loss = history['val_loss']
        if selected_epochs > 20 and 'smoothing_window' in locals() and smoothing_window > 1:
            loss = smooth_data(loss, smoothing_window * 4)
            val_loss = smooth_data(val_loss, smoothing_window * 4)
        sns.lineplot(x=range(len(loss)), y=loss, ax=ax, label='Train Loss')
        sns.lineplot(x=range(len(val_loss)), y=val_loss, ax=ax, label='Validation Loss')
        ax.set_title('Model Loss', fontsize=15)
        ax.set_ylabel('Loss', fontsize=12)
        ax.set_xlabel('Epoch', fontsize=12)
        ax.legend(loc='upper left', fontsize=10)
        if enable_grid:
            ax.grid(True)
        st.pyplot(fig)
    else:
        st.error(f"History file not found: {history_path}")
else:
    st.error("Unable to load model due to missing hyperparameters")


# Final message
st.markdown("""
    <p class="medium-text">There is no rule of thumb for hyperparameters. The combination varies, and this is just to give an idea or an interactive way of showing how each parameter affects the model training.</p>
    <p class="medium-text">Though the code is generated using AI, the tuning has to be done by a human. πŸ˜‚</p>
    """, unsafe_allow_html=True)