E.L.N / app.py
Sephfox's picture
Update app.py
ba908ff verified
raw
history blame
12.8 kB
import streamlit as st
import numpy as np
import random
import torch
import transformers
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset
from huggingface_hub import HfApi
import os
import traceback
from contextlib import contextmanager
import plotly.graph_objects as go
import plotly.express as px
from datetime import datetime
import time
import json
import pandas as pd
# Advanced Cyberpunk Styling
def setup_advanced_cyberpunk_style():
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
@import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap');
.stApp {
background: linear-gradient(
45deg,
rgba(0, 0, 0, 0.9) 0%,
rgba(0, 30, 60, 0.9) 50%,
rgba(0, 0, 0, 0.9) 100%
);
color: #00ff9d;
}
.main-title {
font-family: 'Orbitron', sans-serif;
background: linear-gradient(45deg, #00ff9d, #00b8ff);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
text-align: center;
font-size: 3.5em;
margin-bottom: 30px;
text-transform: uppercase;
letter-spacing: 3px;
animation: glow 2s ease-in-out infinite alternate;
}
@keyframes glow {
from {
text-shadow: 0 0 5px #00ff9d, 0 0 10px #00ff9d, 0 0 15px #00ff9d;
}
to {
text-shadow: 0 0 10px #00b8ff, 0 0 20px #00b8ff, 0 0 30px #00b8ff;
}
}
.cyber-box {
background: rgba(0, 0, 0, 0.7);
border: 2px solid #00ff9d;
border-radius: 10px;
padding: 20px;
margin: 10px 0;
position: relative;
overflow: hidden;
}
.cyber-box::before {
content: '';
position: absolute;
top: -2px;
left: -2px;
right: -2px;
bottom: -2px;
background: linear-gradient(45deg, #00ff9d, #00b8ff);
z-index: -1;
filter: blur(10px);
opacity: 0.5;
}
.metric-container {
background: rgba(0, 0, 0, 0.8);
border: 2px solid #00ff9d;
border-radius: 10px;
padding: 20px;
margin: 10px 0;
position: relative;
overflow: hidden;
transition: all 0.3s ease;
}
.metric-container:hover {
transform: translateY(-5px);
box-shadow: 0 5px 15px rgba(0, 255, 157, 0.3);
}
.status-text {
font-family: 'Share Tech Mono', monospace;
color: #00ff9d;
font-size: 1.2em;
margin: 0;
text-shadow: 0 0 5px #00ff9d;
}
.sidebar .stSelectbox, .sidebar .stSlider {
background-color: rgba(0, 0, 0, 0.5);
border-radius: 5px;
padding: 15px;
margin: 10px 0;
border: 1px solid #00ff9d;
}
.stButton>button {
font-family: 'Orbitron', sans-serif;
background: linear-gradient(45deg, #00ff9d, #00b8ff);
color: black;
border: none;
padding: 15px 30px;
border-radius: 5px;
text-transform: uppercase;
font-weight: bold;
letter-spacing: 2px;
transition: all 0.3s ease;
position: relative;
overflow: hidden;
}
.stButton>button:hover {
transform: scale(1.05);
box-shadow: 0 0 20px rgba(0, 255, 157, 0.5);
}
.stButton>button::after {
content: '';
position: absolute;
top: -50%;
left: -50%;
width: 200%;
height: 200%;
background: linear-gradient(
45deg,
transparent,
rgba(255, 255, 255, 0.1),
transparent
);
transform: rotate(45deg);
animation: shine 3s infinite;
}
@keyframes shine {
0% {
transform: translateX(-100%) rotate(45deg);
}
100% {
transform: translateX(100%) rotate(45deg);
}
}
.custom-info-box {
background: rgba(0, 255, 157, 0.1);
border-left: 5px solid #00ff9d;
padding: 15px;
margin: 10px 0;
font-family: 'Share Tech Mono', monospace;
}
.progress-bar-container {
width: 100%;
height: 30px;
background: rgba(0, 0, 0, 0.5);
border: 2px solid #00ff9d;
border-radius: 15px;
overflow: hidden;
position: relative;
}
.progress-bar {
height: 100%;
background: linear-gradient(45deg, #00ff9d, #00b8ff);
transition: width 0.3s ease;
}
</style>
""", unsafe_allow_html=True)
# Fixed prepare_dataset function
def prepare_dataset(data, tokenizer, block_size=128):
with error_handling("dataset preparation"):
def tokenize_function(examples):
return tokenizer(examples['text'], truncation=True, max_length=block_size, padding='max_length')
raw_dataset = Dataset.from_dict({'text': data})
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=['text'])
tokenized_dataset = tokenized_dataset.map(
lambda examples: {'labels': examples['input_ids']},
batched=True
)
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
return tokenized_dataset
# Advanced Metrics Visualization
def create_training_metrics_plot(fitness_history):
fig = go.Figure()
fig.add_trace(go.Scatter(
y=fitness_history,
mode='lines+markers',
name='Loss',
line=dict(color='#00ff9d', width=2),
marker=dict(size=8, symbol='diamond'),
))
fig.update_layout(
title={
'text': 'Training Progress',
'y':0.95,
'x':0.5,
'xanchor': 'center',
'yanchor': 'top',
'font': {'family': 'Orbitron', 'size': 24, 'color': '#00ff9d'}
},
paper_bgcolor='rgba(0,0,0,0.5)',
plot_bgcolor='rgba(0,0,0,0.3)',
font=dict(family='Share Tech Mono', color='#00ff9d'),
xaxis=dict(
title='Generation',
gridcolor='rgba(0,255,157,0.1)',
zerolinecolor='#00ff9d'
),
yaxis=dict(
title='Loss',
gridcolor='rgba(0,255,157,0.1)',
zerolinecolor='#00ff9d'
),
hovermode='x unified'
)
return fig
# Advanced Training Dashboard
class TrainingDashboard:
def __init__(self):
self.metrics = {
'current_loss': 0,
'best_loss': float('inf'),
'generation': 0,
'individual': 0,
'start_time': time.time(),
'training_speed': 0
}
self.history = []
def update(self, loss, generation, individual):
self.metrics['current_loss'] = loss
self.metrics['generation'] = generation
self.metrics['individual'] = individual
if loss < self.metrics['best_loss']:
self.metrics['best_loss'] = loss
elapsed_time = time.time() - self.metrics['start_time']
self.metrics['training_speed'] = (generation * individual) / elapsed_time
self.history.append({
'loss': loss,
'timestamp': datetime.now().strftime('%H:%M:%S')
})
def display(self):
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("""
<div class="metric-container">
<h3 style="color: #00ff9d;">Current Status</h3>
<p class="status-text">Generation: {}/{}</p>
<p class="status-text">Individual: {}/{}</p>
</div>
""".format(
self.metrics['generation'],
self.metrics['total_generations'],
self.metrics['individual'],
self.metrics['population_size']
), unsafe_allow_html=True)
with col2:
st.markdown("""
<div class="metric-container">
<h3 style="color: #00ff9d;">Performance</h3>
<p class="status-text">Current Loss: {:.4f}</p>
<p class="status-text">Best Loss: {:.4f}</p>
</div>
""".format(
self.metrics['current_loss'],
self.metrics['best_loss']
), unsafe_allow_html=True)
with col3:
st.markdown("""
<div class="metric-container">
<h3 style="color: #00ff9d;">Training Metrics</h3>
<p class="status-text">Speed: {:.2f} iter/s</p>
<p class="status-text">Runtime: {:.2f}m</p>
</div>
""".format(
self.metrics['training_speed'],
(time.time() - self.metrics['start_time']) / 60
), unsafe_allow_html=True)
def main():
setup_advanced_cyberpunk_style()
st.markdown('<h1 class="main-title">Neural Evolution GPT-2 Training Hub</h1>', unsafe_allow_html=True)
# Initialize dashboard
dashboard = TrainingDashboard()
# Advanced Sidebar
with st.sidebar:
st.markdown("""
<div style="text-align: center; padding: 20px;">
<h2 style="font-family: 'Orbitron'; color: #00ff9d;">Control Panel</h2>
</div>
""", unsafe_allow_html=True)
# Configuration Tabs
tab1, tab2, tab3 = st.tabs(["πŸ”§ Setup", "βš™οΈ Parameters", "πŸ“Š Monitoring"])
with tab1:
hf_token = st.text_input("πŸ”‘ HuggingFace Token", type="password")
repo_name = st.text_input("πŸ“ Repository Name", "my-gpt2-model")
data_source = st.selectbox('πŸ“Š Data Source', ('DEMO', 'Upload Text File'))
with tab2:
population_size = st.slider("Population Size", 4, 20, 6)
num_generations = st.slider("Generations", 1, 10, 3)
num_parents = st.slider("Parents", 2, population_size, 2)
mutation_rate = st.slider("Mutation Rate", 0.0, 1.0, 0.1)
# Advanced Parameters
with st.expander("πŸ”¬ Advanced Settings"):
learning_rate_min = st.number_input("Min Learning Rate", 1e-6, 1e-4, 1e-5)
learning_rate_max = st.number_input("Max Learning Rate", 1e-5, 1e-3, 5e-5)
batch_size_options = st.multiselect("Batch Sizes", [2, 4, 8, 16], default=[2, 4, 8])
with tab3:
st.markdown("""
<div class="cyber-box">
<h3 style="color: #00ff9d;">System Status</h3>
<p>GPU: {}</p>
<p>Memory Usage: {:.2f}GB</p>
</div>
""".format(
'CUDA' if torch.cuda.is_available() else 'CPU',
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
), unsafe_allow_html=True)
# [Rest of your existing main() function code here, integrated with the dashboard]
# Make sure to update the dashboard metrics during training
# Example of updating dashboard during training:
for generation in range(num_generations):
for idx, individual in enumerate(population):
# Your existing training code
fitness = fitness_function(individual, train_dataset, model_clone, tokenizer)
dashboard.update(fitness, generation + 1, idx + 1)
dashboard.display()
# Update progress
progress = (generation * len(population) + idx + 1) / (num_generations * len(population))
st.markdown(f"""
<div class="progress-bar-container">
<div class="progress-bar" style="width: {progress * 100}%"></div>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()