Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
import torch | |
import random | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling | |
from datasets import Dataset | |
from huggingface_hub import HfApi | |
import plotly.graph_objects as go | |
import time | |
from datetime import datetime | |
# Cyberpunk and Loading Animation Styling | |
def setup_cyberpunk_style(): | |
st.markdown(""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap'); | |
@import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&display=swap'); | |
.stApp { | |
background: radial-gradient(circle, rgba(0, 0, 0, 0.95) 20%, rgba(0, 50, 80, 0.95) 90%); | |
color: #00ff9d; | |
font-family: 'Orbitron', sans-serif; | |
} | |
.main-title { | |
text-align: center; | |
font-size: 4em; | |
color: #00ff9d; | |
letter-spacing: 4px; | |
animation: glow 2s ease-in-out infinite alternate; | |
} | |
@keyframes glow { | |
from {text-shadow: 0 0 5px #00ff9d, 0 0 10px #00ff9d;} | |
to {text-shadow: 0 0 15px #00b8ff, 0 0 20px #00b8ff;} | |
} | |
.stButton > button { | |
font-family: 'Orbitron', sans-serif; | |
background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
color: #000; | |
font-size: 1.1em; | |
padding: 10px 20px; | |
border: none; | |
border-radius: 8px; | |
transition: all 0.3s ease; | |
} | |
.stButton > button:hover { | |
transform: scale(1.1); | |
box-shadow: 0 0 20px rgba(0, 255, 157, 0.5); | |
} | |
.progress-bar-container { | |
background: rgba(0, 0, 0, 0.5); | |
border-radius: 15px; | |
overflow: hidden; | |
width: 100%; | |
height: 30px; | |
position: relative; | |
margin: 10px 0; | |
} | |
.progress-bar { | |
height: 100%; | |
width: 0%; | |
background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
transition: width 0.5s ease; | |
} | |
.go-button { | |
font-family: 'Orbitron', sans-serif; | |
background: linear-gradient(45deg, #00ff9d, #00b8ff); | |
color: #000; | |
font-size: 1.1em; | |
padding: 10px 20px; | |
border: none; | |
border-radius: 8px; | |
transition: all 0.3s ease; | |
cursor: pointer; | |
} | |
.go-button:hover { | |
transform: scale(1.1); | |
box-shadow: 0 0 20px rgba(0, 255, 157, 0.5); | |
} | |
.loading-animation { | |
display: inline-block; | |
width: 20px; | |
height: 20px; | |
border: 3px solid #00ff9d; | |
border-radius: 50%; | |
border-top-color: transparent; | |
animation: spin 1s ease-in-out infinite; | |
} | |
@keyframes spin { | |
to {transform: rotate(360deg);} | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Prepare Dataset Function with Padding Token Fix | |
def prepare_dataset(data, tokenizer, block_size=128): | |
tokenizer.pad_token = tokenizer.eos_token | |
def tokenize_function(examples): | |
return tokenizer(examples['text'], truncation=True, max_length=block_size, padding='max_length') | |
raw_dataset = Dataset.from_dict({'text': data}) | |
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=['text']) | |
tokenized_dataset = tokenized_dataset.map(lambda examples: {'labels': examples['input_ids']}, batched=True) | |
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) | |
return tokenized_dataset | |
# Training Dashboard Class with Enhanced Display | |
class TrainingDashboard: | |
def __init__(self): | |
self.metrics = { | |
'current_loss': 0, | |
'best_loss': float('inf'), | |
'generation': 0, | |
'individual': 0, | |
'start_time': time.time(), | |
'training_speed': 0 | |
} | |
self.history = [] | |
def update(self, loss, generation, individual): | |
self.metrics['current_loss'] = loss | |
self.metrics['generation'] = generation | |
self.metrics['individual'] = individual | |
if loss < self.metrics['best_loss']: | |
self.metrics['best_loss'] = loss | |
elapsed_time = time.time() - self.metrics['start_time'] | |
self.metrics['training_speed'] = (generation * individual) / elapsed_time | |
self.history.append({'loss': loss, 'timestamp': datetime.now().strftime('%H:%M:%S')}) | |
# Define Model Initialization | |
def initialize_model(model_name="gpt2"): | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
return model, tokenizer | |
# Load Dataset Function with Uploaded File Option | |
def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None): | |
if data_source == "demo": | |
data = ["Sample text data for model training. This can be replaced with actual data for better performance."] | |
elif uploaded_file is not None: | |
if uploaded_file.name.endswith(".txt"): | |
data = [uploaded_file.read().decode("utf-8")] | |
elif uploaded_file.name.endswith(".csv"): | |
import pandas as pd | |
df = pd.read_csv(uploaded_file) | |
data = df[df.columns[0]].tolist() # assuming first column is text data | |
else: | |
data = ["No file uploaded. Please upload a dataset."] | |
dataset = prepare_dataset(data, tokenizer) | |
return dataset | |
# Train Model Function with Customized Progress Bar | |
def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4): | |
training_args = TrainingArguments( | |
output_dir="./results", | |
overwrite_output_dir=True, | |
num_train_epochs=epochs, | |
per_device_train_batch_size=batch_size, | |
save_steps=10_000, | |
save_total_limit=2, | |
logging_dir="./logs", | |
logging_steps=100, | |
) | |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
data_collator=data_collator, | |
train_dataset=train_dataset, | |
) | |
trainer.train() | |
# Main App Logic | |
def main(): | |
setup_cyberpunk_style() | |
st.markdown('<h1 class="main-title">Cyberpunk Neural Training Hub</h1>', unsafe_allow_html=True) | |
# Initialize model and tokenizer | |
model, tokenizer = initialize_model() | |
# Sidebar Configuration with Additional Options | |
with st.sidebar: | |
st.markdown("### Configuration Panel") | |
# Hugging Face API Token Input | |
hf_token = st.text_input("Enter your Hugging Face Token", type="password") | |
if hf_token: | |
api = HfApi() | |
api.set_access_token(hf_token) | |
st.success("Hugging Face token added successfully!") | |
# Training Parameters | |
training_epochs = st.slider("Training Epochs", min_value=1, max_value=5, value=3) | |
batch_size = st.slider("Batch Size", min_value=2, max_value=8, value=4) | |
model_choice = st.selectbox("Model Selection", ("gpt2", "distilgpt2", "gpt2-medium")) | |
# Dataset Source Selection | |
data_source = st.selectbox("Data Source", ("demo", "uploaded file")) | |
uploaded_file = st.file_uploader("Upload a text file", type=["txt", "csv"]) if data_source == "uploaded file" else None | |
custom_learning_rate = st.slider("Learning Rate", min_value=1e-6, max_value=5e-4, value=3e-5, step=1e-6) | |
# Advanced Settings Toggle | |
advanced_toggle = st.checkbox("Advanced Training Settings") | |
if advanced_toggle: | |
warmup_steps = st.slider("Warmup Steps", min_value=0, max_value=500, value=100) | |
weight_decay = st.slider("Weight Decay", min_value=0.0, max_value=0.1, step=0.01, value=0.01) | |
else: | |
warmup_steps = 100 | |
weight_decay = 0.01 | |
# Load Dataset | |
train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file) | |
# Go Button to Start Training | |
if st.button("Go"): | |
progress_placeholder = st.empty() | |
loading_animation = st.empty() | |
st.markdown("### Model Training Progress") | |
dashboard = TrainingDashboard() | |
for epoch in range(training_epochs): | |
loading_animation.markdown(""" | |
<div class="loading-animation"></div> | |
""", unsafe_allow_html=True) | |
train_model(model, train_dataset, tokenizer, epochs=1, batch_size=batch_size) | |
# Update Progress Bar | |
progress = (epoch + 1) / training_epochs * 100 | |
progress_placeholder.markdown(f""" | |
<div class="progress-bar-container"> | |
<div class="progress-bar" style="width: {progress}%;"></div> | |
</div> | |
""", unsafe_allow_html=True) | |
dashboard.update(loss=0, generation=epoch + 1, individual=batch_size) | |
loading_animation.empty() | |
st.success("Training Complete!") | |
st.write("Training Metrics:") | |
st.write(dashboard.metrics) | |
if __name__ == "__main__": | |
main() |