|
from typing import Dict, List, Any |
|
from pipeline import PreTrainedPipeline |
|
|
|
class PreTrainedPipelineHandler: |
|
def __init__(self, path=""): |
|
|
|
self.pipeline = PreTrainedPipeline(path=path) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Args: |
|
data (Dict[str, Any]): A dictionary containing input text and language codes. |
|
- text (str): The text to translate. |
|
- src_lang (str): Source language code. |
|
- tgt_lang (str): Target language code. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list of dictionaries containing translated sentences. |
|
""" |
|
text = data.get("text", "") |
|
src_lang = data.get("src_lang", "spa_Latn") |
|
tgt_lang = data.get("tgt_lang", "agr_Latn") |
|
|
|
|
|
translation = self.pipeline({ |
|
"text": text, |
|
"src_lang": src_lang, |
|
"tgt_lang": tgt_lang |
|
}) |
|
|
|
|
|
return [{"original": text, "translation": translation["translation"]}] |