File size: 1,089 Bytes
7c4332a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import logging
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl

from .training_status import TrainingStatus


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

class ProgressCallback(TrainerCallback):

    __trainingStatus: TrainingStatus = None

    def __init__(self, trainingStatus: TrainingStatus):
        self.__trainingStatus = trainingStatus

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        logger.info(f"Completed step {state.global_step} of {state.max_steps}")

        if self.__trainingStatus.is_training_aborted(): 
            control.should_training_stop = True
            logger.info("Training aborted")
            return

        startPercentage = 21
        endPercentage = 89
        scope = endPercentage - startPercentage
        progress = startPercentage + (state.global_step / state.max_steps) * scope

        self.__trainingStatus.update_status(progress, f"Training model, completed step {state.global_step} of {state.max_steps}")