import time import traceback from transformers import TrainerCallback from ..globals import Global from ..utils.eta_predictor import ETAPredictor def reset_training_status(): Global.is_train_starting = False Global.is_training = False Global.should_stop_training = False Global.train_started_at = time.time() Global.training_error_message = None Global.training_error_detail = None Global.training_total_epochs = 1 Global.training_current_epoch = 0.0 Global.training_total_steps = 1 Global.training_current_step = 0 Global.training_progress = 0.0 Global.training_log_history = [] Global.training_status_text = "" Global.training_eta_predictor = ETAPredictor() Global.training_eta = None Global.train_output = None Global.train_output_str = None Global.training_params_info_text = "" def get_progress_text(current_epoch, total_epochs, last_loss): progress_detail = f"Epoch {current_epoch:.2f}/{total_epochs}" if last_loss is not None: progress_detail += f", Loss: {last_loss:.4f}" return f"Training... ({progress_detail})" def set_train_output(output): end_by = 'aborted' if Global.should_stop_training else 'completed' result_message = f"Training {end_by}" Global.training_status_text = result_message Global.train_output = output Global.train_output_str = str(output) return result_message def update_training_states( current_step, total_steps, current_epoch, total_epochs, log_history): Global.training_total_steps = total_steps Global.training_current_step = current_step Global.training_total_epochs = total_epochs Global.training_current_epoch = current_epoch Global.training_progress = current_step / total_steps Global.training_log_history = log_history Global.training_eta = Global.training_eta_predictor.predict_eta(current_step, total_steps) if Global.should_stop_training: return last_history = None last_loss = None if len(Global.training_log_history) > 0: last_history = log_history[-1] last_loss = last_history.get('loss', None) Global.training_status_text = get_progress_text( total_epochs=total_epochs, current_epoch=current_epoch, last_loss=last_loss, ) class UiTrainerCallback(TrainerCallback): def _on_progress(self, args, state, control): if Global.should_stop_training: control.should_training_stop = True try: total_steps = ( state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch) current_step = state.global_step total_epochs = args.num_train_epochs current_epoch = state.epoch log_history = state.log_history update_training_states( total_steps=total_steps, current_step=current_step, total_epochs=total_epochs, current_epoch=current_epoch, log_history=log_history ) except Exception as e: print("Error occurred while updating UI status:", e) traceback.print_exc() def on_epoch_begin(self, args, state, control, **kwargs): self._on_progress(args, state, control) def on_step_end(self, args, state, control, **kwargs): self._on_progress(args, state, control)