|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from fairseq.tasks import register_task |
|
from fairseq.tasks.speech_to_text import SpeechToTextTask |
|
from fairseq.tasks.translation import TranslationTask, TranslationConfig |
|
|
|
try: |
|
import examples.simultaneous_translation |
|
|
|
import_successful = True |
|
except BaseException: |
|
import_successful = False |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def check_import(flag): |
|
if not flag: |
|
raise ImportError( |
|
"'examples.simultaneous_translation' is not correctly imported. " |
|
"Please considering `pip install -e $FAIRSEQ_DIR`." |
|
) |
|
|
|
|
|
@register_task("simul_speech_to_text") |
|
class SimulSpeechToTextTask(SpeechToTextTask): |
|
def __init__(self, args, tgt_dict): |
|
check_import(import_successful) |
|
super().__init__(args, tgt_dict) |
|
|
|
|
|
@register_task("simul_text_to_text", dataclass=TranslationConfig) |
|
class SimulTextToTextTask(TranslationTask): |
|
def __init__(self, cfg, src_dict, tgt_dict): |
|
check_import(import_successful) |
|
super().__init__(cfg, src_dict, tgt_dict) |
|
|