Spaces:
Sleeping
Sleeping
File size: 8,523 Bytes
f0a00b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from collections import defaultdict, OrderedDict
import os
from typing import Any, Callable, Dict, Iterable, List, Set
def namespace_match(pattern: str, namespace: str):
"""
Matches a namespace pattern against a namespace string. For example, ``*tags`` matches
``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not
``stemmed_tokens``.
"""
if pattern[0] == '*' and namespace.endswith(pattern[1:]):
return True
elif pattern == namespace:
return True
return False
class _NamespaceDependentDefaultDict(defaultdict):
def __init__(self,
non_padded_namespaces: Set[str],
padded_function: Callable[[], Any],
non_padded_function: Callable[[], Any]) -> None:
self._non_padded_namespaces = set(non_padded_namespaces)
self._padded_function = padded_function
self._non_padded_function = non_padded_function
super(_NamespaceDependentDefaultDict, self).__init__()
def __missing__(self, key: str):
if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces):
value = self._non_padded_function()
else:
value = self._padded_function()
dict.__setitem__(self, key, value)
return value
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
# add non_padded_namespaces which weren't already present
self._non_padded_namespaces.update(non_padded_namespaces)
class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces,
lambda: {padding_token: 0, oov_token: 1},
lambda: {})
class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces,
lambda: {0: padding_token, 1: oov_token},
lambda: {})
DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
DEFAULT_PADDING_TOKEN = '[PAD]'
DEFAULT_OOV_TOKEN = '[UNK]'
NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt'
class Vocabulary(object):
def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES):
self._non_padded_namespaces = set(non_padded_namespaces)
self._padding_token = DEFAULT_PADDING_TOKEN
self._oov_token = DEFAULT_OOV_TOKEN
self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces,
self._padding_token,
self._oov_token)
self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces,
self._padding_token,
self._oov_token)
def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int:
if token not in self._token_to_index[namespace]:
index = len(self._token_to_index[namespace])
self._token_to_index[namespace][token] = index
self._index_to_token[namespace][index] = token
return index
else:
return self._token_to_index[namespace][token]
def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]:
return self._index_to_token[namespace]
def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]:
return self._token_to_index[namespace]
def get_token_index(self, token: str, namespace: str = 'tokens') -> int:
if token in self._token_to_index[namespace]:
return self._token_to_index[namespace][token]
else:
return self._token_to_index[namespace][self._oov_token]
def get_token_from_index(self, index: int, namespace: str = 'tokens'):
return self._index_to_token[namespace][index]
def get_vocab_size(self, namespace: str = 'tokens') -> int:
return len(self._token_to_index[namespace])
def save_to_files(self, directory: str):
os.makedirs(directory, exist_ok=True)
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f:
for namespace_str in self._non_padded_namespaces:
f.write('{}\n'.format(namespace_str))
for namespace, token_to_index in self._token_to_index.items():
filename = os.path.join(directory, '{}.txt'.format(namespace))
with open(filename, 'w', encoding='utf-8') as f:
for token, _ in token_to_index.items():
f.write('{}\n'.format(token))
@classmethod
def from_files(cls, directory: str) -> 'Vocabulary':
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f:
non_padded_namespaces = [namespace_str.strip() for namespace_str in f]
vocab = cls(non_padded_namespaces=non_padded_namespaces)
for namespace_filename in os.listdir(directory):
if namespace_filename == NAMESPACE_PADDING_FILE:
continue
if namespace_filename.startswith("."):
continue
namespace = namespace_filename.replace('.txt', '')
if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
is_padded = False
else:
is_padded = True
filename = os.path.join(directory, namespace_filename)
vocab.set_from_file(filename, is_padded, namespace=namespace)
return vocab
def set_from_file(self,
filename: str,
is_padded: bool = True,
oov_token: str = DEFAULT_OOV_TOKEN,
namespace: str = "tokens"
):
if is_padded:
self._token_to_index[namespace] = {self._padding_token: 0}
self._index_to_token[namespace] = {0: self._padding_token}
else:
self._token_to_index[namespace] = {}
self._index_to_token[namespace] = {}
with open(filename, 'r', encoding='utf-8') as f:
index = 1 if is_padded else 0
for row in f:
token = str(row).strip()
if token == oov_token:
token = self._oov_token
self._token_to_index[namespace][token] = index
self._index_to_token[namespace][index] = token
index += 1
def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"):
result = list()
for token in tokens:
idx = self._token_to_index[namespace].get(token)
if idx is None:
idx = self._token_to_index[namespace][self._oov_token]
result.append(idx)
return result
def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"):
result = list()
for idx in ids:
idx = self._index_to_token[namespace][idx]
result.append(idx)
return result
def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"):
pad_idx = self._token_to_index[namespace][self._padding_token]
length = len(ids)
if length > max_length:
result = ids[:max_length]
else:
result = ids + [pad_idx] * (max_length - length)
return result
def demo1():
import jieba
vocabulary = Vocabulary()
vocabulary.add_token_to_namespace('白天', 'tokens')
vocabulary.add_token_to_namespace('晚上', 'tokens')
text = '不是在白天, 就是在晚上'
tokens = jieba.lcut(text)
print(tokens)
ids = vocabulary.convert_tokens_to_ids(tokens)
print(ids)
padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10)
print(padded_idx)
tokens = vocabulary.convert_ids_to_tokens(padded_idx)
print(tokens)
return
if __name__ == '__main__':
demo1()
|