Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import torch | |
from transformers import AutoTokenizer | |
from torchnlp.encoders.text.text_encoder import TextEncoder | |
from .tokenizer_base import TextEncoderBase | |
class HFTextEncoder(TextEncoderBase): | |
""" | |
Wrapper arround transformers AutoTokenizer: | |
- https://huggingface.co/transformers/model_doc/auto.html#autotokenizer | |
:param model: model to be used. | |
""" | |
def __init__(self, model: str) -> None: | |
super().__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained(model) | |
# Properties from the base class | |
self.stoi = self.tokenizer.get_vocab() | |
self.itos = {v: k for k, v in self.stoi.items()} | |
self._bos_index = self.tokenizer.cls_token_id | |
self._pad_index = self.tokenizer.pad_token_id | |
self._eos_index = self.tokenizer.sep_token_id | |
self._unk_index = self.tokenizer.unk_token_id | |
self._mask_index = self.tokenizer.mask_token_id | |
def encode(self, sequence: str) -> torch.Tensor: | |
"""Encodes a 'sequence'. | |
:param sequence: String 'sequence' to encode. | |
Returns: | |
- torch.Tensor: Encoding of the 'sequence'. | |
""" | |
sequence = TextEncoder.encode(self, sequence) | |
return torch.tensor(self.tokenizer(sequence, truncation=False)["input_ids"]) | |