File size: 2,099 Bytes
8a35bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
7c4332a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264e02e
7c4332a
264e02e
7c4332a
264e02e
7c4332a
 
264e02e
7c4332a
 
264e02e
 
 
 
7c4332a
 
 
264e02e
7c4332a
 
 
 
264e02e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# -------------------------------------------------------------------
# Pimcore
#
# This source file is available under two different licenses:
#  - GNU General Public License version 3 (GPLv3)
#  - Pimcore Commercial License (PCL)
# Full copyright and license information is available in
# LICENSE.md which is distributed with this source code.
#
#  @copyright  Copyright (c) Pimcore GmbH (http://www.pimcore.org)
#  @license    http://www.pimcore.org/license     GPLv3 and PCL
# -------------------------------------------------------------------


import logging
import asyncio
from .abstract_trainer import AbstractTrainer
from .training_status import TrainingStatus
from concurrent.futures import ThreadPoolExecutor


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)



class TrainingManager: 

    __training_task = None
    __trainer: AbstractTrainer = None

    async def __do_start_training(self, parameters):
        logger.info('do start training')
        
        loop = asyncio.get_running_loop()
        with ThreadPoolExecutor() as pool:
            await loop.run_in_executor(pool, self.__trainer.start_training, parameters)

        logger.info('done')

    async def start_training(self, trainer: AbstractTrainer, parameters):
        logger.info('start training')
        
        if self.__training_task is None or self.__training_task.done():
            self.__trainer = trainer
            self.__training_task = asyncio.create_task(self.__do_start_training(parameters))
        else:
            raise RuntimeError("Training already running.")

    def get_task_status(self) -> TrainingStatus:

        if self.__trainer is None:
            return TrainingStatus()
        
        return self.__trainer.get_status()

    def stop_task(self):
        if self.__training_task is not None and not self.__training_task.done() and self.__trainer is not None:
            self.__trainer.get_status().abort_training("Stopping training")
            #self.__training_task.cancel()

        else:
            raise RuntimeError("No task running.")