Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
import random | |
import torch | |
import transformers | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling | |
from datasets import Dataset | |
from huggingface_hub import HfApi | |
import os | |
import traceback | |
from contextlib import contextmanager | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from datetime import datetime | |
import time | |
import json | |
import pandas as pd | |
# Advanced Cyberpunk Styling | |
def setup_advanced_cyberpunk_style(): | |
st.markdown(""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap'); | |
@import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap'); | |
.stApp { | |
background: linear-gradient( | |
45deg, | |
rgba(0, 0, 0, 0.9) 0%, | |
rgba(0, 30, 60, 0.9) 50%, | |
rgba(0, 0, 0, 0.9) 100% | |
); | |
color: #00ff9d; | |
} | |
.main-title { | |
font-family: 'Orbitron', sans-serif; | |
background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
text-align: center; | |
font-size: 3.5em; | |
margin-bottom: 30px; | |
text-transform: uppercase; | |
letter-spacing: 3px; | |
animation: glow 2s ease-in-out infinite alternate; | |
} | |
@keyframes glow { | |
from { | |
text-shadow: 0 0 5px #00ff9d, 0 0 10px #00ff9d, 0 0 15px #00ff9d; | |
} | |
to { | |
text-shadow: 0 0 10px #00b8ff, 0 0 20px #00b8ff, 0 0 30px #00b8ff; | |
} | |
} | |
.cyber-box { | |
background: rgba(0, 0, 0, 0.7); | |
border: 2px solid #00ff9d; | |
border-radius: 10px; | |
padding: 20px; | |
margin: 10px 0; | |
position: relative; | |
overflow: hidden; | |
} | |
.cyber-box::before { | |
content: ''; | |
position: absolute; | |
top: -2px; | |
left: -2px; | |
right: -2px; | |
bottom: -2px; | |
background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
z-index: -1; | |
filter: blur(10px); | |
opacity: 0.5; | |
} | |
.metric-container { | |
background: rgba(0, 0, 0, 0.8); | |
border: 2px solid #00ff9d; | |
border-radius: 10px; | |
padding: 20px; | |
margin: 10px 0; | |
position: relative; | |
overflow: hidden; | |
transition: all 0.3s ease; | |
} | |
.metric-container:hover { | |
transform: translateY(-5px); | |
box-shadow: 0 5px 15px rgba(0, 255, 157, 0.3); | |
} | |
.status-text { | |
font-family: 'Share Tech Mono', monospace; | |
color: #00ff9d; | |
font-size: 1.2em; | |
margin: 0; | |
text-shadow: 0 0 5px #00ff9d; | |
} | |
.sidebar .stSelectbox, .sidebar .stSlider { | |
background-color: rgba(0, 0, 0, 0.5); | |
border-radius: 5px; | |
padding: 15px; | |
margin: 10px 0; | |
border: 1px solid #00ff9d; | |
} | |
.stButton>button { | |
font-family: 'Orbitron', sans-serif; | |
background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
color: black; | |
border: none; | |
padding: 15px 30px; | |
border-radius: 5px; | |
text-transform: uppercase; | |
font-weight: bold; | |
letter-spacing: 2px; | |
transition: all 0.3s ease; | |
position: relative; | |
overflow: hidden; | |
} | |
.stButton>button:hover { | |
transform: scale(1.05); | |
box-shadow: 0 0 20px rgba(0, 255, 157, 0.5); | |
} | |
.stButton>button::after { | |
content: ''; | |
position: absolute; | |
top: -50%; | |
left: -50%; | |
width: 200%; | |
height: 200%; | |
background: linear-gradient( | |
45deg, | |
transparent, | |
rgba(255, 255, 255, 0.1), | |
transparent | |
); | |
transform: rotate(45deg); | |
animation: shine 3s infinite; | |
} | |
@keyframes shine { | |
0% { | |
transform: translateX(-100%) rotate(45deg); | |
} | |
100% { | |
transform: translateX(100%) rotate(45deg); | |
} | |
} | |
.custom-info-box { | |
background: rgba(0, 255, 157, 0.1); | |
border-left: 5px solid #00ff9d; | |
padding: 15px; | |
margin: 10px 0; | |
font-family: 'Share Tech Mono', monospace; | |
} | |
.progress-bar-container { | |
width: 100%; | |
height: 30px; | |
background: rgba(0, 0, 0, 0.5); | |
border: 2px solid #00ff9d; | |
border-radius: 15px; | |
overflow: hidden; | |
position: relative; | |
} | |
.progress-bar { | |
height: 100%; | |
background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
transition: width 0.3s ease; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Fixed prepare_dataset function | |
def prepare_dataset(data, tokenizer, block_size=128): | |
with error_handling("dataset preparation"): | |
def tokenize_function(examples): | |
return tokenizer(examples['text'], truncation=True, max_length=block_size, padding='max_length') | |
raw_dataset = Dataset.from_dict({'text': data}) | |
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=['text']) | |
tokenized_dataset = tokenized_dataset.map( | |
lambda examples: {'labels': examples['input_ids']}, | |
batched=True | |
) | |
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) | |
return tokenized_dataset | |
# Advanced Metrics Visualization | |
def create_training_metrics_plot(fitness_history): | |
fig = go.Figure() | |
fig.add_trace(go.Scatter( | |
y=fitness_history, | |
mode='lines+markers', | |
name='Loss', | |
line=dict(color='#00ff9d', width=2), | |
marker=dict(size=8, symbol='diamond'), | |
)) | |
fig.update_layout( | |
title={ | |
'text': 'Training Progress', | |
'y':0.95, | |
'x':0.5, | |
'xanchor': 'center', | |
'yanchor': 'top', | |
'font': {'family': 'Orbitron', 'size': 24, 'color': '#00ff9d'} | |
}, | |
paper_bgcolor='rgba(0,0,0,0.5)', | |
plot_bgcolor='rgba(0,0,0,0.3)', | |
font=dict(family='Share Tech Mono', color='#00ff9d'), | |
xaxis=dict( | |
title='Generation', | |
gridcolor='rgba(0,255,157,0.1)', | |
zerolinecolor='#00ff9d' | |
), | |
yaxis=dict( | |
title='Loss', | |
gridcolor='rgba(0,255,157,0.1)', | |
zerolinecolor='#00ff9d' | |
), | |
hovermode='x unified' | |
) | |
return fig | |
# Advanced Training Dashboard | |
class TrainingDashboard: | |
def __init__(self): | |
self.metrics = { | |
'current_loss': 0, | |
'best_loss': float('inf'), | |
'generation': 0, | |
'individual': 0, | |
'start_time': time.time(), | |
'training_speed': 0 | |
} | |
self.history = [] | |
def update(self, loss, generation, individual): | |
self.metrics['current_loss'] = loss | |
self.metrics['generation'] = generation | |
self.metrics['individual'] = individual | |
if loss < self.metrics['best_loss']: | |
self.metrics['best_loss'] = loss | |
elapsed_time = time.time() - self.metrics['start_time'] | |
self.metrics['training_speed'] = (generation * individual) / elapsed_time | |
self.history.append({ | |
'loss': loss, | |
'timestamp': datetime.now().strftime('%H:%M:%S') | |
}) | |
def display(self): | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown(""" | |
<div class="metric-container"> | |
<h3 style="color: #00ff9d;">Current Status</h3> | |
<p class="status-text">Generation: {}/{}</p> | |
<p class="status-text">Individual: {}/{}</p> | |
</div> | |
""".format( | |
self.metrics['generation'], | |
self.metrics['total_generations'], | |
self.metrics['individual'], | |
self.metrics['population_size'] | |
), unsafe_allow_html=True) | |
with col2: | |
st.markdown(""" | |
<div class="metric-container"> | |
<h3 style="color: #00ff9d;">Performance</h3> | |
<p class="status-text">Current Loss: {:.4f}</p> | |
<p class="status-text">Best Loss: {:.4f}</p> | |
</div> | |
""".format( | |
self.metrics['current_loss'], | |
self.metrics['best_loss'] | |
), unsafe_allow_html=True) | |
with col3: | |
st.markdown(""" | |
<div class="metric-container"> | |
<h3 style="color: #00ff9d;">Training Metrics</h3> | |
<p class="status-text">Speed: {:.2f} iter/s</p> | |
<p class="status-text">Runtime: {:.2f}m</p> | |
</div> | |
""".format( | |
self.metrics['training_speed'], | |
(time.time() - self.metrics['start_time']) / 60 | |
), unsafe_allow_html=True) | |
def main(): | |
setup_advanced_cyberpunk_style() | |
st.markdown('<h1 class="main-title">Neural Evolution GPT-2 Training Hub</h1>', unsafe_allow_html=True) | |
# Initialize dashboard | |
dashboard = TrainingDashboard() | |
# Advanced Sidebar | |
with st.sidebar: | |
st.markdown(""" | |
<div style="text-align: center; padding: 20px;"> | |
<h2 style="font-family: 'Orbitron'; color: #00ff9d;">Control Panel</h2> | |
</div> | |
""", unsafe_allow_html=True) | |
# Configuration Tabs | |
tab1, tab2, tab3 = st.tabs(["π§ Setup", "βοΈ Parameters", "π Monitoring"]) | |
with tab1: | |
hf_token = st.text_input("π HuggingFace Token", type="password") | |
repo_name = st.text_input("π Repository Name", "my-gpt2-model") | |
data_source = st.selectbox('π Data Source', ('DEMO', 'Upload Text File')) | |
with tab2: | |
population_size = st.slider("Population Size", 4, 20, 6) | |
num_generations = st.slider("Generations", 1, 10, 3) | |
num_parents = st.slider("Parents", 2, population_size, 2) | |
mutation_rate = st.slider("Mutation Rate", 0.0, 1.0, 0.1) | |
# Advanced Parameters | |
with st.expander("π¬ Advanced Settings"): | |
learning_rate_min = st.number_input("Min Learning Rate", 1e-6, 1e-4, 1e-5) | |
learning_rate_max = st.number_input("Max Learning Rate", 1e-5, 1e-3, 5e-5) | |
batch_size_options = st.multiselect("Batch Sizes", [2, 4, 8, 16], default=[2, 4, 8]) | |
with tab3: | |
st.markdown(""" | |
<div class="cyber-box"> | |
<h3 style="color: #00ff9d;">System Status</h3> | |
<p>GPU: {}</p> | |
<p>Memory Usage: {:.2f}GB</p> | |
</div> | |
""".format( | |
'CUDA' if torch.cuda.is_available() else 'CPU', | |
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 | |
), unsafe_allow_html=True) | |
# [Rest of your existing main() function code here, integrated with the dashboard] | |
# Make sure to update the dashboard metrics during training | |
# Example of updating dashboard during training: | |
for generation in range(num_generations): | |
for idx, individual in enumerate(population): | |
# Your existing training code | |
fitness = fitness_function(individual, train_dataset, model_clone, tokenizer) | |
dashboard.update(fitness, generation + 1, idx + 1) | |
dashboard.display() | |
# Update progress | |
progress = (generation * len(population) + idx + 1) / (num_generations * len(population)) | |
st.markdown(f""" | |
<div class="progress-bar-container"> | |
<div class="progress-bar" style="width: {progress * 100}%"></div> | |
</div> | |
""", unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() |