import streamlit as st import pandas as pd import time import threading from data_utils import list_available_datasets, get_dataset_info from model_utils import list_available_huggingface_models from training_utils import ( start_model_training, stop_model_training, get_running_training_jobs, simulate_training ) from utils import ( set_page_config, display_sidebar, add_log, display_logs, plot_training_progress ) # Set page configuration set_page_config() # Display sidebar display_sidebar() # Title st.title("Model Training") st.markdown("Configure and train code generation models on your datasets.") # Training configuration tab tab1, tab2 = st.tabs(["Configure Training", "Monitor Jobs"]) with tab1: st.subheader("Train a New Model") # Model ID input model_id = st.text_input("Model ID", placeholder="e.g., my_codegen_model_v1") # Dataset selection available_datasets = list_available_datasets() if not available_datasets: st.warning("No datasets available. Please upload a dataset in the Dataset Management section.") dataset_name = None else: dataset_name = st.selectbox("Select Dataset", available_datasets) # Model selection model_options = list_available_huggingface_models() base_model = st.selectbox("Select Base Model", model_options) # Training parameters st.markdown("### Training Parameters") col1, col2 = st.columns(2) with col1: learning_rate = st.number_input( "Learning Rate", min_value=1e-6, max_value=1e-3, value=2e-5, format="%.2e" ) batch_size = st.slider("Batch Size", min_value=1, max_value=32, value=8, step=1) with col2: epochs = st.slider("Number of Epochs", min_value=1, max_value=10, value=3, step=1) use_simulation = st.checkbox("Use Simulation Mode (for demonstration)", value=True) # Start training button if st.button("Start Training", disabled=not dataset_name): if not model_id: st.error("Please provide a model ID") elif model_id in st.session_state.get('trained_models', {}): st.error(f"Model with ID '{model_id}' already exists. Please choose a different ID.") elif model_id in st.session_state.get('training_progress', {}): st.error(f"A training job for model '{model_id}' already exists.") else: # Initialize stop_events if not present if 'stop_events' not in st.session_state: st.session_state.stop_events = {} # Start training (real or simulated) if use_simulation: st.session_state.stop_events[model_id] = simulate_training( model_id, dataset_name, base_model, epochs ) add_log(f"Started simulated training for model '{model_id}'") else: st.session_state.stop_events[model_id] = start_model_training( model_id, dataset_name, base_model, learning_rate, batch_size, epochs ) add_log(f"Started training for model '{model_id}'") st.success(f"Training job started for model '{model_id}'") time.sleep(1) st.rerun() with tab2: st.subheader("Training Jobs") # Check if there are any training jobs if 'training_progress' not in st.session_state or not st.session_state.training_progress: st.info("No training jobs found. Start a new training job in the 'Configure Training' tab.") else: # List all training jobs all_jobs = list(st.session_state.training_progress.keys()) selected_job = st.selectbox("Select Training Job", all_jobs) if selected_job: # Get job progress job_progress = st.session_state.training_progress[selected_job] # Display job status status = job_progress['status'] status_color = { 'initialized': 'blue', 'running': 'green', 'completed': 'green', 'failed': 'red', 'stopped': 'orange' }.get(status, 'gray') st.markdown(f"### Status: :{status_color}[{status.upper()}]") # Display progress bar progress = job_progress['progress'] st.progress(progress/100) # Display job details col1, col2 = st.columns(2) with col1: st.markdown("### Job Details") st.markdown(f"**Model ID:** {selected_job}") st.markdown(f"**Current Epoch:** {job_progress['current_epoch']}/{job_progress['total_epochs']}") st.markdown(f"**Started At:** {job_progress['started_at']}") if job_progress['completed_at']: st.markdown(f"**Completed At:** {job_progress['completed_at']}") with col2: # Training controls st.markdown("### Controls") # Only show stop button for running jobs if status == 'running' and selected_job in st.session_state.get('stop_events', {}): if st.button("Stop Training"): stop_event = st.session_state.stop_events[selected_job] stop_model_training(selected_job, stop_event) st.success(f"Stopping training for model '{selected_job}'") time.sleep(1) st.rerun() # Add delete button for completed/failed/stopped jobs if status in ['completed', 'failed', 'stopped']: if st.button("Delete Job"): del st.session_state.training_progress[selected_job] if selected_job in st.session_state.get('stop_events', {}): del st.session_state.stop_events[selected_job] add_log(f"Deleted training job for model '{selected_job}'") st.success(f"Training job for model '{selected_job}' deleted") time.sleep(1) st.rerun() # Display training progress plot st.markdown("### Training Progress") plot_training_progress(selected_job) # Display logs st.markdown("### Training Logs") display_logs() # Display running jobs summary at the bottom st.markdown("---") st.subheader("Running Jobs Summary") running_jobs = get_running_training_jobs() if not running_jobs: st.info("No active training jobs") else: for job in running_jobs: progress = st.session_state.training_progress[job] col1, col2, col3 = st.columns([2, 1, 1]) with col1: st.markdown(f"**{job}**") with col2: st.markdown(f"Epoch {progress['current_epoch']}/{progress['total_epochs']}") with col3: st.progress(progress['progress']/100)