|
import json |
|
import pickle |
|
from collections import defaultdict |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
|
|
import pandas as pd |
|
import torch |
|
from omegaconf import MISSING |
|
|
|
from s3prl.dataio.dataset import EncodeMultiLabel, LoadAudio |
|
from s3prl.dataio.encoder import CategoryEncoder |
|
from s3prl.dataio.sampler import FixedBatchSizeBatchSampler |
|
from s3prl.nn.hear import HearFullyConnectedPrediction |
|
from s3prl.task.scene_prediction import ScenePredictionTask |
|
|
|
from ._hear_util import resample_hear_corpus |
|
from .superb_sid import SuperbSID |
|
|
|
__all__ = ["HearFSD"] |
|
|
|
|
|
def hear_scene_trainvaltest( |
|
target_dir: str, |
|
cache_dir: str, |
|
dataset_root: str, |
|
get_path_only: bool = False, |
|
): |
|
target_dir = Path(target_dir) |
|
|
|
resample_hear_corpus(dataset_root, target_sr=16000) |
|
|
|
dataset_root = Path(dataset_root) |
|
wav_root: Path = dataset_root / "16000" |
|
|
|
train_csv = target_dir / "train.csv" |
|
valid_csv = target_dir / "valid.csv" |
|
test_csv = target_dir / "test_csv" |
|
|
|
if get_path_only: |
|
return train_csv, valid_csv, [test_csv] |
|
|
|
def load_json(filepath): |
|
with open(filepath, "r") as fp: |
|
return json.load(fp) |
|
|
|
def split_to_df(split: str) -> pd.DataFrame: |
|
meta = load_json(dataset_root / f"{split}.json") |
|
data = defaultdict(list) |
|
for k in list(meta.keys()): |
|
data["id"].append(k) |
|
data["wav_path"].append(wav_root / split / k) |
|
data["labels"].append(" ; ".join([str(label).strip() for label in meta[k]])) |
|
return pd.DataFrame(data=data) |
|
|
|
split_to_df("train").to_csv(train_csv, index=False) |
|
split_to_df("valid").to_csv(valid_csv, index=False) |
|
split_to_df("test").to_csv(test_csv, index=False) |
|
|
|
return train_csv, valid_csv, [test_csv] |
|
|
|
|
|
class HearFSD(SuperbSID): |
|
def default_config(self) -> dict: |
|
return dict( |
|
start=0, |
|
stop=None, |
|
target_dir=MISSING, |
|
cache_dir=None, |
|
remove_all_cache=False, |
|
prepare_data=dict( |
|
dataset_root=MISSING, |
|
), |
|
build_batch_sampler=dict( |
|
train=dict( |
|
batch_size=10, |
|
shuffle=True, |
|
), |
|
valid=dict( |
|
batch_size=1, |
|
), |
|
test=dict( |
|
batch_size=1, |
|
), |
|
), |
|
build_upstream=dict( |
|
name=MISSING, |
|
), |
|
build_featurizer=dict( |
|
layer_selections=None, |
|
normalize=False, |
|
), |
|
build_downstream=dict( |
|
hidden_layers=2, |
|
pooling_type="MeanPooling", |
|
), |
|
build_model=dict( |
|
upstream_trainable=False, |
|
), |
|
build_task=dict( |
|
prediction_type="multilabel", |
|
scores=["mAP", "top1_acc", "d_prime", "aucroc"], |
|
), |
|
build_optimizer=dict( |
|
name="Adam", |
|
conf=dict( |
|
lr=1.0e-3, |
|
), |
|
), |
|
build_scheduler=dict( |
|
name="ExponentialLR", |
|
gamma=0.9, |
|
), |
|
save_model=dict(), |
|
save_task=dict(), |
|
train=dict( |
|
total_steps=40000, |
|
log_step=100, |
|
eval_step=1000, |
|
save_step=100, |
|
gradient_clipping=1.0, |
|
gradient_accumulate=1, |
|
valid_metric="mAP", |
|
valid_higher_better=True, |
|
auto_resume=True, |
|
resume_ckpt_dir=None, |
|
), |
|
evaluate=dict(), |
|
) |
|
|
|
def prepare_data( |
|
self, |
|
prepare_data: dict, |
|
target_dir: str, |
|
cache_dir: str, |
|
get_path_only: bool = False, |
|
): |
|
return hear_scene_trainvaltest( |
|
**self._get_current_arguments(flatten_dict="prepare_data") |
|
) |
|
|
|
def build_encoder( |
|
self, |
|
build_encoder: dict, |
|
target_dir: str, |
|
cache_dir: str, |
|
train_csv_path: str, |
|
valid_csv_path: str, |
|
test_csv_paths: list, |
|
get_path_only: bool = False, |
|
): |
|
encoder_path = Path(target_dir) / "encoder.pkl" |
|
if get_path_only: |
|
return encoder_path |
|
|
|
train_csv = pd.read_csv(train_csv_path) |
|
valid_csv = pd.read_csv(valid_csv_path) |
|
test_csvs = [pd.read_csv(path) for path in test_csv_paths] |
|
all_csv = pd.concat([train_csv, valid_csv, *test_csvs]) |
|
all_labels = [] |
|
for rowid, row in all_csv.iterrows(): |
|
labels = str(row["labels"]).split(";") |
|
labels = [l.strip() for l in labels] |
|
all_labels.extend(labels) |
|
|
|
encoder = CategoryEncoder(all_labels) |
|
with open(encoder_path, "wb") as f: |
|
pickle.dump(encoder, f) |
|
|
|
return encoder_path |
|
|
|
def build_dataset( |
|
self, |
|
build_dataset: dict, |
|
target_dir: str, |
|
cache_dir: str, |
|
mode: str, |
|
data_csv: str, |
|
encoder_path: str, |
|
frame_shift: int, |
|
): |
|
df = pd.read_csv(data_csv) |
|
ids = df["id"].tolist() |
|
wav_paths = df["wav_path"].tolist() |
|
labels = [ |
|
[single_label.strip() for single_label in str(label_str).split(";")] |
|
for label_str in df["labels"].tolist() |
|
] |
|
with open(encoder_path, "rb") as f: |
|
encoder = pickle.load(f) |
|
|
|
audio_loader = LoadAudio(wav_paths) |
|
label_encoder = EncodeMultiLabel(labels, encoder) |
|
|
|
class Dataset: |
|
def __len__(self): |
|
return len(audio_loader) |
|
|
|
def __getitem__(self, index: int): |
|
audio = audio_loader[index] |
|
label = label_encoder[index] |
|
return { |
|
"x": audio["wav"], |
|
"x_len": audio["wav_len"], |
|
"y": label["binary_labels"], |
|
"labels": label["labels"], |
|
"unique_name": ids[index], |
|
} |
|
|
|
dataset = Dataset() |
|
return dataset |
|
|
|
def build_batch_sampler( |
|
self, |
|
build_batch_sampler: dict, |
|
target_dir: str, |
|
cache_dir: str, |
|
mode: str, |
|
data_csv: str, |
|
dataset, |
|
): |
|
@dataclass |
|
class Config: |
|
train: dict = None |
|
valid: dict = None |
|
test: dict = None |
|
|
|
conf = Config(**build_batch_sampler) |
|
return FixedBatchSizeBatchSampler(dataset, **(conf.train or {})) |
|
|
|
def build_downstream( |
|
self, |
|
build_downstream: dict, |
|
downstream_input_size: int, |
|
downstream_output_size: int, |
|
downstream_input_stride: int, |
|
): |
|
return HearFullyConnectedPrediction( |
|
downstream_input_size, downstream_output_size, **build_downstream |
|
) |
|
|
|
def build_task( |
|
self, |
|
build_task: dict, |
|
model: torch.nn.Module, |
|
encoder, |
|
valid_df: pd.DataFrame = None, |
|
test_df: pd.DataFrame = None, |
|
): |
|
return ScenePredictionTask(model, encoder, **build_task) |
|
|