Sephfox commited on
Commit
6a25926
1 Parent(s): 289ccd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -17
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # Imports
2
  import streamlit as st
3
  import numpy as np
4
  import torch
@@ -67,6 +66,36 @@ def setup_cyberpunk_style():
67
  transition: width 0.5s ease;
68
  }
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  </style>
71
  """, unsafe_allow_html=True)
72
 
@@ -196,23 +225,35 @@ def main():
196
  # Load Dataset
197
  train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file)
198
 
199
- # Start Training with Progress Bar
200
- progress_placeholder = st.empty()
201
- st.markdown("### Model Training Progress")
 
 
 
 
202
 
203
- for epoch in range(training_epochs):
204
- train_model(model, train_dataset, tokenizer, epochs=1, batch_size=batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- # Update Progress Bar
207
- progress = (epoch + 1) / training_epochs * 100
208
- progress_placeholder.markdown(f"""
209
- <div class="progress-bar-container">
210
- <div class="progress-bar" style="width: {progress}%;"></div>
211
- </div>
212
- """, unsafe_allow_html=True)
213
-
214
- st.success("Training Complete!")
215
 
216
  if __name__ == "__main__":
217
- main()
218
-
 
 
1
  import streamlit as st
2
  import numpy as np
3
  import torch
 
66
  transition: width 0.5s ease;
67
  }
68
 
69
+ .go-button {
70
+ font-family: 'Orbitron', sans-serif;
71
+ background: linear-gradient(45deg, #00ff9d, #00b8ff);
72
+ color: #000;
73
+ font-size: 1.1em;
74
+ padding: 10px 20px;
75
+ border: none;
76
+ border-radius: 8px;
77
+ transition: all 0.3s ease;
78
+ cursor: pointer;
79
+ }
80
+
81
+ .go-button:hover {
82
+ transform: scale(1.1);
83
+ box-shadow: 0 0 20px rgba(0, 255, 157, 0.5);
84
+ }
85
+
86
+ .loading-animation {
87
+ display: inline-block;
88
+ width: 20px;
89
+ height: 20px;
90
+ border: 3px solid #00ff9d;
91
+ border-radius: 50%;
92
+ border-top-color: transparent;
93
+ animation: spin 1s ease-in-out infinite;
94
+ }
95
+
96
+ @keyframes spin {
97
+ to {transform: rotate(360deg);}
98
+ }
99
  </style>
100
  """, unsafe_allow_html=True)
101
 
 
225
  # Load Dataset
226
  train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file)
227
 
228
+ # Go Button to Start Training
229
+ if st.button("Go"):
230
+ progress_placeholder = st.empty()
231
+ loading_animation = st.empty()
232
+ st.markdown("### Model Training Progress")
233
+
234
+ dashboard = TrainingDashboard()
235
 
236
+ for epoch in range(training_epochs):
237
+ loading_animation.markdown("""
238
+ <div class="loading-animation"></div>
239
+ """, unsafe_allow_html=True)
240
+
241
+ train_model(model, train_dataset, tokenizer, epochs=1, batch_size=batch_size)
242
+
243
+ # Update Progress Bar
244
+ progress = (epoch + 1) / training_epochs * 100
245
+ progress_placeholder.markdown(f"""
246
+ <div class="progress-bar-container">
247
+ <div class="progress-bar" style="width: {progress}%;"></div>
248
+ </div>
249
+ """, unsafe_allow_html=True)
250
+
251
+ dashboard.update(loss=0, generation=epoch + 1, individual=batch_size)
252
 
253
+ loading_animation.empty()
254
+ st.success("Training Complete!")
255
+ st.write("Training Metrics:")
256
+ st.write(dashboard.metrics)
 
 
 
 
 
257
 
258
  if __name__ == "__main__":
259
+ main()