CaxtonEmeraldS commited on
Commit
0852299
·
verified ·
1 Parent(s): 43726c1

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +265 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from tensorflow import keras
3
+ import os
4
+ import matplotlib.pyplot as plt
5
+ from io import BytesIO
6
+ from NNVisualiser import NNVisualiser
7
+ import glob
8
+ import inspect
9
+ from tensorflow.keras.models import save_model
10
+ import tempfile
11
+ import re
12
+ import zipfile
13
+ import io
14
+
15
+ # Function to create a ZIP file of all PNG files
16
+ def create_zip_of_png_files():
17
+ # Get current working directory
18
+ cwd = os.getcwd()
19
+ png_files = [f for f in os.listdir(cwd) if f.endswith('.png')]
20
+
21
+ # Create a BytesIO object to hold the ZIP file in memory
22
+ zip_buffer = io.BytesIO()
23
+
24
+ with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
25
+ for png_file in png_files:
26
+ zip_file.write(os.path.join(cwd, png_file), arcname=png_file)
27
+
28
+ zip_buffer.seek(0) # Seek to the beginning of the BytesIO buffer
29
+ return zip_buffer
30
+
31
+ def generate_title_from_method_name(method_name):
32
+ # Remove the "plot" prefix if it exists
33
+ if method_name.startswith("plot"):
34
+ method_name = method_name[4:] # Remove the first 4 characters ("plot")
35
+
36
+ # Split the string at camel case boundaries
37
+ words = re.findall(r'[A-Z][a-z]*', method_name)
38
+
39
+ # Join the words with spaces and format the final string
40
+ title = "Plotting " + " ".join(words[:]) + " Plot "
41
+
42
+ return title
43
+
44
+ def downloadKerasModel():
45
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".keras") as tmp_file:
46
+ save_model(model, tmp_file.name)
47
+ tmp_file.seek(0)
48
+ model_data = tmp_file.read()
49
+ return model_data
50
+
51
+ # Function to build folder hierarchy up to the 6th level (excluding files and hidden folders)
52
+ @st.cache_data
53
+ def generate_folder_hierarchy(root_folder, max_depth=6):
54
+ folder_dict = {}
55
+
56
+ # Traverse through the directory tree
57
+ for dirpath, dirnames, filenames in os.walk(root_folder):
58
+ # Get the relative path from the root folder
59
+ rel_path = os.path.relpath(dirpath, root_folder)
60
+ depth = rel_path.count(os.sep) + 1 # Calculate the depth level
61
+
62
+ # Only include directories up to the max_depth (7th level)
63
+ if depth > max_depth:
64
+ continue
65
+
66
+ # Filter out directories that start with a dot (e.g., .git)
67
+ dirnames[:] = [d for d in dirnames if not d.startswith('.') and d != '1']
68
+
69
+ sub_dict = folder_dict
70
+ # Split the relative path into parts to create a nested structure
71
+ for part in rel_path.split(os.sep):
72
+ if part == '.' or part.startswith('.'):
73
+ continue
74
+ if part not in sub_dict:
75
+ sub_dict[part] = {}
76
+ sub_dict = sub_dict[part]
77
+
78
+ return folder_dict
79
+
80
+ @st.cache_data
81
+ def getPlotMethods():
82
+ return [name for name, func in inspect.getmembers(NNVisualiser, inspect.isfunction) if name.startswith('plot')]
83
+
84
+ # Example usage
85
+ root_folder = os.getcwd(); # Replace with your folder path
86
+ folder_hierarchy = generate_folder_hierarchy(root_folder)
87
+
88
+ # Streamlit app
89
+ st.title("Repository : Simple ANN Models with UAT Architecture")
90
+ st.write(f"A Collection of ANN Models with a 1-xReLU-1 Architecture for Basic 1D Functions on Bounded Intervals")
91
+ #Commented
92
+
93
+ # col1, col2, col3 = st.columns([4, 3, 3])
94
+
95
+ # with col1:
96
+ # # Level 1: Initialisation dropdown
97
+ # initialisation = st.selectbox("Select Initialisation", list(folder_hierarchy.keys()))
98
+
99
+ # with col2:
100
+ # # Level 2: Sample size dropdown, based on selected initialisation
101
+ # sampleSize = st.selectbox("Select Sample Size", list(folder_hierarchy[initialisation].keys()))
102
+
103
+ # with col3:
104
+ # # Level 3: Batch size dropdown, based on selected sample size
105
+ # batchSize = st.selectbox("Select Batch Size", list(folder_hierarchy[initialisation][sampleSize].keys()))
106
+
107
+
108
+ # col4, col5, col6 = st.columns([3, 4, 3])
109
+
110
+ # with col4:
111
+ # # Level 4: Epochs count dropdown, based on selected batch size
112
+ # epochs = st.selectbox("Select Epochs Count", list(folder_hierarchy[initialisation][sampleSize][batchSize].keys()))
113
+
114
+ # with col5:
115
+ # # Level 5: Functions list dropdown, based on selected epochs count
116
+ # functions = st.selectbox("Select Neurons Count", list(folder_hierarchy[initialisation][sampleSize][batchSize][epochs].keys()))
117
+
118
+ # with col6:
119
+ # # Level 6: Neurons count dropdown, based on selected function
120
+ # neurons = st.selectbox("Select Neurons Count", list(folder_hierarchy[initialisation][sampleSize][batchSize][epochs][functions].keys()))
121
+
122
+
123
+ initialisation = st.sidebar.selectbox("Select Initialisation", list(folder_hierarchy.keys()))
124
+ sampleSize = st.sidebar.selectbox("Select Sample Size", list(folder_hierarchy[initialisation].keys()))
125
+ batchSize = st.sidebar.selectbox("Select Batch Size", list(folder_hierarchy[initialisation][sampleSize].keys()))
126
+ epochs = st.sidebar.selectbox("Select Epochs Count", list(folder_hierarchy[initialisation][sampleSize][batchSize].keys()))
127
+ functions = st.sidebar.selectbox("Select Neurons Count", list(folder_hierarchy[initialisation][sampleSize][batchSize][epochs].keys()))
128
+ neurons = st.sidebar.selectbox("Select Neurons Count", list(folder_hierarchy[initialisation][sampleSize][batchSize][epochs][functions].keys()))
129
+
130
+ # Display the selected values
131
+ st.write(f"You selected: {initialisation} : {sampleSize} : {batchSize} : {epochs} : {functions} : {neurons}")
132
+
133
+ modelPath = os.path.join(os.getcwd(), initialisation, sampleSize, batchSize, epochs, functions, neurons);
134
+ model = keras.models.load_model(modelPath);
135
+
136
+ visualiser = NNVisualiser(model);
137
+ visualiser.setSavePlots(True);
138
+
139
+ # Function to get layer and neuron information
140
+ def get_layer_info(model):
141
+ layer_info = []
142
+ for layer in model.layers:
143
+ layer_info.append({
144
+ 'index': len(layer_info),
145
+ 'type': layer.__class__.__name__,
146
+ 'units': getattr(layer, 'units', None), # Number of neurons
147
+ })
148
+ return layer_info
149
+
150
+ layer_info = get_layer_info(model)
151
+
152
+ # Extract layer indices and neuron counts
153
+ layer_indices = [layer['index'] for layer in layer_info]
154
+ neuron_counts = [layer['units'] for layer in layer_info]
155
+
156
+ # Dropdown for selecting layer index
157
+ #selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
158
+
159
+ # Find the number of neurons for the selected layer
160
+ #selected_layer_units = neuron_counts[selected_layer_index]
161
+
162
+ # Dropdown for selecting neuron index in the selected layer
163
+ #neuron_indices = list(range(selected_layer_units))
164
+ #selected_neuron_index = st.sidebar.selectbox("Select Neuron Index", neuron_indices)
165
+
166
+ # Dropdown for selecting plots from NNVisualiser
167
+ plotMethods = getPlotMethods()
168
+ selectedPlotMethod = st.sidebar.selectbox("Select Plot", plotMethods)
169
+
170
+ #Removing earlier plots
171
+ image_files = glob.glob("*.png")
172
+ for file in image_files:
173
+ try:
174
+ os.remove(file)
175
+ except Exception as e:
176
+ st.write("Error in removing previous plots")
177
+
178
+ st.session_state.title_text = generate_title_from_method_name(selectedPlotMethod)
179
+ st.title(st.session_state.title_text)
180
+
181
+ # Call your package's plot method (which directly plots without returning a figure)
182
+ visualiser.setSavePlots(True);
183
+ method = getattr(visualiser, selectedPlotMethod, None)
184
+
185
+ if method is not None:
186
+ if 'Neuron' in selectedPlotMethod:
187
+ selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
188
+ # Find the number of neurons for the selected layer
189
+ selected_layer_units = neuron_counts[selected_layer_index]
190
+ # Dropdown for selecting neuron index in the selected layer
191
+ neuron_indices = list(range(selected_layer_units))
192
+ selected_neuron_index = st.sidebar.selectbox("Select Neuron Index", neuron_indices)
193
+ params = (selected_layer_index, selected_neuron_index)
194
+ method(*params)
195
+ elif 'Layer' in selectedPlotMethod:
196
+ selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices)
197
+ params = (selected_layer_index,)
198
+ method(*params)
199
+ else:
200
+ method()
201
+
202
+ st.session_state.kerasModelToDownload = downloadKerasModel()
203
+ st.session_state.plotsToDownload = create_zip_of_png_files()
204
+
205
+ @st.fragment()
206
+ def downloads():
207
+ st.download_button(
208
+ label="Download Model",
209
+ data = downloadKerasModel(),
210
+ file_name="model.keras",
211
+ mime="application/octet-stream"
212
+ );
213
+
214
+ st.download_button(
215
+ label="Download Plots",
216
+ data=create_zip_of_png_files(),
217
+ file_name="images.zip",
218
+ mime="application/zip"
219
+ );
220
+ # column = st.columns (2)
221
+
222
+ # column[0].download_button(
223
+ # label="Download Model",
224
+ # data = downloadKerasModel(),
225
+ # file_name="model.keras",
226
+ # mime="application/octet-stream"
227
+ # );
228
+
229
+ # column[1].download_button(
230
+ # label="Download Plots",
231
+ # data=create_zip_of_png_files(),
232
+ # file_name="images.zip",
233
+ # mime="application/zip"
234
+ # );
235
+
236
+ with st.sidebar:
237
+ downloads()
238
+
239
+ # visualiser.plotFlowForNetwork();
240
+
241
+ image_files = glob.glob("*.png")
242
+
243
+ # Use Streamlit to display the image from the buffer
244
+ st.image(image_files)
245
+
246
+ # if st.sidebar.button("Download Keras model"):
247
+ # downloadKerasModel()
248
+
249
+
250
+ # if st.sidebar.download_button(
251
+ # label="Download Keras Model",
252
+ # data = downloadKerasModel(),
253
+ # file_name="model.keras",
254
+ # mime="application/octet-stream"
255
+ # ):
256
+ # st.sidebar.success(f"Model Downloaded Successfully")
257
+
258
+ # # Button to create and download the ZIP file
259
+ # if st.sidebar.download_button(
260
+ # label="Download Plots",
261
+ # data=create_zip_of_png_files(),
262
+ # file_name="images.zip",
263
+ # mime="application/zip"
264
+ # ):
265
+ # st.sidebar.success(f"Plots Downloaded Successfully")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy==1.23.5
2
+ keras==2.14.0
3
+ matplotlib==3.7.1
4
+ tensorflow==2.14.0
5
+ NeuralNetworkCoordinates==1.0.0
6
+ NNVisualiser==1.0.0