|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
|
from fairseq import file_utils |
|
from fairseq.data.encoders import register_bpe |
|
from fairseq.dataclass import FairseqDataclass |
|
|
|
from .gpt2_bpe_utils import get_encoder |
|
|
|
|
|
DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" |
|
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" |
|
|
|
|
|
@dataclass |
|
class GPT2BPEConfig(FairseqDataclass): |
|
gpt2_encoder_json: str = field( |
|
default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"} |
|
) |
|
gpt2_vocab_bpe: str = field( |
|
default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"} |
|
) |
|
|
|
|
|
@register_bpe("gpt2", dataclass=GPT2BPEConfig) |
|
class GPT2BPE(object): |
|
def __init__(self, cfg): |
|
encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json) |
|
vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe) |
|
self.bpe = get_encoder(encoder_json, vocab_bpe) |
|
|
|
def encode(self, x: str) -> str: |
|
return " ".join(map(str, self.bpe.encode(x))) |
|
|
|
def decode(self, x: str) -> str: |
|
return self.bpe.decode( |
|
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()] |
|
) |
|
|
|
def is_beginning_of_word(self, x: str) -> bool: |
|
return self.decode(x).startswith(" ") |
|
|