Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
import random | |
import torch | |
import transformers | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling | |
from datasets import Dataset | |
from huggingface_hub import HfApi | |
import os | |
import traceback | |
from contextlib import contextmanager | |
# Error Handling Context Manager | |
def error_handling(operation_name): | |
try: | |
yield | |
except Exception as e: | |
error_msg = f"Error during {operation_name}: {str(e)}\n{traceback.format_exc()}" | |
st.error(error_msg) | |
with open("error_log.txt", "a") as f: | |
f.write(f"\n{error_msg}") | |
# Cyberpunk Styling | |
def setup_cyberpunk_style(): | |
st.markdown(""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap'); | |
.stApp { | |
background: linear-gradient(45deg, #000428, #004e92); | |
} | |
.main-title { | |
font-family: 'Orbitron', sans-serif; | |
color: #00ff9d; | |
text-align: center; | |
text-shadow: 0 0 10px #00ff9d; | |
padding: 20px; | |
font-size: 2.5em; | |
margin-bottom: 30px; | |
} | |
.stButton>button { | |
background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
color: black; | |
font-family: 'Orbitron', sans-serif; | |
border: none; | |
padding: 10px 20px; | |
border-radius: 5px; | |
text-transform: uppercase; | |
font-weight: bold; | |
transition: all 0.3s ease; | |
} | |
.stButton>button:hover { | |
transform: scale(1.05); | |
box-shadow: 0 0 15px #00ff9d; | |
} | |
.metric-container { | |
background: rgba(0, 0, 0, 0.5); | |
border: 2px solid #00ff9d; | |
border-radius: 10px; | |
padding: 15px; | |
margin: 10px 0; | |
} | |
.status-text { | |
color: #00ff9d; | |
font-family: 'Orbitron', sans-serif; | |
font-size: 1.2em; | |
} | |
.sidebar .stSelectbox, .sidebar .stSlider { | |
background-color: rgba(0, 0, 0, 0.3); | |
border-radius: 5px; | |
padding: 10px; | |
margin: 5px 0; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Your existing functions with error handling | |
def generate_demo_data(num_samples=60): | |
with error_handling("demo data generation"): | |
# Your existing generate_demo_data code | |
subjects = [ | |
'Artificial intelligence', 'Climate change', 'Renewable energy', | |
'Space exploration', 'Quantum computing', 'Genetic engineering', | |
'Blockchain technology', 'Virtual reality', 'Cybersecurity', | |
'Biotechnology', 'Nanotechnology', 'Astrophysics' | |
] | |
verbs = [ | |
'is transforming', 'is influencing', 'is revolutionizing', | |
'is challenging', 'is advancing', 'is reshaping', 'is impacting', | |
'is enhancing', 'is disrupting', 'is redefining' | |
] | |
objects = [ | |
'modern science', 'global economies', 'healthcare systems', | |
'communication methods', 'educational approaches', | |
'environmental policies', 'social interactions', 'the job market', | |
'data security', 'the entertainment industry' | |
] | |
data = [] | |
for i in range(num_samples): | |
subject = random.choice(subjects) | |
verb = random.choice(verbs) | |
obj = random.choice(objects) | |
sentence = f"{subject} {verb} {obj}." | |
data.append(sentence) | |
return data | |
def upload_to_huggingface(model_path, token, repo_name): | |
with error_handling("HuggingFace upload"): | |
api = HfApi() | |
api.create_repo(repo_name, token=token, private=True) | |
api.upload_folder( | |
folder_path=model_path, | |
repo_id=repo_name, | |
token=token | |
) | |
return True | |
def fitness_function(individual, train_dataset, model, tokenizer): | |
with error_handling("fitness evaluation"): | |
training_args = TrainingArguments( | |
output_dir='./results', | |
overwrite_output_dir=True, | |
num_train_epochs=individual['epochs'], | |
per_device_train_batch_size=individual['batch_size'], | |
learning_rate=individual['learning_rate'], | |
logging_steps=10, | |
save_steps=10, | |
save_total_limit=2, | |
report_to='none', | |
) | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, mlm=False | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
data_collator=data_collator, | |
train_dataset=train_dataset, | |
eval_dataset=None, | |
) | |
trainer.train() | |
logs = [log for log in trainer.state.log_history if 'loss' in log] | |
return logs[-1]['loss'] if logs else float('inf') | |
def main(): | |
setup_cyberpunk_style() | |
st.markdown('<h1 class="main-title">Neural Evolution GPT-2 Training Hub</h1>', unsafe_allow_html=True) | |
# Sidebar Configuration | |
with st.sidebar: | |
st.markdown("### π Configuration") | |
hf_token = st.text_input("π HuggingFace Token", type="password") | |
repo_name = st.text_input("π Repository Name", "my-gpt2-model") | |
data_source = st.selectbox( | |
'π Data Source', | |
('DEMO', 'Upload Text File') | |
) | |
st.markdown("### βοΈ Evolution Parameters") | |
population_size = st.slider("Population Size", 4, 20, 6) | |
num_generations = st.slider("Generations", 1, 10, 3) | |
num_parents = st.slider("Parents", 2, population_size, 2) | |
mutation_rate = st.slider("Mutation Rate", 0.0, 1.0, 0.1) | |
# Hyperparameter bounds | |
param_bounds = { | |
'learning_rate': (1e-5, 5e-5), | |
'epochs': (1, 3), | |
'batch_size': [2, 4, 8] | |
} | |
# Main Content Area | |
with error_handling("main application flow"): | |
if data_source == 'DEMO': | |
st.info("π€ Using demo data...") | |
data = generate_demo_data() | |
else: | |
uploaded_file = st.file_uploader("π Upload Training Data", type="txt") | |
if uploaded_file: | |
data = load_data(uploaded_file) | |
else: | |
st.warning("β οΈ Please upload a text file") | |
st.stop() | |
# Model Setup | |
with st.spinner("π§ Loading GPT-2..."): | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
model = GPT2LMHeadModel.from_pretrained('gpt2') | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model.to(device) | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
# Dataset Preparation | |
with st.spinner("π Preparing dataset..."): | |
train_dataset = prepare_dataset(data, tokenizer) | |
if st.button("π Start Training", key="start_training"): | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
# Metrics Display | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
metrics_loss = st.empty() | |
with col2: | |
metrics_generation = st.empty() | |
with col3: | |
metrics_status = st.empty() | |
try: | |
# Initialize GA | |
population = create_population(population_size, param_bounds) | |
best_individual = None | |
best_fitness = float('inf') | |
fitness_history = [] | |
total_evaluations = num_generations * len(population) | |
current_evaluation = 0 | |
for generation in range(num_generations): | |
metrics_generation.markdown(f""" | |
<div class="metric-container"> | |
<p class="status-text">Generation: {generation + 1}/{num_generations}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
fitnesses = [] | |
for idx, individual in enumerate(population): | |
status_text.text(f"𧬠Evaluating individual {idx+1}/{len(population)} in generation {generation+1}") | |
# Clone model for each individual | |
model_clone = GPT2LMHeadModel.from_pretrained('gpt2') | |
model_clone.to(device) | |
fitness = fitness_function(individual, train_dataset, model_clone, tokenizer) | |
fitnesses.append(fitness) | |
if fitness < best_fitness: | |
best_fitness = fitness | |
best_individual = individual.copy() | |
metrics_loss.markdown(f""" | |
<div class="metric-container"> | |
<p class="status-text">Best Loss: {best_fitness:.4f}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
current_evaluation += 1 | |
progress_bar.progress(current_evaluation / total_evaluations) | |
# Evolution steps | |
parents = select_mating_pool(population, fitnesses, num_parents) | |
offspring_size = population_size - num_parents | |
offspring = crossover(parents, offspring_size) | |
offspring = mutation(offspring, param_bounds, mutation_rate) | |
population = parents + offspring | |
fitness_history.append(min(fitnesses)) | |
# Training Complete | |
st.success("π Training completed!") | |
st.write("Best Hyperparameters:", best_individual) | |
st.write("Best Fitness (Loss):", best_fitness) | |
# Plot fitness history | |
st.line_chart(fitness_history) | |
# Save and Upload Model | |
if st.button("πΎ Save & Upload Model"): | |
with st.spinner("Saving model..."): | |
model.save_pretrained('./fine_tuned_model') | |
tokenizer.save_pretrained('./fine_tuned_model') | |
if hf_token: | |
if upload_to_huggingface('./fine_tuned_model', hf_token, repo_name): | |
st.success(f"β Model uploaded to HuggingFace: {repo_name}") | |
else: | |
st.error("β Failed to upload model") | |
else: | |
st.warning("β οΈ No HuggingFace token provided. Model saved locally only.") | |
except Exception as e: | |
st.error(f"β Training error: {str(e)}") | |
with open("error_log.txt", "a") as f: | |
f.write(f"\nTraining error: {str(e)}\n{traceback.format_exc()}") | |
if __name__ == "__main__": | |
main() |