1-800-BAD-CODE's picture
Create pipeline.py
70c8b16
raw
history blame
861 Bytes
from typing import Dict, List, Any
from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX
class PreTrainedPipeline():
def __init__(self, path: str):
cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX(
directory=path,
spe_filename="sp.model",
model_filename="model.onnx",
config_filename="config.yaml",
)
self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg)
def __call__(self, data: str) -> List[Dict]:
# Use list to generate a batch of size 1
pred_texts: List[List[str]] = self._punctuator.infer([data])
# Can't figure out how to make the text gen widget print multiple lines; use a '\n' for now.
outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}]
return outputs