CodeCraftLab / pages /02_Model_Training.py
S-Dreamer's picture
Upload 2 files
9ad4afc verified
import streamlit as st
import pandas as pd
import time
import threading
from data_utils import list_available_datasets, get_dataset_info
from model_utils import list_available_huggingface_models
from training_utils import (
start_model_training,
stop_model_training,
get_running_training_jobs,
simulate_training
)
from utils import (
set_page_config,
display_sidebar,
add_log,
display_logs,
plot_training_progress
)
# Set page configuration
set_page_config()
# Display sidebar
display_sidebar()
# Title
st.title("Model Training")
st.markdown("Configure and train code generation models on your datasets.")
# Training configuration tab
tab1, tab2 = st.tabs(["Configure Training", "Monitor Jobs"])
with tab1:
st.subheader("Train a New Model")
# Model ID input
model_id = st.text_input("Model ID", placeholder="e.g., my_codegen_model_v1")
# Dataset selection
available_datasets = list_available_datasets()
if not available_datasets:
st.warning("No datasets available. Please upload a dataset in the Dataset Management section.")
dataset_name = None
else:
dataset_name = st.selectbox("Select Dataset", available_datasets)
# Model selection
model_options = list_available_huggingface_models()
base_model = st.selectbox("Select Base Model", model_options)
# Training parameters
st.markdown("### Training Parameters")
col1, col2 = st.columns(2)
with col1:
learning_rate = st.number_input(
"Learning Rate",
min_value=1e-6,
max_value=1e-3,
value=2e-5,
format="%.2e"
)
batch_size = st.slider("Batch Size", min_value=1, max_value=32, value=8, step=1)
with col2:
epochs = st.slider("Number of Epochs", min_value=1, max_value=10, value=3, step=1)
use_simulation = st.checkbox("Use Simulation Mode (for demonstration)", value=True)
# Start training button
if st.button("Start Training", disabled=not dataset_name):
if not model_id:
st.error("Please provide a model ID")
elif model_id in st.session_state.get('trained_models', {}):
st.error(f"Model with ID '{model_id}' already exists. Please choose a different ID.")
elif model_id in st.session_state.get('training_progress', {}):
st.error(f"A training job for model '{model_id}' already exists.")
else:
# Initialize stop_events if not present
if 'stop_events' not in st.session_state:
st.session_state.stop_events = {}
# Start training (real or simulated)
if use_simulation:
st.session_state.stop_events[model_id] = simulate_training(
model_id, dataset_name, base_model, epochs
)
add_log(f"Started simulated training for model '{model_id}'")
else:
st.session_state.stop_events[model_id] = start_model_training(
model_id, dataset_name, base_model, learning_rate, batch_size, epochs
)
add_log(f"Started training for model '{model_id}'")
st.success(f"Training job started for model '{model_id}'")
time.sleep(1)
st.rerun()
with tab2:
st.subheader("Training Jobs")
# Check if there are any training jobs
if 'training_progress' not in st.session_state or not st.session_state.training_progress:
st.info("No training jobs found. Start a new training job in the 'Configure Training' tab.")
else:
# List all training jobs
all_jobs = list(st.session_state.training_progress.keys())
selected_job = st.selectbox("Select Training Job", all_jobs)
if selected_job:
# Get job progress
job_progress = st.session_state.training_progress[selected_job]
# Display job status
status = job_progress['status']
status_color = {
'initialized': 'blue',
'running': 'green',
'completed': 'green',
'failed': 'red',
'stopped': 'orange'
}.get(status, 'gray')
st.markdown(f"### Status: :{status_color}[{status.upper()}]")
# Display progress bar
progress = job_progress['progress']
st.progress(progress/100)
# Display job details
col1, col2 = st.columns(2)
with col1:
st.markdown("### Job Details")
st.markdown(f"**Model ID:** {selected_job}")
st.markdown(f"**Current Epoch:** {job_progress['current_epoch']}/{job_progress['total_epochs']}")
st.markdown(f"**Started At:** {job_progress['started_at']}")
if job_progress['completed_at']:
st.markdown(f"**Completed At:** {job_progress['completed_at']}")
with col2:
# Training controls
st.markdown("### Controls")
# Only show stop button for running jobs
if status == 'running' and selected_job in st.session_state.get('stop_events', {}):
if st.button("Stop Training"):
stop_event = st.session_state.stop_events[selected_job]
stop_model_training(selected_job, stop_event)
st.success(f"Stopping training for model '{selected_job}'")
time.sleep(1)
st.rerun()
# Add delete button for completed/failed/stopped jobs
if status in ['completed', 'failed', 'stopped']:
if st.button("Delete Job"):
del st.session_state.training_progress[selected_job]
if selected_job in st.session_state.get('stop_events', {}):
del st.session_state.stop_events[selected_job]
add_log(f"Deleted training job for model '{selected_job}'")
st.success(f"Training job for model '{selected_job}' deleted")
time.sleep(1)
st.rerun()
# Display training progress plot
st.markdown("### Training Progress")
plot_training_progress(selected_job)
# Display logs
st.markdown("### Training Logs")
display_logs()
# Display running jobs summary at the bottom
st.markdown("---")
st.subheader("Running Jobs Summary")
running_jobs = get_running_training_jobs()
if not running_jobs:
st.info("No active training jobs")
else:
for job in running_jobs:
progress = st.session_state.training_progress[job]
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
st.markdown(f"**{job}**")
with col2:
st.markdown(f"Epoch {progress['current_epoch']}/{progress['total_epochs']}")
with col3:
st.progress(progress['progress']/100)