#!/usr/bin/python3 # -*- coding: utf-8 -*- from collections import defaultdict, OrderedDict import os from typing import Any, Callable, Dict, Iterable, List, Set def namespace_match(pattern: str, namespace: str): """ Matches a namespace pattern against a namespace string. For example, ``*tags`` matches ``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not ``stemmed_tokens``. """ if pattern[0] == '*' and namespace.endswith(pattern[1:]): return True elif pattern == namespace: return True return False class _NamespaceDependentDefaultDict(defaultdict): def __init__(self, non_padded_namespaces: Set[str], padded_function: Callable[[], Any], non_padded_function: Callable[[], Any]) -> None: self._non_padded_namespaces = set(non_padded_namespaces) self._padded_function = padded_function self._non_padded_function = non_padded_function super(_NamespaceDependentDefaultDict, self).__init__() def __missing__(self, key: str): if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces): value = self._non_padded_function() else: value = self._padded_function() dict.__setitem__(self, key, value) return value def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]): # add non_padded_namespaces which weren't already present self._non_padded_namespaces.update(non_padded_namespaces) class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict): def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None: super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces, lambda: {padding_token: 0, oov_token: 1}, lambda: {}) class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict): def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None: super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces, lambda: {0: padding_token, 1: oov_token}, lambda: {}) DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels") DEFAULT_PADDING_TOKEN = '[PAD]' DEFAULT_OOV_TOKEN = '[UNK]' NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt' class Vocabulary(object): def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES): self._non_padded_namespaces = set(non_padded_namespaces) self._padding_token = DEFAULT_PADDING_TOKEN self._oov_token = DEFAULT_OOV_TOKEN self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces, self._padding_token, self._oov_token) self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces, self._padding_token, self._oov_token) def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int: if token not in self._token_to_index[namespace]: index = len(self._token_to_index[namespace]) self._token_to_index[namespace][token] = index self._index_to_token[namespace][index] = token return index else: return self._token_to_index[namespace][token] def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]: return self._index_to_token[namespace] def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]: return self._token_to_index[namespace] def get_token_index(self, token: str, namespace: str = 'tokens') -> int: if token in self._token_to_index[namespace]: return self._token_to_index[namespace][token] else: return self._token_to_index[namespace][self._oov_token] def get_token_from_index(self, index: int, namespace: str = 'tokens'): return self._index_to_token[namespace][index] def get_vocab_size(self, namespace: str = 'tokens') -> int: return len(self._token_to_index[namespace]) def save_to_files(self, directory: str): os.makedirs(directory, exist_ok=True) with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f: for namespace_str in self._non_padded_namespaces: f.write('{}\n'.format(namespace_str)) for namespace, token_to_index in self._token_to_index.items(): filename = os.path.join(directory, '{}.txt'.format(namespace)) with open(filename, 'w', encoding='utf-8') as f: for token, _ in token_to_index.items(): f.write('{}\n'.format(token)) @classmethod def from_files(cls, directory: str) -> 'Vocabulary': with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f: non_padded_namespaces = [namespace_str.strip() for namespace_str in f] vocab = cls(non_padded_namespaces=non_padded_namespaces) for namespace_filename in os.listdir(directory): if namespace_filename == NAMESPACE_PADDING_FILE: continue if namespace_filename.startswith("."): continue namespace = namespace_filename.replace('.txt', '') if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces): is_padded = False else: is_padded = True filename = os.path.join(directory, namespace_filename) vocab.set_from_file(filename, is_padded, namespace=namespace) return vocab def set_from_file(self, filename: str, is_padded: bool = True, oov_token: str = DEFAULT_OOV_TOKEN, namespace: str = "tokens" ): if is_padded: self._token_to_index[namespace] = {self._padding_token: 0} self._index_to_token[namespace] = {0: self._padding_token} else: self._token_to_index[namespace] = {} self._index_to_token[namespace] = {} with open(filename, 'r', encoding='utf-8') as f: index = 1 if is_padded else 0 for row in f: token = str(row).strip() if token == oov_token: token = self._oov_token self._token_to_index[namespace][token] = index self._index_to_token[namespace][index] = token index += 1 def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"): result = list() for token in tokens: idx = self._token_to_index[namespace].get(token) if idx is None: idx = self._token_to_index[namespace][self._oov_token] result.append(idx) return result def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"): result = list() for idx in ids: idx = self._index_to_token[namespace][idx] result.append(idx) return result def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"): pad_idx = self._token_to_index[namespace][self._padding_token] length = len(ids) if length > max_length: result = ids[:max_length] else: result = ids + [pad_idx] * (max_length - length) return result def demo1(): import jieba vocabulary = Vocabulary() vocabulary.add_token_to_namespace('白天', 'tokens') vocabulary.add_token_to_namespace('晚上', 'tokens') text = '不是在白天, 就是在晚上' tokens = jieba.lcut(text) print(tokens) ids = vocabulary.convert_tokens_to_ids(tokens) print(ids) padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10) print(padded_idx) tokens = vocabulary.convert_ids_to_tokens(padded_idx) print(tokens) return if __name__ == '__main__': demo1()