lmzjms's picture
Upload 1162 files
0b32ad6 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import re
import os
import typing as tp
from collections import Counter
from multiprocessing import Pool
import torch
SPACE_NORMALIZER = re.compile(r"\s+")
def _safe_readline(fd) -> str:
pos = fd.tell()
while True:
try:
return fd.readline()
except UnicodeDecodeError:
pos -= 1
fd.seek(pos) # search where this character begins
def find_offsets(filename: str, num_chunks: int) -> tp.List[int]:
"""
given a file and a number of chuncks, find the offsets in the file
to be able to chunk around full lines.
"""
with open(filename, "r", encoding="utf-8") as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
_safe_readline(f)
offsets[i] = f.tell()
offsets[-1] = size
return offsets
class ChunkLineIterator:
"""
Iterator to properly iterate over lines of a file chunck.
"""
def __init__(self, fd, start_offset: int, end_offset: int):
self._fd = fd
self._start_offset = start_offset
self._end_offset = end_offset
def __iter__(self) -> tp.Iterable[str]:
self._fd.seek(self._start_offset)
# next(f) breaks f.tell(), hence readline() must be used
line = _safe_readline(self._fd)
while line:
pos = self._fd.tell()
# f.tell() does not always give the byte position in the file
# sometimes it skips to a very large number
# it is unlikely that through a normal read we go from
# end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely
# that the procedure breaks by the undeterministic behavior of
# f.tell()
if (
self._end_offset > 0
and pos > self._end_offset
and pos < self._end_offset + 2**32
):
break
yield line
line = self._fd.readline()
class Chunker:
"""
contextmanager to read a chunck of a file line by line.
"""
def __init__(self, path: str, start_offset: int, end_offset: int):
self.path = path
self.start_offset = start_offset
self.end_offset = end_offset
def __enter__(self) -> ChunkLineIterator:
self.fd = open(self.path, "r", encoding="utf-8")
return ChunkLineIterator(self.fd, self.start_offset, self.end_offset)
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.fd.close()
def tokenize_line(line):
line = SPACE_NORMALIZER.sub(" ", line)
line = line.strip()
return line.split()
def item(tensor):
# tpu-comment: making this a no-op for xla devices.
if torch.is_tensor(tensor) and tensor.device.type == "xla":
return tensor.detach()
if hasattr(tensor, "item"):
return tensor.item()
if hasattr(tensor, "__getitem__"):
return tensor[0]
return tensor
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == "wordpiece":
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == "letter":
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "silence":
import re
sentence = sentence.replace("<SIL>", "")
sentence = re.sub(" +", " ", sentence).strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol in {"subword_nmt", "@@ ", "@@"}:
if symbol == "subword_nmt":
symbol = "@@ "
sentence = (sentence + " ").replace(symbol, "").rstrip()
elif symbol == "none":
pass
elif symbol is not None:
raise NotImplementedError(f"Unknown post_process option: {symbol}")
return sentence
class Dictionary:
"""A mapping from symbols to consecutive integers"""
def __init__(
self,
*, # begin keyword-only arguments
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
extra_special_symbols=None,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def get_count(self, idx):
return self.count[idx]
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def string(
self,
tensor,
bpe_symbol=None,
escape_unk=False,
extra_symbols_to_ignore=None,
unk_string=None,
include_eos=False,
separator=" ",
):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if torch.is_tensor(tensor) and tensor.dim() == 2:
return "\n".join(
self.string(
t,
bpe_symbol,
escape_unk,
extra_symbols_to_ignore,
include_eos=include_eos,
)
for t in tensor
)
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
if not include_eos:
extra_symbols_to_ignore.add(self.eos())
def token_string(i):
if i == self.unk():
if unk_string is not None:
return unk_string
else:
return self.unk_string(escape_unk)
else:
return self[i]
if hasattr(self, "bos_index"):
extra_symbols_to_ignore.add(self.bos())
sent = separator.join(
token_string(i) for i in tensor if item(i) not in extra_symbols_to_ignore
)
return post_process(sent, bpe_symbol)
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
if escape:
return "<{}>".format(self.unk_word)
else:
return self.unk_word
def add_symbol(self, word, n=1, overwrite=False):
"""Adds a word to the dictionary"""
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def update(self, new_dict):
"""Updates counts from new dictionary."""
for word in new_dict.symbols:
idx2 = new_dict.indices[word]
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + new_dict.count[idx2]
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(new_dict.count[idx2])
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if nwords <= 0:
nwords = len(self)
new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
new_symbols = self.symbols[: self.nspecial]
new_count = self.count[: self.nspecial]
c = Counter(
dict(
sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
)
)
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(count)
else:
break
assert len(new_symbols) == len(new_indices)
self.count = list(new_count)
self.symbols = list(new_symbols)
self.indices = new_indices
self.pad_to_multiple_(padding_factor)
def pad_to_multiple_(self, padding_factor):
"""Pad Dictionary size to be a multiple of *padding_factor*."""
if padding_factor > 1:
i = 0
while len(self) % padding_factor != 0:
symbol = "madeupword{:04d}".format(i)
self.add_symbol(symbol, n=0)
i += 1
def bos(self):
"""Helper to get index of beginning-of-sentence symbol"""
return self.bos_index
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.eos_index
def unk(self):
"""Helper to get index of unk symbol"""
return self.unk_index
@classmethod
def load(cls, f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f)
return d
def add_from_file(self, f):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if isinstance(f, str):
try:
with open(f, "r", encoding="utf-8") as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception(
"Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f)
)
return
lines = f.readlines()
indices_start_line = self._load_meta(lines)
for line in lines[indices_start_line:]:
try:
line, field = line.rstrip().rsplit(" ", 1)
if field == "#fairseq:overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
word = line
if word in self and not overwrite:
raise RuntimeError(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file.".format(word)
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError(
f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\""
)
def _save(self, f, kv_iterator):
if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, "w", encoding="utf-8") as fd:
return self.save(fd)
for k, v in kv_iterator:
print("{} {}".format(k, v), file=f)
def _get_meta(self):
return [], []
def _load_meta(self, lines):
return 0
def save(self, f):
"""Stores dictionary into a text file"""
ex_keys, ex_vals = self._get_meta()
self._save(
f,
zip(
ex_keys + self.symbols[self.nspecial :],
ex_vals + self.count[self.nspecial :],
),
)
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
def encode_line(
self,
line,
line_tokenizer=tokenize_line,
add_if_not_exist=True,
consumer=None,
append_eos=True,
reverse_order=False,
) -> torch.IntTensor:
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids
@staticmethod
def _add_file_to_dictionary_single_worker(
filename,
tokenize,
eos_word,
start_offset,
end_offset,
):
counter = Counter()
with Chunker(filename, start_offset, end_offset) as line_iterator:
for line in line_iterator:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in sorted(counter.items()):
dict.add_symbol(w, c)
local_file = filename
offsets = find_offsets(local_file, num_workers)
if num_workers > 1:
chunks = zip(offsets, offsets[1:])
pool = Pool(processes=num_workers)
results = []
for (start_offset, end_offset) in chunks:
results.append(
pool.apply_async(
Dictionary._add_file_to_dictionary_single_worker,
(
local_file,
tokenize,
dict.eos_word,
start_offset,
end_offset,
),
)
)
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(
Dictionary._add_file_to_dictionary_single_worker(
local_file, tokenize, dict.eos_word, offsets[0], offsets[1]
)
)