CodeCraftLab / training_utils.py
S-Dreamer's picture
Upload 14 files
5dd070e verified
import streamlit as st
import threading
import random
import time
from datetime import datetime
from utils import add_log, timestamp
# Handle missing dependencies
try:
import torch
import pandas as pd
from transformers import TrainingArguments as HFTrainingArguments
from transformers import Trainer, AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset, DatasetDict
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
HFTrainingArguments = None
# For demo purposes
class DummyTrainer:
def __init__(self, **kwargs):
self.callback = type('obj', (object,), {'__init__': lambda self: None})
def train(self):
pass
def initialize_training_progress(model_id):
"""
Initialize training progress tracking for a model.
Args:
model_id: Identifier for the model
"""
if 'training_progress' not in st.session_state:
st.session_state.training_progress = {}
st.session_state.training_progress[model_id] = {
'status': 'initialized',
'current_epoch': 0,
'total_epochs': 0,
'loss_history': [],
'started_at': timestamp(),
'completed_at': None,
'progress': 0.0
}
def update_training_progress(model_id, epoch=None, loss=None, status=None, progress=None, total_epochs=None):
"""
Update training progress for a model.
Args:
model_id: Identifier for the model
epoch: Current epoch
loss: Current loss value
status: Training status
progress: Progress percentage (0-100)
total_epochs: Total number of epochs
"""
if 'training_progress' not in st.session_state or model_id not in st.session_state.training_progress:
initialize_training_progress(model_id)
progress_data = st.session_state.training_progress[model_id]
if epoch is not None:
progress_data['current_epoch'] = epoch
if loss is not None:
progress_data['loss_history'].append(loss)
if status is not None:
progress_data['status'] = status
if status == 'completed':
progress_data['completed_at'] = timestamp()
progress_data['progress'] = 100.0
if progress is not None:
progress_data['progress'] = progress
if total_epochs is not None:
progress_data['total_epochs'] = total_epochs
def tokenize_dataset(dataset, tokenizer, max_length=512):
"""
Tokenize a dataset for model training.
Args:
dataset: The dataset to tokenize
tokenizer: The tokenizer to use
max_length: Maximum sequence length
Returns:
Dataset: Tokenized dataset
"""
def tokenize_function(examples):
return tokenizer(examples['code'], padding='max_length', truncation=True, max_length=max_length)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
return tokenized_dataset
def train_model_thread(model_id, dataset_name, base_model_name, training_args, device, stop_event):
"""
Thread function for training a model.
Args:
model_id: Identifier for the model
dataset_name: Name of the dataset to use
base_model_name: Base model from Hugging Face
training_args: Training arguments
device: Device to use for training (cpu/cuda)
stop_event: Threading event to signal stopping
"""
try:
# Get dataset
dataset = st.session_state.datasets[dataset_name]['data']
# Initialize model and tokenizer
add_log(f"Initializing model {base_model_name} for training")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModelForCausalLM.from_pretrained(base_model_name)
# Check if tokenizer has padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# Tokenize dataset
add_log(f"Tokenizing dataset {dataset_name}")
train_dataset = tokenize_dataset(dataset['train'], tokenizer)
val_dataset = tokenize_dataset(dataset['validation'], tokenizer)
# Update training progress
update_training_progress(
model_id,
status='running',
total_epochs=training_args.num_train_epochs
)
# Define custom callback to track progress
class CustomCallback(Trainer.callback):
def on_epoch_end(self, args, state, control, **kwargs):
current_epoch = state.epoch
epoch_loss = state.log_history[-1].get('loss', 0)
update_training_progress(
model_id,
epoch=current_epoch,
loss=epoch_loss,
progress=(current_epoch / training_args.num_train_epochs) * 100
)
add_log(f"Epoch {current_epoch}/{training_args.num_train_epochs} completed. Loss: {epoch_loss:.4f}")
# Check if training should be stopped
if stop_event.is_set():
add_log(f"Training for model {model_id} was manually stopped")
control.should_training_stop = True
# Configure training arguments
args = HFTrainingArguments(
output_dir=f"./results/{model_id}",
evaluation_strategy="epoch",
learning_rate=training_args.learning_rate,
per_device_train_batch_size=training_args.batch_size,
per_device_eval_batch_size=training_args.batch_size,
num_train_epochs=training_args.num_train_epochs,
weight_decay=0.01,
save_total_limit=1,
)
# Initialize trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
callbacks=[CustomCallback]
)
# Train the model
add_log(f"Starting training for model {model_id}")
trainer.train()
# Save the model
if not stop_event.is_set():
add_log(f"Training completed for model {model_id}")
update_training_progress(model_id, status='completed')
# Save to session state
st.session_state.trained_models[model_id] = {
'model': model,
'tokenizer': tokenizer,
'info': {
'id': model_id,
'base_model': base_model_name,
'dataset': dataset_name,
'created_at': timestamp(),
'epochs': training_args.num_train_epochs,
'learning_rate': training_args.learning_rate,
'batch_size': training_args.batch_size
}
}
except Exception as e:
add_log(f"Error during training model {model_id}: {str(e)}", "ERROR")
update_training_progress(model_id, status='failed')
class TrainingArguments:
def __init__(self, learning_rate, batch_size, num_train_epochs):
self.learning_rate = learning_rate
self.batch_size = batch_size
self.num_train_epochs = num_train_epochs
def start_model_training(model_id, dataset_name, base_model_name, learning_rate, batch_size, epochs):
"""
Start model training in a separate thread.
Args:
model_id: Identifier for the model
dataset_name: Name of the dataset to use
base_model_name: Base model from Hugging Face
learning_rate: Learning rate for training
batch_size: Batch size for training
epochs: Number of training epochs
Returns:
threading.Event: Event to signal stopping the training
"""
# Use simulate_training instead if transformers isn't available
if not TRANSFORMERS_AVAILABLE:
add_log("No transformers library available, using simulation mode")
return simulate_training(model_id, dataset_name, base_model_name, epochs)
# Create training arguments
training_args = TrainingArguments(
learning_rate=learning_rate,
batch_size=batch_size,
num_train_epochs=epochs
)
# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"
add_log(f"Using device: {device}")
# Initialize training progress
initialize_training_progress(model_id)
# Create stop event
stop_event = threading.Event()
# Start training thread
training_thread = threading.Thread(
target=train_model_thread,
args=(model_id, dataset_name, base_model_name, training_args, device, stop_event)
)
training_thread.start()
return stop_event
def stop_model_training(model_id, stop_event):
"""
Stop model training.
Args:
model_id: Identifier for the model
stop_event: Threading event to signal stopping
"""
if stop_event.is_set():
return
add_log(f"Stopping training for model {model_id}")
stop_event.set()
# Update training progress
if 'training_progress' in st.session_state and model_id in st.session_state.training_progress:
progress_data = st.session_state.training_progress[model_id]
if progress_data['status'] == 'running':
progress_data['status'] = 'stopped'
progress_data['completed_at'] = timestamp()
def get_running_training_jobs():
"""
Get list of currently running training jobs.
Returns:
list: List of model IDs with running training jobs
"""
running_jobs = []
if 'training_progress' in st.session_state:
for model_id, progress in st.session_state.training_progress.items():
if progress['status'] == 'running':
running_jobs.append(model_id)
return running_jobs
# For demo purposes - Simulate training progress without actual model training
def simulate_training_thread(model_id, dataset_name, base_model_name, epochs, stop_event):
"""
Simulate training progress for demonstration purposes.
Args:
model_id: Identifier for the model
dataset_name: Name of the dataset to use
base_model_name: Base model from Hugging Face
epochs: Number of training epochs
stop_event: Threading event to signal stopping
"""
add_log(f"Starting simulated training for model {model_id}")
update_training_progress(model_id, status='running', total_epochs=epochs)
for epoch in range(1, epochs + 1):
if stop_event.is_set():
add_log(f"Simulated training for model {model_id} was manually stopped")
update_training_progress(model_id, status='stopped')
return
# Simulate epoch time
time.sleep(2)
# Generate random loss that decreases over time
loss = max(0.1, 2.0 - (epoch / epochs) * 1.5 + random.uniform(-0.1, 0.1))
# Update progress
update_training_progress(
model_id,
epoch=epoch,
loss=loss,
progress=(epoch / epochs) * 100
)
add_log(f"Epoch {epoch}/{epochs} completed. Loss: {loss:.4f}")
# Training completed
add_log(f"Simulated training completed for model {model_id}")
update_training_progress(model_id, status='completed')
# Create dummy model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModelForCausalLM.from_pretrained(base_model_name)
# Save to session state
st.session_state.trained_models[model_id] = {
'model': model,
'tokenizer': tokenizer,
'info': {
'id': model_id,
'base_model': base_model_name,
'dataset': dataset_name,
'created_at': timestamp(),
'epochs': epochs,
'simulated': True
}
}
def simulate_training(model_id, dataset_name, base_model_name, epochs):
"""
Start simulated training in a separate thread.
Args:
model_id: Identifier for the model
dataset_name: Name of the dataset to use
base_model_name: Base model from Hugging Face
epochs: Number of training epochs
Returns:
threading.Event: Event to signal stopping the training
"""
# Initialize training progress
initialize_training_progress(model_id)
# Create stop event
stop_event = threading.Event()
# Start training thread
training_thread = threading.Thread(
target=simulate_training_thread,
args=(model_id, dataset_name, base_model_name, epochs, stop_event)
)
training_thread.start()
return stop_event