File size: 7,319 Bytes
9ad4afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import streamlit as st
import pandas as pd
import time
import threading
from data_utils import list_available_datasets, get_dataset_info
from model_utils import list_available_huggingface_models
from training_utils import (
    start_model_training, 
    stop_model_training, 
    get_running_training_jobs,
    simulate_training
)
from utils import (
    set_page_config, 
    display_sidebar, 
    add_log, 
    display_logs, 
    plot_training_progress
)

# Set page configuration
set_page_config()

# Display sidebar
display_sidebar()

# Title
st.title("Model Training")
st.markdown("Configure and train code generation models on your datasets.")

# Training configuration tab
tab1, tab2 = st.tabs(["Configure Training", "Monitor Jobs"])

with tab1:
    st.subheader("Train a New Model")
    
    # Model ID input
    model_id = st.text_input("Model ID", placeholder="e.g., my_codegen_model_v1")
    
    # Dataset selection
    available_datasets = list_available_datasets()
    if not available_datasets:
        st.warning("No datasets available. Please upload a dataset in the Dataset Management section.")
        dataset_name = None
    else:
        dataset_name = st.selectbox("Select Dataset", available_datasets)
    
    # Model selection
    model_options = list_available_huggingface_models()
    base_model = st.selectbox("Select Base Model", model_options)
    
    # Training parameters
    st.markdown("### Training Parameters")
    col1, col2 = st.columns(2)
    
    with col1:
        learning_rate = st.number_input(
            "Learning Rate", 
            min_value=1e-6, 
            max_value=1e-3, 
            value=2e-5, 
            format="%.2e"
        )
        batch_size = st.slider("Batch Size", min_value=1, max_value=32, value=8, step=1)
    
    with col2:
        epochs = st.slider("Number of Epochs", min_value=1, max_value=10, value=3, step=1)
        use_simulation = st.checkbox("Use Simulation Mode (for demonstration)", value=True)
    
    # Start training button
    if st.button("Start Training", disabled=not dataset_name):
        if not model_id:
            st.error("Please provide a model ID")
        elif model_id in st.session_state.get('trained_models', {}):
            st.error(f"Model with ID '{model_id}' already exists. Please choose a different ID.")
        elif model_id in st.session_state.get('training_progress', {}):
            st.error(f"A training job for model '{model_id}' already exists.")
        else:
            # Initialize stop_events if not present
            if 'stop_events' not in st.session_state:
                st.session_state.stop_events = {}
            
            # Start training (real or simulated)
            if use_simulation:
                st.session_state.stop_events[model_id] = simulate_training(
                    model_id, dataset_name, base_model, epochs
                )
                add_log(f"Started simulated training for model '{model_id}'")
            else:
                st.session_state.stop_events[model_id] = start_model_training(
                    model_id, dataset_name, base_model, learning_rate, batch_size, epochs
                )
                add_log(f"Started training for model '{model_id}'")
            
            st.success(f"Training job started for model '{model_id}'")
            time.sleep(1)
            st.rerun()

with tab2:
    st.subheader("Training Jobs")
    
    # Check if there are any training jobs
    if 'training_progress' not in st.session_state or not st.session_state.training_progress:
        st.info("No training jobs found. Start a new training job in the 'Configure Training' tab.")
    else:
        # List all training jobs
        all_jobs = list(st.session_state.training_progress.keys())
        selected_job = st.selectbox("Select Training Job", all_jobs)
        
        if selected_job:
            # Get job progress
            job_progress = st.session_state.training_progress[selected_job]
            
            # Display job status
            status = job_progress['status']
            status_color = {
                'initialized': 'blue',
                'running': 'green',
                'completed': 'green',
                'failed': 'red',
                'stopped': 'orange'
            }.get(status, 'gray')
            
            st.markdown(f"### Status: :{status_color}[{status.upper()}]")
            
            # Display progress bar
            progress = job_progress['progress']
            st.progress(progress/100)
            
            # Display job details
            col1, col2 = st.columns(2)
            
            with col1:
                st.markdown("### Job Details")
                st.markdown(f"**Model ID:** {selected_job}")
                st.markdown(f"**Current Epoch:** {job_progress['current_epoch']}/{job_progress['total_epochs']}")
                st.markdown(f"**Started At:** {job_progress['started_at']}")
                
                if job_progress['completed_at']:
                    st.markdown(f"**Completed At:** {job_progress['completed_at']}")
            
            with col2:
                # Training controls
                st.markdown("### Controls")
                
                # Only show stop button for running jobs
                if status == 'running' and selected_job in st.session_state.get('stop_events', {}):
                    if st.button("Stop Training"):
                        stop_event = st.session_state.stop_events[selected_job]
                        stop_model_training(selected_job, stop_event)
                        st.success(f"Stopping training for model '{selected_job}'")
                        time.sleep(1)
                        st.rerun()
                
                # Add delete button for completed/failed/stopped jobs
                if status in ['completed', 'failed', 'stopped']:
                    if st.button("Delete Job"):
                        del st.session_state.training_progress[selected_job]
                        if selected_job in st.session_state.get('stop_events', {}):
                            del st.session_state.stop_events[selected_job]
                        add_log(f"Deleted training job for model '{selected_job}'")
                        st.success(f"Training job for model '{selected_job}' deleted")
                        time.sleep(1)
                        st.rerun()
            
            # Display training progress plot
            st.markdown("### Training Progress")
            plot_training_progress(selected_job)
            
            # Display logs
            st.markdown("### Training Logs")
            display_logs()

# Display running jobs summary at the bottom
st.markdown("---")
st.subheader("Running Jobs Summary")
running_jobs = get_running_training_jobs()

if not running_jobs:
    st.info("No active training jobs")
else:
    for job in running_jobs:
        progress = st.session_state.training_progress[job]
        col1, col2, col3 = st.columns([2, 1, 1])
        
        with col1:
            st.markdown(f"**{job}**")
        
        with col2:
            st.markdown(f"Epoch {progress['current_epoch']}/{progress['total_epochs']}")
        
        with col3:
            st.progress(progress['progress']/100)