E.L.N / app.py
Sephfox's picture
Update app.py
f5b3aed verified
raw
history blame
4.88 kB
import streamlit as st
import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset
import time
from datetime import datetime
import plotly.graph_objects as go
# Advanced Cyberpunk Styling
def setup_advanced_cyberpunk_style():
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
@import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap');
/* Additional styling as provided previously */
</style>
""", unsafe_allow_html=True)
# Initialize Model and Tokenizer
def initialize_model():
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
return model, tokenizer
# Prepare Dataset
def prepare_dataset(data, tokenizer, block_size=128):
def tokenize_function(examples):
return tokenizer(examples['text'], truncation=True, max_length=block_size, padding='max_length')
raw_dataset = Dataset.from_dict({'text': data})
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=['text'])
tokenized_dataset = tokenized_dataset.map(
lambda examples: {'labels': examples['input_ids']},
batched=True
)
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
return tokenized_dataset
# Training Dashboard Class
class TrainingDashboard:
def __init__(self):
self.metrics = {
'current_loss': 0,
'best_loss': float('inf'),
'generation': 0,
'start_time': time.time(),
'training_speed': 0
}
self.history = []
def update(self, loss, generation):
self.metrics['current_loss'] = loss
self.metrics['generation'] = generation
if loss < self.metrics['best_loss']:
self.metrics['best_loss'] = loss
elapsed_time = time.time() - self.metrics['start_time']
self.metrics['training_speed'] = generation / elapsed_time
self.history.append({'loss': loss, 'timestamp': datetime.now().strftime('%H:%M:%S')})
def display(self):
st.write(f"**Generation:** {self.metrics['generation']}")
st.write(f"**Current Loss:** {self.metrics['current_loss']:.4f}")
st.write(f"**Best Loss:** {self.metrics['best_loss']:.4f}")
st.write(f"**Training Speed:** {self.metrics['training_speed']:.2f} generations/sec")
# Display Progress Bar
def display_progress(progress):
st.markdown(f"""
<div class="progress-bar-container">
<div class="progress-bar" style="width: {progress * 100}%"></div>
</div>
""", unsafe_allow_html=True)
# Fitness Calculation (Placeholder for actual loss computation)
def compute_loss(model, dataset):
# Placeholder for real loss computation with Trainer API or custom logic
trainer = Trainer(
model=model,
args=TrainingArguments(output_dir="./results", per_device_train_batch_size=2, num_train_epochs=1),
train_dataset=dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=model.config._name_or_path, mlm=False),
)
train_result = trainer.train()
return train_result.training_loss
# Training Loop with Loading Screen
def training_loop(dashboard, model, dataset, num_generations, population_size):
with st.spinner("Training in progress..."):
for generation in range(1, num_generations + 1):
# Simulated population loop
for individual in range(population_size):
loss = compute_loss(model, dataset)
dashboard.update(loss, generation)
progress = generation / num_generations
display_progress(progress)
dashboard.display()
time.sleep(1) # Simulate delay for each individual training
# Main Function
def main():
setup_advanced_cyberpunk_style()
st.markdown('<h1 class="main-title">Neural Evolution GPT-2 Training Hub</h1>', unsafe_allow_html=True)
# Load Model and Tokenizer
model, tokenizer = initialize_model()
# Prepare Data
data = ["Sample training text"] * 10 # Replace with real data
train_dataset = prepare_dataset(data, tokenizer)
# Initialize Dashboard
dashboard = TrainingDashboard()
# Sidebar Configuration
st.sidebar.markdown("### Training Parameters")
num_generations = st.sidebar.slider("Generations", 1, 20, 5)
population_size = st.sidebar.slider("Population Size", 4, 20, 6)
# Run Training
if st.button("Start Training"):
training_loop(dashboard, model, train_dataset, num_generations, population_size)
if __name__ == "__main__":
main()