|
"""Downstream expert for Query-by-Example Spoken Term Detection on QUESST 2014.""" |
|
|
|
from collections import defaultdict |
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
|
from functools import partial |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from dtw import dtw |
|
from lxml import etree |
|
from scipy.spatial import distance |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from .dataset import QUESST14Dataset |
|
|
|
|
|
class DownstreamExpert(nn.Module): |
|
""" |
|
Used to handle downstream-specific operations |
|
eg. downstream forward, metric computation, contents to log |
|
""" |
|
|
|
def __init__( |
|
self, upstream_dim: int, downstream_expert: dict, expdir: str, **kwargs |
|
): |
|
super(DownstreamExpert, self).__init__() |
|
self.upstream_dim = upstream_dim |
|
self.max_workers = downstream_expert["max_workers"] |
|
self.feature_normalization = downstream_expert["feature_normalization"] |
|
self.silence_frame = downstream_expert["silence_frame"] |
|
self.datarc = downstream_expert["datarc"] |
|
self.dtwrc = downstream_expert["dtwrc"] |
|
self.expdir = Path(expdir) |
|
self.test_dataset = None |
|
|
|
assert not ( |
|
self.feature_normalization and self.dtwrc["dist_method"] == "cosine_neg_log" |
|
), "Upstream features normalization cannot be used with cosine_neg_log." |
|
|
|
assert ( |
|
self.dtwrc["step_pattern"] == "asymmetric" or not self.dtwrc["subsequence"] |
|
), "Subsequence finding only works under asymmetric setting." |
|
|
|
|
|
def get_dataloader(self, mode): |
|
if mode == "dev": |
|
self.test_dataset = QUESST14Dataset("dev", **self.datarc) |
|
else: |
|
self.test_dataset = QUESST14Dataset("eval", **self.datarc) |
|
|
|
return DataLoader( |
|
self.test_dataset, |
|
shuffle=False, |
|
batch_size=self.datarc["batch_size"], |
|
drop_last=False, |
|
num_workers=self.datarc["num_workers"], |
|
collate_fn=self.test_dataset.collate_fn, |
|
) |
|
|
|
|
|
def forward( |
|
self, |
|
mode, |
|
features, |
|
audio_names, |
|
records, |
|
**kwargs, |
|
): |
|
for feature, audio_name in zip(features, audio_names): |
|
feature = feature.detach().cpu() |
|
if self.silence_frame is not None: |
|
feature = feature[feature.argmax(1) != self.silence_frame] |
|
records["features"].append(feature) |
|
records["audio_names"].append(audio_name) |
|
|
|
|
|
def log_records(self, mode, records, **kwargs): |
|
"""Perform DTW and save results.""" |
|
|
|
|
|
queries = records["features"][: self.test_dataset.n_queries] |
|
docs = records["features"][self.test_dataset.n_queries :] |
|
query_names = records["audio_names"][: self.test_dataset.n_queries] |
|
doc_names = records["audio_names"][self.test_dataset.n_queries :] |
|
|
|
|
|
feature_mean, feature_std = 0.0, 1.0 |
|
if self.feature_normalization: |
|
feats = torch.cat(records["features"]) |
|
feature_mean = feats.mean(0) |
|
feature_std = torch.clamp(feats.std(0), 1e-9) |
|
queries = [((query - feature_mean) / feature_std).numpy() for query in queries] |
|
docs = [((doc - feature_mean) / feature_std).numpy() for doc in docs] |
|
|
|
|
|
if self.dtwrc["dist_method"] == "cosine_exp": |
|
dist_fn = cosine_exp |
|
elif self.dtwrc["dist_method"] == "cosine_neg_log": |
|
dist_fn = cosine_neg_log |
|
else: |
|
dist_fn = partial(distance.cdist, metric=self.dtwrc["dist_method"]) |
|
|
|
|
|
dtwrc = { |
|
"step_pattern": self.dtwrc["step_pattern"], |
|
"keep_internals": False, |
|
"distance_only": False if self.dtwrc["subsequence"] else True, |
|
"open_begin": True if self.dtwrc["subsequence"] else False, |
|
"open_end": True if self.dtwrc["subsequence"] else False, |
|
} |
|
|
|
|
|
results = defaultdict(list) |
|
with ProcessPoolExecutor(max_workers=self.max_workers) as executor: |
|
futures = [] |
|
for query, query_name in zip(queries, query_names): |
|
if len(query) < 5: |
|
results[query_name] = [(doc_name, 0) for doc_name in doc_names] |
|
continue |
|
for doc, doc_name in zip(docs, doc_names): |
|
futures.append( |
|
executor.submit( |
|
match, |
|
query, |
|
doc, |
|
query_name, |
|
doc_name, |
|
dist_fn, |
|
self.dtwrc["minmax_norm"], |
|
dtwrc, |
|
) |
|
) |
|
for future in tqdm( |
|
as_completed(futures), total=len(futures), ncols=0, desc="DTW" |
|
): |
|
query_name, doc_name, score = future.result() |
|
results[query_name].append((doc_name, score)) |
|
|
|
|
|
for query_name, doc_scores in results.items(): |
|
names, scores = zip(*doc_scores) |
|
scores = np.array(scores) |
|
scores = (scores - scores.mean()) / np.clip(scores.std(), 1e-9, np.inf) |
|
results[query_name] = list(zip(names, scores)) |
|
|
|
|
|
score_thresh = 2.0 |
|
|
|
|
|
root = etree.Element( |
|
"stdlist", |
|
termlist_filename="benchmark.stdlist.xml", |
|
indexing_time="1.00", |
|
language="english", |
|
index_size="1", |
|
system_id="benchmark", |
|
) |
|
for query_name, doc_scores in results.items(): |
|
term_list = etree.SubElement( |
|
root, |
|
"detected_termlist", |
|
termid=query_name, |
|
term_search_time="1.0", |
|
oov_term_count="1", |
|
) |
|
for doc_name, score in doc_scores: |
|
etree.SubElement( |
|
term_list, |
|
"term", |
|
file=doc_name, |
|
channel="1", |
|
tbeg="0.000", |
|
dur="0.00", |
|
score=f"{score:.4f}", |
|
decision="YES" if score > score_thresh else "NO", |
|
) |
|
|
|
|
|
etree.ElementTree(root).write( |
|
str(self.expdir / "benchmark.stdlist.xml"), |
|
encoding="UTF-8", |
|
pretty_print=True, |
|
) |
|
|
|
|
|
def match(query, doc, query_name, doc_name, dist_fn, minmax_norm, dtwrc): |
|
"""Match between a query and a doc.""" |
|
dist = dist_fn(query, doc) |
|
|
|
if minmax_norm: |
|
dist_min = dist.min(1)[:, np.newaxis] |
|
dist_max = dist.max(1)[:, np.newaxis] |
|
dist = (dist - dist_min) / np.clip(dist_max - dist_min, 1e-9, np.inf) |
|
|
|
dtw_result = dtw(x=dist, **dtwrc) |
|
cost = dtw_result.normalizedDistance |
|
return query_name, doc_name, -1 * cost |
|
|
|
|
|
def cosine_exp(query, doc): |
|
dist = distance.cdist(query, doc, "cosine") |
|
dist = np.exp(dist) - 1 |
|
return dist |
|
|
|
|
|
def cosine_neg_log(query, doc): |
|
dist = distance.cdist(query, doc, "cosine") |
|
dist = -1 * np.log(1 - dist) |
|
return dist |
|
|