|
import os |
|
import pickle |
|
|
|
import matplotlib.pyplot as plt |
|
import networkx as nx |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS |
|
from Utility.utils import load_json_from_path |
|
|
|
distance_types = ["tree", "asp", "map", "learned", "l1"] |
|
modes = ["plot_all", "plot_neighbors"] |
|
neighbor = "Latin" |
|
num_neighbors = 12 |
|
distance_type = distance_types[0] |
|
mode = modes[1] |
|
edge_threshold = 0.01 |
|
|
|
|
|
cache_root = "." |
|
supervised_iso_codes = load_json_from_path(os.path.join(cache_root, "supervised_languages.json")) |
|
|
|
if distance_type == "l1": |
|
iso_codes_to_ids = load_json_from_path(os.path.join(cache_root, "iso_lookup.json"))[-1] |
|
model_path = "../../Models/ToucanTTS_Meta/best.pt" |
|
checkpoint = torch.load(model_path, map_location='cpu') |
|
embedding_provider = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]).encoder.language_embedding |
|
embedding_provider.requires_grad_(False) |
|
l1_dist = dict() |
|
seen_langs = set() |
|
for lang_1 in supervised_iso_codes: |
|
if lang_1 not in seen_langs: |
|
seen_langs.add(lang_1) |
|
l1_dist[lang_1] = dict() |
|
for lang_2 in supervised_iso_codes: |
|
if lang_2 not in seen_langs: |
|
l1_dist[lang_1][lang_2] = torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze()) |
|
largest_value_l1_dist = 0.0 |
|
for _, values in l1_dist.items(): |
|
for _, value in values.items(): |
|
largest_value_l1_dist = max(largest_value_l1_dist, value) |
|
for key1 in l1_dist: |
|
for key2 in l1_dist[key1]: |
|
l1_dist[key1][key2] = l1_dist[key1][key2] / largest_value_l1_dist |
|
distance_measure = l1_dist |
|
|
|
if distance_type == "tree": |
|
tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json") |
|
tree_dist = load_json_from_path(tree_lookup_path) |
|
distance_measure = tree_dist |
|
|
|
if distance_type == "map": |
|
map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json") |
|
map_dist = load_json_from_path(map_lookup_path) |
|
largest_value_map_dist = 0.0 |
|
for _, values in map_dist.items(): |
|
for _, value in values.items(): |
|
largest_value_map_dist = max(largest_value_map_dist, value) |
|
for key1 in map_dist: |
|
for key2 in map_dist[key1]: |
|
map_dist[key1][key2] = map_dist[key1][key2] / largest_value_map_dist |
|
distance_measure = map_dist |
|
|
|
if distance_type == "learned": |
|
learned_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json") |
|
learned_dist = load_json_from_path(learned_lookup_path) |
|
largest_value_learned_dist = 0.0 |
|
for _, values in learned_dist.items(): |
|
for _, value in values.items(): |
|
largest_value_learned_dist = max(largest_value_learned_dist, value) |
|
for key1 in learned_dist: |
|
for key2 in learned_dist[key1]: |
|
learned_dist[key1][key2] = learned_dist[key1][key2] / largest_value_learned_dist |
|
distance_measure = learned_dist |
|
|
|
if distance_type == "asp": |
|
asp_dict_path = os.path.join(cache_root, "asp_dict.pkl") |
|
with open(asp_dict_path, 'rb') as dictfile: |
|
asp_sim = pickle.load(dictfile) |
|
lang_list = list(asp_sim.keys()) |
|
asp_dist = dict() |
|
seen_langs = set() |
|
for lang_1 in lang_list: |
|
if lang_1 not in seen_langs: |
|
seen_langs.add(lang_1) |
|
asp_dist[lang_1] = dict() |
|
for index, lang_2 in enumerate(lang_list): |
|
if lang_2 not in seen_langs: |
|
asp_dist[lang_1][lang_2] = 1 - asp_sim[lang_1][index] |
|
distance_measure = asp_dist |
|
|
|
iso_codes_to_names = load_json_from_path(os.path.join(cache_root, "iso_to_fullname.json")) |
|
distances = list() |
|
|
|
for lang_1 in distance_measure: |
|
if lang_1 not in iso_codes_to_names: |
|
continue |
|
if lang_1 not in supervised_iso_codes and iso_codes_to_names[lang_1] != neighbor: |
|
continue |
|
for lang_2 in distance_measure[lang_1]: |
|
try: |
|
if lang_2 not in supervised_iso_codes and iso_codes_to_names[lang_2] != neighbor: |
|
continue |
|
except KeyError: |
|
continue |
|
distances.append((iso_codes_to_names[lang_1], iso_codes_to_names[lang_2], distance_measure[lang_1][lang_2])) |
|
|
|
|
|
G = nx.Graph() |
|
|
|
|
|
min_dist = min(d for _, _, d in distances) |
|
max_dist = max(d for _, _, d in distances) |
|
normalized_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances] |
|
|
|
if mode == "plot_neighbors": |
|
fullnames = list() |
|
fullnames.append(neighbor) |
|
for code in supervised_iso_codes: |
|
fullnames.append(iso_codes_to_names[code]) |
|
supervised_iso_codes = fullnames |
|
d_dist = list() |
|
for entity1, entity2, d in tqdm(normalized_distances): |
|
if (neighbor == entity2 or neighbor == entity1) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): |
|
if entity1 != entity2: |
|
d_dist.append(d) |
|
thresh = sorted(d_dist)[num_neighbors] |
|
|
|
neighbors = list() |
|
for entity1, entity2, d in tqdm(normalized_distances): |
|
if (d < thresh and (neighbor == entity2 or neighbor == entity1)) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): |
|
neighbors.append(entity1) |
|
neighbors.append(entity2) |
|
unique_neighbors = list(set(neighbors)) |
|
unique_neighbors.remove(neighbor) |
|
for entity1, entity2, d in tqdm(normalized_distances): |
|
if (neighbor == entity2 or neighbor == entity1) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): |
|
if entity1 != entity2 and d < thresh: |
|
spring_tension = ((thresh - d) ** 2) * 20000 |
|
print(f"{d}-->{spring_tension}") |
|
G.add_edge(entity1, entity2, weight=spring_tension) |
|
for entity1, entity2, d in tqdm(normalized_distances): |
|
if (entity2 in unique_neighbors and entity1 in unique_neighbors) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): |
|
if entity1 != entity2: |
|
spring_tension = 1 - d |
|
G.add_edge(entity1, entity2, weight=spring_tension) |
|
|
|
|
|
pos = nx.spring_layout(G, weight="weight") |
|
edges = G.edges(data=True) |
|
|
|
|
|
nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01) |
|
|
|
|
|
edges_connected_to_specific_node = [(u, v) for u, v in G.edges() if u == neighbor or v == neighbor] |
|
|
|
nx.draw_networkx_edges(G, pos, edgelist=edges_connected_to_specific_node, edge_color='red', alpha=0.3, width=3) |
|
for u, v, d in edges: |
|
if u == neighbor or v == neighbor: |
|
nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): round((thresh - (d['weight'] / 20000) ** (1 / 2)) * 10, 2)}, font_color="red", alpha=0.3) |
|
else: |
|
pass |
|
|
|
|
|
|
|
nx.draw_networkx_labels(G, pos, font_size=14, font_family='sans-serif', font_color='green') |
|
nx.draw_networkx_labels(G, pos, labels={neighbor: neighbor}, font_size=14, font_family='sans-serif', font_color='red') |
|
|
|
plt.title(f'Graph of {distance_type} Distances') |
|
|
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) |
|
plt.tight_layout(pad=0) |
|
|
|
plt.savefig("avg.png", dpi=300) |
|
plt.show() |
|
|
|
|
|
|
|
elif mode == "plot_all": |
|
for entity1, entity2, d in tqdm(normalized_distances): |
|
if d < edge_threshold and entity1 != entity2: |
|
spring_tension = edge_threshold - d |
|
G.add_edge(entity1, entity2, weight=spring_tension) |
|
|
|
|
|
pos = nx.spring_layout(G, weight="weight") |
|
edges = G.edges(data=True) |
|
|
|
|
|
nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01) |
|
|
|
|
|
nx.draw_networkx_edges(G, pos, alpha=0.1, edge_color="blue") |
|
|
|
|
|
|
|
nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif') |
|
|
|
plt.title(f'Graph of {distance_type} Distances') |
|
|
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) |
|
plt.tight_layout(pad=0) |
|
|
|
plt.show() |
|
|