|
import json |
|
import random |
|
import argparse |
|
import numpy as np |
|
from tqdm import tqdm |
|
from h5py import File |
|
import sklearn.metrics |
|
|
|
random.seed(1) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("train_feature", type=str) |
|
parser.add_argument("train_corpus", type=str) |
|
parser.add_argument("pred_feature", type=str) |
|
parser.add_argument("output_json", type=str) |
|
|
|
args = parser.parse_args() |
|
train_embs = [] |
|
train_idx_to_audioid = [] |
|
with File(args.train_feature, "r") as store: |
|
for audio_id, embedding in tqdm(store.items(), ascii=True): |
|
train_embs.append(embedding[()]) |
|
train_idx_to_audioid.append(audio_id) |
|
|
|
train_annotation = json.load(open(args.train_corpus, "r"))["audios"] |
|
train_audioid_to_tokens = {} |
|
for item in train_annotation: |
|
audio_id = item["audio_id"] |
|
train_audioid_to_tokens[audio_id] = [cap_item["tokens"] for cap_item in item["captions"]] |
|
train_embs = np.stack(train_embs) |
|
|
|
|
|
pred_data = [] |
|
pred_embs = [] |
|
pred_idx_to_audioids = [] |
|
with File(args.pred_feature, "r") as store: |
|
for audio_id, embedding in tqdm(store.items(), ascii=True): |
|
pred_embs.append(embedding[()]) |
|
pred_idx_to_audioids.append(audio_id) |
|
pred_embs = np.stack(pred_embs) |
|
|
|
similarity = sklearn.metrics.pairwise.cosine_similarity(pred_embs, train_embs) |
|
for idx, audio_id in enumerate(pred_idx_to_audioids): |
|
train_idx = similarity[idx].argmax() |
|
pred_data.append({ |
|
"filename": audio_id, |
|
"tokens": random.choice(train_audioid_to_tokens[train_idx_to_audioid[train_idx]]) |
|
}) |
|
json.dump({"predictions": pred_data}, open(args.output_json, "w"), ensure_ascii=False, indent=4) |
|
|