lmzjms's picture
Upload 1162 files
0b32ad6 verified
"""Downstream expert for Query-by-Example Spoken Term Detection on QUESST 2014."""
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from dtw import dtw
from lxml import etree
from scipy.spatial import distance
from torch.utils.data import DataLoader
from tqdm import tqdm
from .dataset import QUESST14Dataset
class DownstreamExpert(nn.Module):
"""
Used to handle downstream-specific operations
eg. downstream forward, metric computation, contents to log
"""
def __init__(
self, upstream_dim: int, downstream_expert: dict, expdir: str, **kwargs
):
super(DownstreamExpert, self).__init__()
self.upstream_dim = upstream_dim
self.max_workers = downstream_expert["max_workers"]
self.feature_normalization = downstream_expert["feature_normalization"]
self.silence_frame = downstream_expert["silence_frame"]
self.datarc = downstream_expert["datarc"]
self.dtwrc = downstream_expert["dtwrc"]
self.expdir = Path(expdir)
self.test_dataset = None
assert not (
self.feature_normalization and self.dtwrc["dist_method"] == "cosine_neg_log"
), "Upstream features normalization cannot be used with cosine_neg_log."
assert (
self.dtwrc["step_pattern"] == "asymmetric" or not self.dtwrc["subsequence"]
), "Subsequence finding only works under asymmetric setting."
# Interface
def get_dataloader(self, mode):
if mode == "dev":
self.test_dataset = QUESST14Dataset("dev", **self.datarc)
else: # eval
self.test_dataset = QUESST14Dataset("eval", **self.datarc)
return DataLoader(
self.test_dataset,
shuffle=False,
batch_size=self.datarc["batch_size"],
drop_last=False,
num_workers=self.datarc["num_workers"],
collate_fn=self.test_dataset.collate_fn,
)
# Interface
def forward(
self,
mode,
features,
audio_names,
records,
**kwargs,
):
for feature, audio_name in zip(features, audio_names):
feature = feature.detach().cpu()
if self.silence_frame is not None: # remove silence frames
feature = feature[feature.argmax(1) != self.silence_frame]
records["features"].append(feature)
records["audio_names"].append(audio_name)
# interface
def log_records(self, mode, records, **kwargs):
"""Perform DTW and save results."""
# Get precomputed queries & docs
queries = records["features"][: self.test_dataset.n_queries]
docs = records["features"][self.test_dataset.n_queries :]
query_names = records["audio_names"][: self.test_dataset.n_queries]
doc_names = records["audio_names"][self.test_dataset.n_queries :]
# Normalize upstream features
feature_mean, feature_std = 0.0, 1.0
if self.feature_normalization:
feats = torch.cat(records["features"])
feature_mean = feats.mean(0)
feature_std = torch.clamp(feats.std(0), 1e-9)
queries = [((query - feature_mean) / feature_std).numpy() for query in queries]
docs = [((doc - feature_mean) / feature_std).numpy() for doc in docs]
# Define distance function for DTW
if self.dtwrc["dist_method"] == "cosine_exp":
dist_fn = cosine_exp
elif self.dtwrc["dist_method"] == "cosine_neg_log":
dist_fn = cosine_neg_log
else:
dist_fn = partial(distance.cdist, metric=self.dtwrc["dist_method"])
# Define DTW configurations
dtwrc = {
"step_pattern": self.dtwrc["step_pattern"],
"keep_internals": False,
"distance_only": False if self.dtwrc["subsequence"] else True,
"open_begin": True if self.dtwrc["subsequence"] else False,
"open_end": True if self.dtwrc["subsequence"] else False,
}
# Calculate matching scores
results = defaultdict(list)
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
futures = []
for query, query_name in zip(queries, query_names):
if len(query) < 5: # Do not consider too short queries
results[query_name] = [(doc_name, 0) for doc_name in doc_names]
continue
for doc, doc_name in zip(docs, doc_names):
futures.append(
executor.submit(
match,
query,
doc,
query_name,
doc_name,
dist_fn,
self.dtwrc["minmax_norm"],
dtwrc,
)
)
for future in tqdm(
as_completed(futures), total=len(futures), ncols=0, desc="DTW"
):
query_name, doc_name, score = future.result()
results[query_name].append((doc_name, score))
# Normalize scores with regard to each query
for query_name, doc_scores in results.items():
names, scores = zip(*doc_scores)
scores = np.array(scores)
scores = (scores - scores.mean()) / np.clip(scores.std(), 1e-9, np.inf)
results[query_name] = list(zip(names, scores))
# Scores above 2 STDs are seen as detected (top 2.5% as YES)
score_thresh = 2.0
# Build XML tree
root = etree.Element(
"stdlist",
termlist_filename="benchmark.stdlist.xml",
indexing_time="1.00",
language="english",
index_size="1",
system_id="benchmark",
)
for query_name, doc_scores in results.items():
term_list = etree.SubElement(
root,
"detected_termlist",
termid=query_name,
term_search_time="1.0",
oov_term_count="1",
)
for doc_name, score in doc_scores:
etree.SubElement(
term_list,
"term",
file=doc_name,
channel="1",
tbeg="0.000",
dur="0.00",
score=f"{score:.4f}",
decision="YES" if score > score_thresh else "NO",
)
# Output XML
etree.ElementTree(root).write(
str(self.expdir / "benchmark.stdlist.xml"),
encoding="UTF-8",
pretty_print=True,
)
def match(query, doc, query_name, doc_name, dist_fn, minmax_norm, dtwrc):
"""Match between a query and a doc."""
dist = dist_fn(query, doc)
if minmax_norm:
dist_min = dist.min(1)[:, np.newaxis]
dist_max = dist.max(1)[:, np.newaxis]
dist = (dist - dist_min) / np.clip(dist_max - dist_min, 1e-9, np.inf)
dtw_result = dtw(x=dist, **dtwrc)
cost = dtw_result.normalizedDistance
return query_name, doc_name, -1 * cost
def cosine_exp(query, doc):
dist = distance.cdist(query, doc, "cosine")
dist = np.exp(dist) - 1
return dist
def cosine_neg_log(query, doc):
dist = distance.cdist(query, doc, "cosine")
dist = -1 * np.log(1 - dist)
return dist