Spaces:
Running
Running
import os | |
import sys | |
import ast | |
import torch | |
import itertools | |
import collections | |
sys.path.append(os.getcwd()) | |
from main.library.speaker_diarization.speechbrain import if_main_process, ddp_barrier | |
from main.library.speaker_diarization.features import register_checkpoint_hooks, mark_as_saver, mark_as_loader | |
DEFAULT_UNK = "<unk>" | |
DEFAULT_BOS = "<bos>" | |
DEFAULT_EOS = "<eos>" | |
DEFAULT_BLANK = "<blank>" | |
class CategoricalEncoder: | |
VALUE_SEPARATOR = " => " | |
EXTRAS_SEPARATOR = "================\n" | |
def __init__(self, starting_index=0, **special_labels): | |
self.lab2ind = {} | |
self.ind2lab = {} | |
self.starting_index = starting_index | |
self.handle_special_labels(special_labels) | |
def handle_special_labels(self, special_labels): | |
if "unk_label" in special_labels: self.add_unk(special_labels["unk_label"]) | |
def __len__(self): | |
return len(self.lab2ind) | |
def from_saved(cls, path): | |
obj = cls() | |
obj.load(path) | |
return obj | |
def update_from_iterable(self, iterable, sequence_input=False): | |
label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable) | |
for label in label_iterator: | |
self.ensure_label(label) | |
def update_from_didataset(self, didataset, output_key, sequence_input=False): | |
with didataset.output_keys_as([output_key]): | |
self.update_from_iterable((data_point[output_key] for data_point in didataset), sequence_input=sequence_input) | |
def limited_labelset_from_iterable(self, iterable, sequence_input=False, n_most_common=None, min_count=1): | |
label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable) | |
counts = collections.Counter(label_iterator) | |
for label, count in counts.most_common(n_most_common): | |
if count < min_count: break | |
self.add_label(label) | |
return counts | |
def load_or_create(self, path, from_iterables=[], from_didatasets=[], sequence_input=False, output_key=None, special_labels={}): | |
try: | |
if if_main_process(): | |
if not self.load_if_possible(path): | |
for iterable in from_iterables: | |
self.update_from_iterable(iterable, sequence_input) | |
for didataset in from_didatasets: | |
if output_key is None: raise ValueError | |
self.update_from_didataset(didataset, output_key, sequence_input) | |
self.handle_special_labels(special_labels) | |
self.save(path) | |
finally: | |
ddp_barrier() | |
self.load(path) | |
def add_label(self, label): | |
if label in self.lab2ind: raise KeyError | |
index = self._next_index() | |
self.lab2ind[label] = index | |
self.ind2lab[index] = label | |
return index | |
def ensure_label(self, label): | |
if label in self.lab2ind: return self.lab2ind[label] | |
else: return self.add_label(label) | |
def insert_label(self, label, index): | |
if label in self.lab2ind: raise KeyError | |
else: self.enforce_label(label, index) | |
def enforce_label(self, label, index): | |
index = int(index) | |
if label in self.lab2ind: | |
if index == self.lab2ind[label]: return | |
else: del self.ind2lab[self.lab2ind[label]] | |
if index in self.ind2lab: | |
saved_label = self.ind2lab[index] | |
moving_other = True | |
else: moving_other = False | |
self.lab2ind[label] = index | |
self.ind2lab[index] = label | |
if moving_other: | |
new_index = self._next_index() | |
self.lab2ind[saved_label] = new_index | |
self.ind2lab[new_index] = saved_label | |
def add_unk(self, unk_label=DEFAULT_UNK): | |
self.unk_label = unk_label | |
return self.add_label(unk_label) | |
def _next_index(self): | |
index = self.starting_index | |
while index in self.ind2lab: | |
index += 1 | |
return index | |
def is_continuous(self): | |
indices = sorted(self.ind2lab.keys()) | |
return self.starting_index in indices and all(j - i == 1 for i, j in zip(indices[:-1], indices[1:])) | |
def encode_label(self, label, allow_unk=True): | |
self._assert_len() | |
try: | |
return self.lab2ind[label] | |
except KeyError: | |
if hasattr(self, "unk_label") and allow_unk: return self.lab2ind[self.unk_label] | |
elif hasattr(self, "unk_label") and not allow_unk: raise KeyError | |
elif not hasattr(self, "unk_label") and allow_unk: raise KeyError | |
else: raise KeyError | |
def encode_label_torch(self, label, allow_unk=True): | |
return torch.LongTensor([self.encode_label(label, allow_unk)]) | |
def encode_sequence(self, sequence, allow_unk=True): | |
self._assert_len() | |
return [self.encode_label(label, allow_unk) for label in sequence] | |
def encode_sequence_torch(self, sequence, allow_unk=True): | |
return torch.LongTensor([self.encode_label(label, allow_unk) for label in sequence]) | |
def decode_torch(self, x): | |
self._assert_len() | |
decoded = [] | |
if x.ndim == 1: | |
for element in x: | |
decoded.append(self.ind2lab[int(element)]) | |
else: | |
for subtensor in x: | |
decoded.append(self.decode_torch(subtensor)) | |
return decoded | |
def decode_ndim(self, x): | |
self._assert_len() | |
try: | |
decoded = [] | |
for subtensor in x: | |
decoded.append(self.decode_ndim(subtensor)) | |
return decoded | |
except TypeError: | |
return self.ind2lab[int(x)] | |
def save(self, path): | |
self._save_literal(path, self.lab2ind, self._get_extras()) | |
def load(self, path): | |
lab2ind, ind2lab, extras = self._load_literal(path) | |
self.lab2ind = lab2ind | |
self.ind2lab = ind2lab | |
self._set_extras(extras) | |
def load_if_possible(self, path, end_of_epoch=False): | |
del end_of_epoch | |
try: | |
self.load(path) | |
except FileNotFoundError: | |
return False | |
except (ValueError, SyntaxError): | |
return False | |
return True | |
def expect_len(self, expected_len): | |
self.expected_len = expected_len | |
def ignore_len(self): | |
self.expected_len = None | |
def _assert_len(self): | |
if hasattr(self, "expected_len"): | |
if self.expected_len is None: return | |
if len(self) != self.expected_len: raise RuntimeError | |
else: | |
self.ignore_len() | |
return | |
def _get_extras(self): | |
extras = {"starting_index": self.starting_index} | |
if hasattr(self, "unk_label"): extras["unk_label"] = self.unk_label | |
return extras | |
def _set_extras(self, extras): | |
if "unk_label" in extras: self.unk_label = extras["unk_label"] | |
self.starting_index = extras["starting_index"] | |
def _save_literal(path, lab2ind, extras): | |
with open(path, "w", encoding="utf-8") as f: | |
for label, ind in lab2ind.items(): | |
f.write(repr(label) + CategoricalEncoder.VALUE_SEPARATOR + str(ind) + "\n") | |
f.write(CategoricalEncoder.EXTRAS_SEPARATOR) | |
for key, value in extras.items(): | |
f.write(repr(key) + CategoricalEncoder.VALUE_SEPARATOR + repr(value) + "\n") | |
f.flush() | |
def _load_literal(path): | |
lab2ind, ind2lab, extras = {}, {}, {} | |
with open(path, encoding="utf-8") as f: | |
for line in f: | |
if line == CategoricalEncoder.EXTRAS_SEPARATOR: break | |
literal, ind = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1) | |
label = ast.literal_eval(literal) | |
lab2ind[label] = int(ind) | |
ind2lab[ind] = label | |
for line in f: | |
literal_key, literal_value = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1) | |
extras[ast.literal_eval(literal_key)] = ast.literal_eval(literal_value) | |
return lab2ind, ind2lab, extras |