RVC-Speakers / speakers /tasks /base_task.py
glide-the
Add large files to Git LFS
04ffec9
from abc import abstractmethod
from typing import List, Dict
from speakers.load.serializable import Serializable
from speakers.processors import ProcessorData, BaseProcessor
from speakers.server.model.flow_data import PayLoad
import logging
class FlowData(Serializable):
"""
当前runner的任务参数
"""
@property
@abstractmethod
def type(self) -> str:
"""Type of the Message, used for serialization."""
@property
def lc_serializable(self) -> bool:
"""Whether this class is Processor serializable."""
return True
class Runner(Serializable):
""" runner的任务id"""
task_id: str
flow_data: FlowData
@property
def type(self) -> str:
"""Type of the Runner Message, used for serialization."""
return 'runner'
@property
def lc_serializable(self) -> bool:
"""Whether this class is Processor serializable."""
return True
# Define a base class for tasks
class BaseTask:
"""
基础任务处理器由任务管理器创建,用于执行runner flow 的任务,子类实现具体的处理流程
此类定义了流程runner task的生命周期
"""
def __init__(self, preprocess_dict: Dict[str, BaseProcessor]):
self._progress_hooks = []
self._add_logger_hook()
self._preprocess_dict = preprocess_dict
self.logger = logging.getLogger('base_task_runner')
@classmethod
def from_config(cls, cfg=None):
return cls(preprocess_dict={})
def _add_logger_hook(self):
"""
默认的任务日志监听者
:return:
"""
LOG_MESSAGES = {
'dispatch_voice_task': 'dispatch_voice_task',
'saved': 'Saving results',
}
LOG_MESSAGES_SKIP = {
'skip-no-text': 'No text regions with text! - Skipping',
}
LOG_MESSAGES_ERROR = {
'error': 'task error',
}
async def ph(task_id: str, runner_stat: str, state: str, finished: bool = False):
if state in LOG_MESSAGES:
self.logger.info(LOG_MESSAGES[state])
elif state in LOG_MESSAGES_SKIP:
self.logger.warn(LOG_MESSAGES_SKIP[state])
elif state in LOG_MESSAGES_ERROR:
self.logger.error(LOG_MESSAGES_ERROR[state])
self.add_progress_hook(ph)
def add_progress_hook(self, ph):
"""
注册监听器
:param ph: 监听者
:return:
"""
self._progress_hooks.append(ph)
async def report_progress(self, task_id: str, runner_stat: str, state: str, finished: bool = False):
"""
任务通知监听器
:param task_id: 任务id
:param runner_stat: 任务执行位置
:param state: 状态
:param finished: 是否完成
:return:
"""
for ph in self._progress_hooks:
await ph(task_id, runner_stat, state, finished)
@classmethod
def prepare(cls, payload: PayLoad) -> Runner:
"""
预处理
Args:
payload (PayLoad): runner flow data
Raises:
NotImplementedError: This method should be overridden by subclasses.
"""
raise NotImplementedError
@classmethod
async def dispatch(cls, runner: Runner):
"""
当前runner task具体flow data
Args:
runner (ProcessorData): runner flow data
Raises:
NotImplementedError: This method should be overridden by subclasses.
"""
raise NotImplementedError
@classmethod
def complete(cls, runner: Runner):
"""
后置处理
Args:
runner (Runner): runner flow data
Raises:
NotImplementedError: This method should be overridden by subclasses.
"""
raise NotImplementedError