import streamlit as st
import pandas as pd
import time
import threading
from data_utils import list_available_datasets, get_dataset_info
from model_utils import list_available_huggingface_models
from training_utils import (
    start_model_training, 
    stop_model_training, 
    get_running_training_jobs,
    simulate_training
)
from utils import (
    set_page_config, 
    display_sidebar, 
    add_log, 
    display_logs, 
    plot_training_progress
)

# Set page configuration
set_page_config()

# Display sidebar
display_sidebar()

# Title
st.title("Model Training")
st.markdown("Configure and train code generation models on your datasets.")

# Training configuration tab
tab1, tab2 = st.tabs(["Configure Training", "Monitor Jobs"])

with tab1:
    st.subheader("Train a New Model")
    
    # Model ID input
    model_id = st.text_input("Model ID", placeholder="e.g., my_codegen_model_v1")
    
    # Dataset selection
    available_datasets = list_available_datasets()
    if not available_datasets:
        st.warning("No datasets available. Please upload a dataset in the Dataset Management section.")
        dataset_name = None
    else:
        dataset_name = st.selectbox("Select Dataset", available_datasets)
    
    # Model selection
    model_options = list_available_huggingface_models()
    base_model = st.selectbox("Select Base Model", model_options)
    
    # Training parameters
    st.markdown("### Training Parameters")
    col1, col2 = st.columns(2)
    
    with col1:
        learning_rate = st.number_input(
            "Learning Rate", 
            min_value=1e-6, 
            max_value=1e-3, 
            value=2e-5, 
            format="%.2e"
        )
        batch_size = st.slider("Batch Size", min_value=1, max_value=32, value=8, step=1)
    
    with col2:
        epochs = st.slider("Number of Epochs", min_value=1, max_value=10, value=3, step=1)
        use_simulation = st.checkbox("Use Simulation Mode (for demonstration)", value=True)
    
    # Start training button
    if st.button("Start Training", disabled=not dataset_name):
        if not model_id:
            st.error("Please provide a model ID")
        elif model_id in st.session_state.get('trained_models', {}):
            st.error(f"Model with ID '{model_id}' already exists. Please choose a different ID.")
        elif model_id in st.session_state.get('training_progress', {}):
            st.error(f"A training job for model '{model_id}' already exists.")
        else:
            # Initialize stop_events if not present
            if 'stop_events' not in st.session_state:
                st.session_state.stop_events = {}
            
            # Start training (real or simulated)
            if use_simulation:
                st.session_state.stop_events[model_id] = simulate_training(
                    model_id, dataset_name, base_model, epochs
                )
                add_log(f"Started simulated training for model '{model_id}'")
            else:
                st.session_state.stop_events[model_id] = start_model_training(
                    model_id, dataset_name, base_model, learning_rate, batch_size, epochs
                )
                add_log(f"Started training for model '{model_id}'")
            
            st.success(f"Training job started for model '{model_id}'")
            time.sleep(1)
            st.rerun()

with tab2:
    st.subheader("Training Jobs")
    
    # Check if there are any training jobs
    if 'training_progress' not in st.session_state or not st.session_state.training_progress:
        st.info("No training jobs found. Start a new training job in the 'Configure Training' tab.")
    else:
        # List all training jobs
        all_jobs = list(st.session_state.training_progress.keys())
        selected_job = st.selectbox("Select Training Job", all_jobs)
        
        if selected_job:
            # Get job progress
            job_progress = st.session_state.training_progress[selected_job]
            
            # Display job status
            status = job_progress['status']
            status_color = {
                'initialized': 'blue',
                'running': 'green',
                'completed': 'green',
                'failed': 'red',
                'stopped': 'orange'
            }.get(status, 'gray')
            
            st.markdown(f"### Status: :{status_color}[{status.upper()}]")
            
            # Display progress bar
            progress = job_progress['progress']
            st.progress(progress/100)
            
            # Display job details
            col1, col2 = st.columns(2)
            
            with col1:
                st.markdown("### Job Details")
                st.markdown(f"**Model ID:** {selected_job}")
                st.markdown(f"**Current Epoch:** {job_progress['current_epoch']}/{job_progress['total_epochs']}")
                st.markdown(f"**Started At:** {job_progress['started_at']}")
                
                if job_progress['completed_at']:
                    st.markdown(f"**Completed At:** {job_progress['completed_at']}")
            
            with col2:
                # Training controls
                st.markdown("### Controls")
                
                # Only show stop button for running jobs
                if status == 'running' and selected_job in st.session_state.get('stop_events', {}):
                    if st.button("Stop Training"):
                        stop_event = st.session_state.stop_events[selected_job]
                        stop_model_training(selected_job, stop_event)
                        st.success(f"Stopping training for model '{selected_job}'")
                        time.sleep(1)
                        st.rerun()
                
                # Add delete button for completed/failed/stopped jobs
                if status in ['completed', 'failed', 'stopped']:
                    if st.button("Delete Job"):
                        del st.session_state.training_progress[selected_job]
                        if selected_job in st.session_state.get('stop_events', {}):
                            del st.session_state.stop_events[selected_job]
                        add_log(f"Deleted training job for model '{selected_job}'")
                        st.success(f"Training job for model '{selected_job}' deleted")
                        time.sleep(1)
                        st.rerun()
            
            # Display training progress plot
            st.markdown("### Training Progress")
            plot_training_progress(selected_job)
            
            # Display logs
            st.markdown("### Training Logs")
            display_logs()

# Display running jobs summary at the bottom
st.markdown("---")
st.subheader("Running Jobs Summary")
running_jobs = get_running_training_jobs()

if not running_jobs:
    st.info("No active training jobs")
else:
    for job in running_jobs:
        progress = st.session_state.training_progress[job]
        col1, col2, col3 = st.columns([2, 1, 1])
        
        with col1:
            st.markdown(f"**{job}**")
        
        with col2:
            st.markdown(f"Epoch {progress['current_epoch']}/{progress['total_epochs']}")
        
        with col3:
            st.progress(progress['progress']/100)