|
import pickle |
|
import fire |
|
import numpy as np |
|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
|
|
class EmbeddingExtractor(object): |
|
|
|
def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False): |
|
from sentence_transformers import SentenceTransformer |
|
lang2model = { |
|
"zh": "distiluse-base-multilingual-cased", |
|
"en": "bert-base-nli-mean-tokens" |
|
} |
|
lang = "zh" if zh else "en" |
|
model = SentenceTransformer(lang2model[lang]) |
|
|
|
self.extract(caption_file, model, output, dev) |
|
|
|
def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"): |
|
from bert_serving.client import BertClient |
|
client = BertClient(ip) |
|
|
|
self.extract(caption_file, client, output, dev) |
|
|
|
def extract(self, caption_file: str, model, output, dev: bool): |
|
caption_df = pd.read_json(caption_file, dtype={"key": str}) |
|
embeddings = {} |
|
|
|
if dev: |
|
with tqdm(total=caption_df.shape[0], ascii=True) as pbar: |
|
for idx, row in caption_df.iterrows(): |
|
caption = row["caption"] |
|
key = row["key"] |
|
cap_idx = row["caption_index"] |
|
embedding = model.encode([caption]) |
|
embedding = np.array(embedding).reshape(-1) |
|
embeddings[f"{key}_{cap_idx}"] = embedding |
|
pbar.update() |
|
|
|
else: |
|
dump = {} |
|
|
|
with tqdm(total=caption_df.shape[0], ascii=True) as pbar: |
|
for idx, row in caption_df.iterrows(): |
|
key = row["key"] |
|
caption = row["caption"] |
|
value = np.array(model.encode([caption])).reshape(-1) |
|
|
|
if key not in embeddings.keys(): |
|
embeddings[key] = [value] |
|
else: |
|
embeddings[key].append(value) |
|
|
|
pbar.update() |
|
|
|
for key in embeddings: |
|
dump[key] = np.stack(embeddings[key]) |
|
|
|
embeddings = dump |
|
|
|
with open(output, "wb") as f: |
|
pickle.dump(embeddings, f) |
|
|
|
def extract_sbert(self, |
|
input_json: str, |
|
output: str): |
|
from sentence_transformers import SentenceTransformer |
|
import json |
|
import torch |
|
from h5py import File |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = SentenceTransformer("paraphrase-MiniLM-L6-v2") |
|
model = model.to(device) |
|
model.eval() |
|
|
|
data = json.load(open(input_json))["audios"] |
|
with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store: |
|
for sample in data: |
|
audio_id = sample["audio_id"] |
|
for cap in sample["captions"]: |
|
cap_id = cap["cap_id"] |
|
store[f"{audio_id}_{cap_id}"] = model.encode(cap["caption"]) |
|
pbar.update() |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(EmbeddingExtractor) |
|
|