File size: 1,494 Bytes
7c4332a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import logging
import asyncio
from .abstract_trainer import AbstractTrainer
from .training_status import TrainingStatus
from concurrent.futures import ThreadPoolExecutor


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)



class TrainingManager: 

    __training_task = None
    __trainer: AbstractTrainer = None

    task_status = {"progress": 0, "status": "Not started"}

    def __init__(self, trainer: AbstractTrainer):
        self.__trainer = trainer

    async def __do_start_training(self, parameters):
        logger.info('do start training')
        
        loop = asyncio.get_running_loop()
        with ThreadPoolExecutor() as pool:
            await loop.run_in_executor(pool, self.__trainer.start_training, parameters)

        logger.info('done')

    async def start_training(self, parameters):
        logger.info('start training')

        if self.__training_task is None or self.__training_task.done():
            self.__training_task = asyncio.create_task(self.__do_start_training(parameters))
        else:
            raise RuntimeError("Training already running")

    def get_task_status(self) -> TrainingStatus:
        return self.__trainer.get_status()

    def stop_task(self):
        if self.__training_task is not None and not self.__training_task.done():
            self.__trainer.get_status().abort_training("Stopping training")
            #self.__training_task.cancel()

        else:
            raise RuntimeError("No task running")