|
import json |
|
import os |
|
import pickle |
|
import random |
|
|
|
import kan |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS |
|
from Utility.utils import load_json_from_path |
|
|
|
|
|
class MetricsCombiner(torch.nn.Module): |
|
def __init__(self, m): |
|
super().__init__() |
|
self.scoring_function = kan.KAN(width=[3, 5, 1], grid=5, k=5, seed=m) |
|
|
|
def forward(self, x): |
|
return self.scoring_function(x.squeeze()) |
|
|
|
|
|
class EnsembleModel(torch.nn.Module): |
|
def __init__(self, models): |
|
super().__init__() |
|
self.models = models |
|
|
|
def forward(self, x): |
|
distances = list() |
|
for model in self.models: |
|
distances.append(model(x)) |
|
return sum(distances) / len(distances) |
|
|
|
|
|
def create_learned_cache(model_path, cache_root="."): |
|
checkpoint = torch.load(model_path, map_location='cpu') |
|
embedding_provider = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]).encoder.language_embedding |
|
embedding_provider.requires_grad_(False) |
|
language_list = load_json_from_path(os.path.join(cache_root, "supervised_languages.json")) |
|
tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json") |
|
map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json") |
|
asp_dict_path = os.path.join(cache_root, "asp_dict.pkl") |
|
if not os.path.exists(tree_lookup_path) or not os.path.exists(map_lookup_path): |
|
raise FileNotFoundError("Please ensure the caches exist!") |
|
if not os.path.exists(asp_dict_path): |
|
raise FileNotFoundError(f"{asp_dict_path} must be downloaded separately.") |
|
tree_dist = load_json_from_path(tree_lookup_path) |
|
map_dist = load_json_from_path(map_lookup_path) |
|
with open(asp_dict_path, 'rb') as dictfile: |
|
asp_sim = pickle.load(dictfile) |
|
lang_list = list(asp_sim.keys()) |
|
largest_value_map_dist = 0.0 |
|
for _, values in map_dist.items(): |
|
for _, value in values.items(): |
|
largest_value_map_dist = max(largest_value_map_dist, value) |
|
iso_codes_to_ids = load_json_from_path(os.path.join(cache_root, "iso_lookup.json"))[-1] |
|
train_set = language_list |
|
batch_size = 128 |
|
model_list = list() |
|
print_intermediate_results = False |
|
|
|
|
|
n_models = 5 |
|
print(f"Training ensemble of {n_models} models for learned distance metric.") |
|
for m in range(n_models): |
|
model_list.append(MetricsCombiner(m)) |
|
optim = torch.optim.Adam(model_list[-1].parameters(), lr=0.0005) |
|
running_loss = list() |
|
for epoch in tqdm(range(35), desc=f"MetricsCombiner {m + 1}/{n_models} - Epoch"): |
|
for i in range(1000): |
|
|
|
embedding_distance_batch = list() |
|
metric_distance_batch = list() |
|
for _ in range(batch_size): |
|
lang_1 = random.sample(train_set, 1)[0] |
|
lang_2 = random.sample(train_set, 1)[0] |
|
embedding_distance_batch.append(torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze())) |
|
try: |
|
_tree_dist = tree_dist[lang_2][lang_1] |
|
except KeyError: |
|
_tree_dist = tree_dist[lang_1][lang_2] |
|
try: |
|
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist |
|
except KeyError: |
|
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist |
|
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)] |
|
metric_distance_batch.append(torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32)) |
|
|
|
|
|
scores = model_list[-1](torch.stack(metric_distance_batch).squeeze()) |
|
if print_intermediate_results: |
|
print("==================================") |
|
print(scores.detach().squeeze()[:9]) |
|
print(torch.stack(embedding_distance_batch).squeeze()[:9]) |
|
loss = torch.nn.functional.mse_loss(scores.squeeze(), torch.stack(embedding_distance_batch).squeeze(), reduction="none") |
|
loss = loss / (torch.stack(embedding_distance_batch).squeeze() + 0.0001) |
|
loss = loss.mean() |
|
|
|
running_loss.append(loss.item()) |
|
optim.zero_grad() |
|
loss.backward() |
|
optim.step() |
|
|
|
print("\n\n") |
|
print(sum(running_loss) / len(running_loss)) |
|
print("\n\n") |
|
running_loss = list() |
|
|
|
|
|
|
|
|
|
|
|
ensemble = EnsembleModel(model_list) |
|
|
|
running_loss = list() |
|
for i in range(100): |
|
|
|
embedding_distance_batch = list() |
|
metric_distance_batch = list() |
|
for _ in range(batch_size): |
|
lang_1 = random.sample(train_set, 1)[0] |
|
lang_2 = random.sample(train_set, 1)[0] |
|
embedding_distance_batch.append(torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze())) |
|
try: |
|
_tree_dist = tree_dist[lang_2][lang_1] |
|
except KeyError: |
|
_tree_dist = tree_dist[lang_1][lang_2] |
|
try: |
|
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist |
|
except KeyError: |
|
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist |
|
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)] |
|
metric_distance_batch.append(torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32)) |
|
|
|
scores = ensemble(torch.stack(metric_distance_batch).squeeze()) |
|
print("==================================") |
|
print(scores.detach().squeeze()[:9]) |
|
print(torch.stack(embedding_distance_batch).squeeze()[:9]) |
|
loss = torch.nn.functional.mse_loss(scores.squeeze(), torch.stack(embedding_distance_batch).squeeze()) |
|
running_loss.append(loss.item()) |
|
|
|
print("\n\n") |
|
print(sum(running_loss) / len(running_loss)) |
|
|
|
language_to_language_to_learned_distance = dict() |
|
|
|
for lang_1 in tqdm(tree_dist): |
|
for lang_2 in tree_dist: |
|
try: |
|
if lang_2 in language_to_language_to_learned_distance: |
|
if lang_1 in language_to_language_to_learned_distance[lang_2]: |
|
continue |
|
if lang_1 not in language_to_language_to_learned_distance: |
|
language_to_language_to_learned_distance[lang_1] = dict() |
|
try: |
|
_tree_dist = tree_dist[lang_2][lang_1] |
|
except KeyError: |
|
_tree_dist = tree_dist[lang_1][lang_2] |
|
try: |
|
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist |
|
except KeyError: |
|
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist |
|
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)] |
|
metric_distance = torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32) |
|
with torch.inference_mode(): |
|
predicted_distance = ensemble(metric_distance.unsqueeze(0)).squeeze() |
|
language_to_language_to_learned_distance[lang_1][lang_2] = predicted_distance.item() |
|
except ValueError: |
|
continue |
|
except KeyError: |
|
continue |
|
|
|
with open(os.path.join(cache_root, 'lang_1_to_lang_2_to_learned_dist.json'), 'w', encoding='utf-8') as f: |
|
json.dump(language_to_language_to_learned_distance, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
if __name__ == '__main__': |
|
create_learned_cache("../../Models/ToucanTTS_Meta/best.pt") |
|
|