|
import os.path |
|
|
|
import torch |
|
|
|
from Preprocessing.multilinguality.create_distance_lookups import CacheCreator |
|
from Utility.utils import load_json_from_path |
|
|
|
|
|
class LanguageEmbeddingSpaceStructureLoss(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
cc = CacheCreator(cache_root="Preprocessing/multilinguality") |
|
if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'): |
|
cc.create_tree_cache(cache_root="Preprocessing/multilinguality") |
|
if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'): |
|
cc.create_map_cache(cache_root="Preprocessing/multilinguality") |
|
|
|
self.tree_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json') |
|
self.map_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_map_dist.json') |
|
|
|
|
|
|
|
|
|
self.largest_value_map_dist = 0.0 |
|
for _, values in self.map_dist.items(): |
|
for _, value in values.items(): |
|
self.largest_value_map_dist = max(self.largest_value_map_dist, value) |
|
|
|
self.iso_codes_to_ids = load_json_from_path("Preprocessing/multilinguality/iso_lookup.json")[-1] |
|
self.ids_to_iso_codes = {v: k for k, v in self.iso_codes_to_ids.items()} |
|
|
|
def forward(self, language_ids, language_embeddings): |
|
""" |
|
Args: |
|
language_ids (Tensor): IDs of languages in the same order as the embeddings to calculate the distances according to the metrics. |
|
language_embeddings (Tensor): Batch of language embeddings, of which the distances will be compared to the distances according to the metrics. |
|
|
|
Returns: |
|
Tensor: Language Embedding Structure Loss Value |
|
""" |
|
|
|
losses = list() |
|
for language_id_1, language_embedding_1 in zip(language_ids, language_embeddings): |
|
for language_id_2, language_embedding_2 in zip(language_ids, language_embeddings): |
|
if language_id_1 != language_id_2: |
|
embed_dist = torch.nn.functional.l1_loss(language_embedding_1, language_embedding_2) |
|
lang_1 = self.ids_to_iso_codes[language_id_1] |
|
lang_2 = self.ids_to_iso_codes[language_id_2] |
|
|
|
|
|
try: |
|
tree_dist = self.tree_dist[lang_1][lang_2] |
|
except KeyError: |
|
tree_dist = self.tree_dist[lang_2][lang_1] |
|
|
|
|
|
try: |
|
map_dist = self.map_dist[lang_1][lang_2] / self.largest_value_map_dist |
|
except KeyError: |
|
map_dist = self.map_dist[lang_2][lang_1] / self.largest_value_map_dist |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metric_distance = (torch.tensor(tree_dist) + torch.tensor(map_dist)) / 2 |
|
losses.append(torch.nn.functional.l1_loss(embed_dist, metric_distance)) |
|
|
|
return sum(losses) / len(losses) |
|
|