|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import glob |
|
|
|
import numpy as np |
|
|
|
|
|
DIM = 1024 |
|
|
|
|
|
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False): |
|
target_ids = [tid for tid in target_embs] |
|
source_mat = np.stack(source_embs.values(), axis=0) |
|
normalized_source_mat = source_mat / np.linalg.norm( |
|
source_mat, axis=1, keepdims=True |
|
) |
|
target_mat = np.stack(target_embs.values(), axis=0) |
|
normalized_target_mat = target_mat / np.linalg.norm( |
|
target_mat, axis=1, keepdims=True |
|
) |
|
sim_mat = normalized_source_mat.dot(normalized_target_mat.T) |
|
if return_sim_mat: |
|
return sim_mat |
|
neighbors_map = {} |
|
for i, sentence_id in enumerate(source_embs): |
|
idx = np.argsort(sim_mat[i, :])[::-1][:k] |
|
neighbors_map[sentence_id] = [target_ids[tid] for tid in idx] |
|
return neighbors_map |
|
|
|
|
|
def load_embeddings(directory, LANGS): |
|
sentence_embeddings = {} |
|
sentence_texts = {} |
|
for lang in LANGS: |
|
sentence_embeddings[lang] = {} |
|
sentence_texts[lang] = {} |
|
lang_dir = f"{directory}/{lang}" |
|
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*") |
|
for embed_file in embedding_files: |
|
shard_id = embed_file.split(".")[-1] |
|
embeddings = np.fromfile(embed_file, dtype=np.float32) |
|
num_rows = embeddings.shape[0] // DIM |
|
embeddings = embeddings.reshape((num_rows, DIM)) |
|
|
|
with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file: |
|
for idx, line in enumerate(sentence_file): |
|
sentence_id, sentence = line.strip().split("\t") |
|
sentence_texts[lang][sentence_id] = sentence |
|
sentence_embeddings[lang][sentence_id] = embeddings[idx, :] |
|
|
|
return sentence_embeddings, sentence_texts |
|
|
|
|
|
def compute_accuracy(directory, LANGS): |
|
sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS) |
|
|
|
top_1_accuracy = {} |
|
|
|
top1_str = " ".join(LANGS) + "\n" |
|
for source_lang in LANGS: |
|
top_1_accuracy[source_lang] = {} |
|
top1_str += f"{source_lang} " |
|
for target_lang in LANGS: |
|
top1 = 0 |
|
top5 = 0 |
|
neighbors_map = compute_dist( |
|
sentence_embeddings[source_lang], sentence_embeddings[target_lang] |
|
) |
|
for sentence_id, neighbors in neighbors_map.items(): |
|
if sentence_id == neighbors[0]: |
|
top1 += 1 |
|
if sentence_id in neighbors[:5]: |
|
top5 += 1 |
|
n = len(sentence_embeddings[target_lang]) |
|
top1_str += f"{top1/n} " |
|
top1_str += "\n" |
|
|
|
print(top1_str) |
|
print(top1_str, file=open(f"{directory}/accuracy", "w")) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Analyze encoder outputs") |
|
parser.add_argument("directory", help="Source language corpus") |
|
parser.add_argument("--langs", help="List of langs") |
|
args = parser.parse_args() |
|
langs = args.langs.split(",") |
|
compute_accuracy(args.directory, langs) |
|
|