Sephfox commited on
Commit
f5b3aed
1 Parent(s): e95cda2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -327
app.py CHANGED
@@ -1,20 +1,11 @@
1
  import streamlit as st
2
  import numpy as np
3
- import random
4
  import torch
5
- import transformers
6
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
7
  from datasets import Dataset
8
- from huggingface_hub import HfApi
9
- import os
10
- import traceback
11
- from contextlib import contextmanager
12
- import plotly.graph_objects as go
13
- import plotly.express as px
14
- from datetime import datetime
15
  import time
16
- import json
17
- import pandas as pd
18
 
19
  # Advanced Cyberpunk Styling
20
  def setup_advanced_cyberpunk_style():
@@ -22,352 +13,113 @@ def setup_advanced_cyberpunk_style():
22
  <style>
23
  @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
24
  @import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap');
25
-
26
- .stApp {
27
- background: linear-gradient(
28
- 45deg,
29
- rgba(0, 0, 0, 0.9) 0%,
30
- rgba(0, 30, 60, 0.9) 50%,
31
- rgba(0, 0, 0, 0.9) 100%
32
- );
33
- color: #00ff9d;
34
- }
35
-
36
- .main-title {
37
- font-family: 'Orbitron', sans-serif;
38
- background: linear-gradient(45deg, #00ff9d, #00b8ff);
39
- -webkit-background-clip: text;
40
- -webkit-text-fill-color: transparent;
41
- text-align: center;
42
- font-size: 3.5em;
43
- margin-bottom: 30px;
44
- text-transform: uppercase;
45
- letter-spacing: 3px;
46
- animation: glow 2s ease-in-out infinite alternate;
47
- }
48
-
49
- @keyframes glow {
50
- from {
51
- text-shadow: 0 0 5px #00ff9d, 0 0 10px #00ff9d, 0 0 15px #00ff9d;
52
- }
53
- to {
54
- text-shadow: 0 0 10px #00b8ff, 0 0 20px #00b8ff, 0 0 30px #00b8ff;
55
- }
56
- }
57
-
58
- .cyber-box {
59
- background: rgba(0, 0, 0, 0.7);
60
- border: 2px solid #00ff9d;
61
- border-radius: 10px;
62
- padding: 20px;
63
- margin: 10px 0;
64
- position: relative;
65
- overflow: hidden;
66
- }
67
-
68
- .cyber-box::before {
69
- content: '';
70
- position: absolute;
71
- top: -2px;
72
- left: -2px;
73
- right: -2px;
74
- bottom: -2px;
75
- background: linear-gradient(45deg, #00ff9d, #00b8ff);
76
- z-index: -1;
77
- filter: blur(10px);
78
- opacity: 0.5;
79
- }
80
-
81
- .metric-container {
82
- background: rgba(0, 0, 0, 0.8);
83
- border: 2px solid #00ff9d;
84
- border-radius: 10px;
85
- padding: 20px;
86
- margin: 10px 0;
87
- position: relative;
88
- overflow: hidden;
89
- transition: all 0.3s ease;
90
- }
91
-
92
- .metric-container:hover {
93
- transform: translateY(-5px);
94
- box-shadow: 0 5px 15px rgba(0, 255, 157, 0.3);
95
- }
96
-
97
- .status-text {
98
- font-family: 'Share Tech Mono', monospace;
99
- color: #00ff9d;
100
- font-size: 1.2em;
101
- margin: 0;
102
- text-shadow: 0 0 5px #00ff9d;
103
- }
104
-
105
- .sidebar .stSelectbox, .sidebar .stSlider {
106
- background-color: rgba(0, 0, 0, 0.5);
107
- border-radius: 5px;
108
- padding: 15px;
109
- margin: 10px 0;
110
- border: 1px solid #00ff9d;
111
- }
112
-
113
- .stButton>button {
114
- font-family: 'Orbitron', sans-serif;
115
- background: linear-gradient(45deg, #00ff9d, #00b8ff);
116
- color: black;
117
- border: none;
118
- padding: 15px 30px;
119
- border-radius: 5px;
120
- text-transform: uppercase;
121
- font-weight: bold;
122
- letter-spacing: 2px;
123
- transition: all 0.3s ease;
124
- position: relative;
125
- overflow: hidden;
126
- }
127
-
128
- .stButton>button:hover {
129
- transform: scale(1.05);
130
- box-shadow: 0 0 20px rgba(0, 255, 157, 0.5);
131
- }
132
-
133
- .stButton>button::after {
134
- content: '';
135
- position: absolute;
136
- top: -50%;
137
- left: -50%;
138
- width: 200%;
139
- height: 200%;
140
- background: linear-gradient(
141
- 45deg,
142
- transparent,
143
- rgba(255, 255, 255, 0.1),
144
- transparent
145
- );
146
- transform: rotate(45deg);
147
- animation: shine 3s infinite;
148
- }
149
-
150
- @keyframes shine {
151
- 0% {
152
- transform: translateX(-100%) rotate(45deg);
153
- }
154
- 100% {
155
- transform: translateX(100%) rotate(45deg);
156
- }
157
- }
158
-
159
- .custom-info-box {
160
- background: rgba(0, 255, 157, 0.1);
161
- border-left: 5px solid #00ff9d;
162
- padding: 15px;
163
- margin: 10px 0;
164
- font-family: 'Share Tech Mono', monospace;
165
- }
166
-
167
- .progress-bar-container {
168
- width: 100%;
169
- height: 30px;
170
- background: rgba(0, 0, 0, 0.5);
171
- border: 2px solid #00ff9d;
172
- border-radius: 15px;
173
- overflow: hidden;
174
- position: relative;
175
- }
176
-
177
- .progress-bar {
178
- height: 100%;
179
- background: linear-gradient(45deg, #00ff9d, #00b8ff);
180
- transition: width 0.3s ease;
181
- }
182
  </style>
183
  """, unsafe_allow_html=True)
184
 
185
- # Fixed prepare_dataset function
186
- def prepare_dataset(data, tokenizer, block_size=128):
187
- with error_handling("dataset preparation"):
188
- def tokenize_function(examples):
189
- return tokenizer(examples['text'], truncation=True, max_length=block_size, padding='max_length')
190
 
191
- raw_dataset = Dataset.from_dict({'text': data})
192
- tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=['text'])
193
- tokenized_dataset = tokenized_dataset.map(
194
- lambda examples: {'labels': examples['input_ids']},
195
- batched=True
196
- )
197
- tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
198
- return tokenized_dataset
199
 
200
- # Advanced Metrics Visualization
201
- def create_training_metrics_plot(fitness_history):
202
- fig = go.Figure()
203
- fig.add_trace(go.Scatter(
204
- y=fitness_history,
205
- mode='lines+markers',
206
- name='Loss',
207
- line=dict(color='#00ff9d', width=2),
208
- marker=dict(size=8, symbol='diamond'),
209
- ))
210
-
211
- fig.update_layout(
212
- title={
213
- 'text': 'Training Progress',
214
- 'y':0.95,
215
- 'x':0.5,
216
- 'xanchor': 'center',
217
- 'yanchor': 'top',
218
- 'font': {'family': 'Orbitron', 'size': 24, 'color': '#00ff9d'}
219
- },
220
- paper_bgcolor='rgba(0,0,0,0.5)',
221
- plot_bgcolor='rgba(0,0,0,0.3)',
222
- font=dict(family='Share Tech Mono', color='#00ff9d'),
223
- xaxis=dict(
224
- title='Generation',
225
- gridcolor='rgba(0,255,157,0.1)',
226
- zerolinecolor='#00ff9d'
227
- ),
228
- yaxis=dict(
229
- title='Loss',
230
- gridcolor='rgba(0,255,157,0.1)',
231
- zerolinecolor='#00ff9d'
232
- ),
233
- hovermode='x unified'
234
  )
235
- return fig
 
236
 
237
- # Advanced Training Dashboard
238
  class TrainingDashboard:
239
  def __init__(self):
240
  self.metrics = {
241
  'current_loss': 0,
242
  'best_loss': float('inf'),
243
  'generation': 0,
244
- 'individual': 0,
245
  'start_time': time.time(),
246
  'training_speed': 0
247
  }
248
  self.history = []
249
 
250
- def update(self, loss, generation, individual):
251
  self.metrics['current_loss'] = loss
252
  self.metrics['generation'] = generation
253
- self.metrics['individual'] = individual
254
  if loss < self.metrics['best_loss']:
255
  self.metrics['best_loss'] = loss
256
-
257
  elapsed_time = time.time() - self.metrics['start_time']
258
- self.metrics['training_speed'] = (generation * individual) / elapsed_time
259
- self.history.append({
260
- 'loss': loss,
261
- 'timestamp': datetime.now().strftime('%H:%M:%S')
262
- })
263
 
264
  def display(self):
265
- col1, col2, col3 = st.columns(3)
266
-
267
- with col1:
268
- st.markdown("""
269
- <div class="metric-container">
270
- <h3 style="color: #00ff9d;">Current Status</h3>
271
- <p class="status-text">Generation: {}/{}</p>
272
- <p class="status-text">Individual: {}/{}</p>
273
- </div>
274
- """.format(
275
- self.metrics['generation'],
276
- self.metrics['total_generations'],
277
- self.metrics['individual'],
278
- self.metrics['population_size']
279
- ), unsafe_allow_html=True)
280
-
281
- with col2:
282
- st.markdown("""
283
- <div class="metric-container">
284
- <h3 style="color: #00ff9d;">Performance</h3>
285
- <p class="status-text">Current Loss: {:.4f}</p>
286
- <p class="status-text">Best Loss: {:.4f}</p>
287
- </div>
288
- """.format(
289
- self.metrics['current_loss'],
290
- self.metrics['best_loss']
291
- ), unsafe_allow_html=True)
292
-
293
- with col3:
294
- st.markdown("""
295
- <div class="metric-container">
296
- <h3 style="color: #00ff9d;">Training Metrics</h3>
297
- <p class="status-text">Speed: {:.2f} iter/s</p>
298
- <p class="status-text">Runtime: {:.2f}m</p>
299
- </div>
300
- """.format(
301
- self.metrics['training_speed'],
302
- (time.time() - self.metrics['start_time']) / 60
303
- ), unsafe_allow_html=True)
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  def main():
306
  setup_advanced_cyberpunk_style()
307
-
308
  st.markdown('<h1 class="main-title">Neural Evolution GPT-2 Training Hub</h1>', unsafe_allow_html=True)
309
-
310
- # Initialize dashboard
311
- dashboard = TrainingDashboard()
312
-
313
- # Advanced Sidebar
314
- with st.sidebar:
315
- st.markdown("""
316
- <div style="text-align: center; padding: 20px;">
317
- <h2 style="font-family: 'Orbitron'; color: #00ff9d;">Control Panel</h2>
318
- </div>
319
- """, unsafe_allow_html=True)
320
-
321
- # Configuration Tabs
322
- tab1, tab2, tab3 = st.tabs(["🔧 Setup", "⚙️ Parameters", "📊 Monitoring"])
323
-
324
- with tab1:
325
- hf_token = st.text_input("🔑 HuggingFace Token", type="password")
326
- repo_name = st.text_input("📁 Repository Name", "my-gpt2-model")
327
- data_source = st.selectbox('📊 Data Source', ('DEMO', 'Upload Text File'))
328
-
329
- with tab2:
330
- population_size = st.slider("Population Size", 4, 20, 6)
331
- num_generations = st.slider("Generations", 1, 10, 3)
332
- num_parents = st.slider("Parents", 2, population_size, 2)
333
- mutation_rate = st.slider("Mutation Rate", 0.0, 1.0, 0.1)
334
-
335
- # Advanced Parameters
336
- with st.expander("🔬 Advanced Settings"):
337
- learning_rate_min = st.number_input("Min Learning Rate", 1e-6, 1e-4, 1e-5)
338
- learning_rate_max = st.number_input("Max Learning Rate", 1e-5, 1e-3, 5e-5)
339
- batch_size_options = st.multiselect("Batch Sizes", [2, 4, 8, 16], default=[2, 4, 8])
340
-
341
- with tab3:
342
- st.markdown("""
343
- <div class="cyber-box">
344
- <h3 style="color: #00ff9d;">System Status</h3>
345
- <p>GPU: {}</p>
346
- <p>Memory Usage: {:.2f}GB</p>
347
- </div>
348
- """.format(
349
- 'CUDA' if torch.cuda.is_available() else 'CPU',
350
- torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
351
- ), unsafe_allow_html=True)
352
 
353
- # [Rest of your existing main() function code here, integrated with the dashboard]
354
- # Make sure to update the dashboard metrics during training
 
 
 
 
355
 
356
- # Example of updating dashboard during training:
357
- for generation in range(num_generations):
358
- for idx, individual in enumerate(population):
359
- # Your existing training code
360
- fitness = fitness_function(individual, train_dataset, model_clone, tokenizer)
361
- dashboard.update(fitness, generation + 1, idx + 1)
362
- dashboard.display()
363
-
364
- # Update progress
365
- progress = (generation * len(population) + idx + 1) / (num_generations * len(population))
366
- st.markdown(f"""
367
- <div class="progress-bar-container">
368
- <div class="progress-bar" style="width: {progress * 100}%"></div>
369
- </div>
370
- """, unsafe_allow_html=True)
371
 
372
  if __name__ == "__main__":
373
- main()
 
1
  import streamlit as st
2
  import numpy as np
 
3
  import torch
 
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
5
  from datasets import Dataset
 
 
 
 
 
 
 
6
  import time
7
+ from datetime import datetime
8
+ import plotly.graph_objects as go
9
 
10
  # Advanced Cyberpunk Styling
11
  def setup_advanced_cyberpunk_style():
 
13
  <style>
14
  @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
15
  @import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap');
16
+ /* Additional styling as provided previously */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  </style>
18
  """, unsafe_allow_html=True)
19
 
20
+ # Initialize Model and Tokenizer
21
+ def initialize_model():
22
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
23
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
24
+ return model, tokenizer
25
 
26
+ # Prepare Dataset
27
+ def prepare_dataset(data, tokenizer, block_size=128):
28
+ def tokenize_function(examples):
29
+ return tokenizer(examples['text'], truncation=True, max_length=block_size, padding='max_length')
 
 
 
 
30
 
31
+ raw_dataset = Dataset.from_dict({'text': data})
32
+ tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=['text'])
33
+ tokenized_dataset = tokenized_dataset.map(
34
+ lambda examples: {'labels': examples['input_ids']},
35
+ batched=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
+ tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
38
+ return tokenized_dataset
39
 
40
+ # Training Dashboard Class
41
  class TrainingDashboard:
42
  def __init__(self):
43
  self.metrics = {
44
  'current_loss': 0,
45
  'best_loss': float('inf'),
46
  'generation': 0,
 
47
  'start_time': time.time(),
48
  'training_speed': 0
49
  }
50
  self.history = []
51
 
52
+ def update(self, loss, generation):
53
  self.metrics['current_loss'] = loss
54
  self.metrics['generation'] = generation
 
55
  if loss < self.metrics['best_loss']:
56
  self.metrics['best_loss'] = loss
 
57
  elapsed_time = time.time() - self.metrics['start_time']
58
+ self.metrics['training_speed'] = generation / elapsed_time
59
+ self.history.append({'loss': loss, 'timestamp': datetime.now().strftime('%H:%M:%S')})
 
 
 
60
 
61
  def display(self):
62
+ st.write(f"**Generation:** {self.metrics['generation']}")
63
+ st.write(f"**Current Loss:** {self.metrics['current_loss']:.4f}")
64
+ st.write(f"**Best Loss:** {self.metrics['best_loss']:.4f}")
65
+ st.write(f"**Training Speed:** {self.metrics['training_speed']:.2f} generations/sec")
66
+
67
+ # Display Progress Bar
68
+ def display_progress(progress):
69
+ st.markdown(f"""
70
+ <div class="progress-bar-container">
71
+ <div class="progress-bar" style="width: {progress * 100}%"></div>
72
+ </div>
73
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # Fitness Calculation (Placeholder for actual loss computation)
76
+ def compute_loss(model, dataset):
77
+ # Placeholder for real loss computation with Trainer API or custom logic
78
+ trainer = Trainer(
79
+ model=model,
80
+ args=TrainingArguments(output_dir="./results", per_device_train_batch_size=2, num_train_epochs=1),
81
+ train_dataset=dataset,
82
+ data_collator=DataCollatorForLanguageModeling(tokenizer=model.config._name_or_path, mlm=False),
83
+ )
84
+ train_result = trainer.train()
85
+ return train_result.training_loss
86
+
87
+ # Training Loop with Loading Screen
88
+ def training_loop(dashboard, model, dataset, num_generations, population_size):
89
+ with st.spinner("Training in progress..."):
90
+ for generation in range(1, num_generations + 1):
91
+ # Simulated population loop
92
+ for individual in range(population_size):
93
+ loss = compute_loss(model, dataset)
94
+ dashboard.update(loss, generation)
95
+ progress = generation / num_generations
96
+ display_progress(progress)
97
+ dashboard.display()
98
+ time.sleep(1) # Simulate delay for each individual training
99
+
100
+ # Main Function
101
  def main():
102
  setup_advanced_cyberpunk_style()
 
103
  st.markdown('<h1 class="main-title">Neural Evolution GPT-2 Training Hub</h1>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # Load Model and Tokenizer
106
+ model, tokenizer = initialize_model()
107
+
108
+ # Prepare Data
109
+ data = ["Sample training text"] * 10 # Replace with real data
110
+ train_dataset = prepare_dataset(data, tokenizer)
111
 
112
+ # Initialize Dashboard
113
+ dashboard = TrainingDashboard()
114
+
115
+ # Sidebar Configuration
116
+ st.sidebar.markdown("### Training Parameters")
117
+ num_generations = st.sidebar.slider("Generations", 1, 20, 5)
118
+ population_size = st.sidebar.slider("Population Size", 4, 20, 6)
119
+
120
+ # Run Training
121
+ if st.button("Start Training"):
122
+ training_loop(dashboard, model, train_dataset, num_generations, population_size)
 
 
 
 
123
 
124
  if __name__ == "__main__":
125
+ main()