diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f5bcc7f166ba0c842726773694539df2baebfbf8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,50 @@ +data/ +data2/ +preds/ +preprocessed_data/ +wandb/ + +.ruff_cache/ + +*.sif +*.log +*.out + +.DS_Store + +__pycache__/ +*.py[cod] +*$py.class + +*.so + +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +.ipynb_checkpoints + +pip-log.txt +pip-delete-this-directory.txt + +.python-version + +.env + +.mypy_cache + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277267f50e2f1d7d433975463f1ec501d14195b8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Saruwatari&Saito laboratory, The University of Tokyo + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md index 81694235a57c3c0d5dd207b32e4568ae260fee3d..76ce250b72bbffa3cde821e3b92ef5439e5f31b6 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,12 @@ --- title: UTMOSv2 -emoji: 👁 -colorFrom: red -colorTo: red +emoji: 🌖 +colorFrom: yellow +colorTo: green sdk: gradio -sdk_version: 4.37.2 +sdk_version: 4.38.1 app_file: app.py pinned: false -license: mit --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..cf7bd97f2be675f7e3a0d50e8d785ffc3ae55f48 --- /dev/null +++ b/app.py @@ -0,0 +1,73 @@ +import importlib +from types import SimpleNamespace + +import gradio as gr +import pandas as pd + +# import spaces +import torch + +from utmosv2.utils import get_dataset, get_model + +description = ( + "# 🚀 UTMOSv2 demo\n\n" + "This is a demonstration of MOS prediction using UTMOSv2. " + "This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate." +) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +config = importlib.import_module("utmosv2.config.fusion_stage3") +cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")}) +cfg.reproduce = False +cfg.config = "fusion_stage3" +cfg.print_config = False +cfg.data_config = None +cfg.phase = "inference" +cfg.weight = None +cfg.num_workers = 1 + +# @spaces.GPU +def predict_mos(audio_path: str, domain: str) -> float: + data = pd.DataFrame({"file_path": [audio_path]}) + data["dataset"] = domain + data['mos'] = 0 + + preds = 0.0 + for fold in range(5): + cfg.now_fold = fold + model = get_model(cfg, device) + for _ in range(5): + test_dataset = get_dataset(cfg, data, "test") + p = model(*[torch.tensor(t).unsqueeze(0) for t in test_dataset[0][:-1]]) + preds += p[0] + preds /= 25.0 + return preds + + +with gr.Blocks() as demo: + gr.Markdown(description) + with gr.Row(): + with gr.Column(): + audio = gr.Audio(type="filepath", label="Audio") + domain = gr.Dropdown( + [ + "sarulab", + "bvcc", + "somos", + "blizzard2008", + "blizzard2009", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "blizzard2011", + ], + label="Data-domain ID for the MOS prediction", + ) + submit = gr.Button(value="Submit") + + with gr.Column(): + output = gr.Textbox(label="Predicted MOS", type="text") + submit.click(fn=predict_mos, inputs=[audio, domain], outputs=[output]) + +demo.queue().launch() \ No newline at end of file diff --git a/models/fusion_stage3/fold0_s42_best_model.pth b/models/fusion_stage3/fold0_s42_best_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..b2bff9c798e7885beddddbc1027408d5a25b3d23 --- /dev/null +++ b/models/fusion_stage3/fold0_s42_best_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:92fb17df6d4ef4708ea56cf0dca072b9b63d2b522190ef21ccee4f9ea80864fd +size 818531314 diff --git a/models/fusion_stage3/fold1_s42_best_model.pth b/models/fusion_stage3/fold1_s42_best_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..ebb109eec7db9858d57ff971315a6ae184378161 --- /dev/null +++ b/models/fusion_stage3/fold1_s42_best_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06756833fa743cb6683f109252ce178236464c33fbbee69d4e45cdf1ae7ad0cc +size 818531314 diff --git a/models/fusion_stage3/fold2_s42_best_model.pth b/models/fusion_stage3/fold2_s42_best_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..c6b6cce5fb68ee483493801b437981292e745d04 --- /dev/null +++ b/models/fusion_stage3/fold2_s42_best_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d784c0cdf9f4e4cbf697a8f755c8b0f5a0b842d18ad3d2bb42bbae3802d17a78 +size 818531314 diff --git a/models/fusion_stage3/fold3_s42_best_model.pth b/models/fusion_stage3/fold3_s42_best_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..b9f0bfee8dacdb04bf5fc45fb87f592bd34968c3 --- /dev/null +++ b/models/fusion_stage3/fold3_s42_best_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed133fe9b3e5cb78d0037aa784c7e23650a3e7b4f8ba00a73644c02aeb627758 +size 818531314 diff --git a/models/fusion_stage3/fold4_s42_best_model.pth b/models/fusion_stage3/fold4_s42_best_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..c6f959011f6aa58b31895d07ed0af31d3a19b6f8 --- /dev/null +++ b/models/fusion_stage3/fold4_s42_best_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12659b9db09d753654a7744b6b93ff6f90759a7429da793137b8cce107355967 +size 818531314 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..92ce23f835e21ca921b88f6353b5b35fa3630430 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +numpy>=1.24.4 +pandas>=2.2.2 +torch>=2.3.1 +timm>=1.0.7 +librosa>=0.10.2 +tqdm>=4.66.4 +scikit-learn>=1.3.2 +transformers>=4.42.4 +wandb>=0.17.0 +python-dotenv>=1.0.1 \ No newline at end of file diff --git a/utmosv2/config/c_fusion_stage2.py b/utmosv2/config/c_fusion_stage2.py new file mode 100755 index 0000000000000000000000000000000000000000..64b09d02afcb40019bbee9834a5681e7328acfc9 --- /dev/null +++ b/utmosv2/config/c_fusion_stage2.py @@ -0,0 +1,149 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 16 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="stratified_group", + target="mos", + group="sys_id", +) + +external_data = "all" +use_bvcc = True + + +validation_dataset = "sarulab" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) + +model = SimpleNamespace( + name="ssl_multispec_ext", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="c_ssl_only_stage2", + spec_weight="c_spec_only_stage2", + num_classes=1, + freeze=True, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=8, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/c_fusion_stage3.py b/utmosv2/config/c_fusion_stage3.py new file mode 100755 index 0000000000000000000000000000000000000000..cc03a2c758540084ea0b31eac630b1528ee9cf71 --- /dev/null +++ b/utmosv2/config/c_fusion_stage3.py @@ -0,0 +1,149 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 8 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="stratified_group", + target="mos", + group="sys_id", +) + +external_data = "all" +use_bvcc = True + + +validation_dataset = "sarulab" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) + +model = SimpleNamespace( + name="ssl_multispec_ext", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="c_ssl_only_stage2", + spec_weight="c_spec_only_stage2", + num_classes=1, + freeze=False, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=2, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/c_spec_only_stage1.py b/utmosv2/config/c_spec_only_stage1.py new file mode 100755 index 0000000000000000000000000000000000000000..31f923976b8f103530c82eb5accbd5ac3083ab24 --- /dev/null +++ b/utmosv2/config/c_spec_only_stage1.py @@ -0,0 +1,134 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 10 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="stratified_group", + target="mos", + group="sys_id", +) + +external_data = [] +use_bvcc = True + + +validation_dataset = "bvcc" + +dataset = SimpleNamespace( + name="multi_spec", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model = SimpleNamespace( + name="multi_specv2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/c_spec_only_stage2.py b/utmosv2/config/c_spec_only_stage2.py new file mode 100755 index 0000000000000000000000000000000000000000..29db1cb1dbd7e93916d97687a8cf1044199ac132 --- /dev/null +++ b/utmosv2/config/c_spec_only_stage2.py @@ -0,0 +1,134 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 10 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="stratified_group", + target="mos", + group="sys_id", +) + +external_data = ["sarulab"] +use_bvcc = False + + +validation_dataset = "sarulab" + +dataset = SimpleNamespace( + name="multi_spec", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) + +model = SimpleNamespace( + name="multi_specv2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=5, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/c_ssl_only_stage1.py b/utmosv2/config/c_ssl_only_stage1.py new file mode 100755 index 0000000000000000000000000000000000000000..ef684fcecb481e4584d89cc5226a88a297efdd14 --- /dev/null +++ b/utmosv2/config/c_ssl_only_stage1.py @@ -0,0 +1,68 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="stratified_group", + target="mos", + group="sys_id", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = "all" +use_bvcc = True + + +validation_dataset = "sarulab" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/c_ssl_only_stage2.py b/utmosv2/config/c_ssl_only_stage2.py new file mode 100755 index 0000000000000000000000000000000000000000..30d7594fd2abc513df95d71404fd2133796248a4 --- /dev/null +++ b/utmosv2/config/c_ssl_only_stage2.py @@ -0,0 +1,68 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="stratified_group", + target="mos", + group="sys_id", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = "all" +use_bvcc = True + + +validation_dataset = "sarulab" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=5, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/fusion_stage2.py b/utmosv2/config/fusion_stage2.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5a72065fb1172680d8dae500212b65082f1996 --- /dev/null +++ b/utmosv2/config/fusion_stage2.py @@ -0,0 +1,151 @@ +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 16 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path="preprocessed_data/clip_audio" +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = "all" +use_bvcc = True + +predict_dataset = "ysaito" +# predict_dataset = "bvcc" + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2", + spec_weight="spec_only", + num_classes=1, + freeze=True, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=8, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path="preds", + submit_save_path="submissions", + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_stage2_wo_bc.py b/utmosv2/config/fusion_stage2_wo_bc.py new file mode 100755 index 0000000000000000000000000000000000000000..39b13117c8850d5966ce01870dc7558e2695cdde --- /dev/null +++ b/utmosv2/config/fusion_stage2_wo_bc.py @@ -0,0 +1,160 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 16 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + "sarulab", + # "blizzard2008", + # "blizzard2009", + # "blizzard2011", + # "blizzard2010-EH1", + # "blizzard2010-EH2", + # "blizzard2010-ES1", + # "blizzard2010-ES3", + "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2_wo_bc", + spec_weight="spec_only_wo_bc", + num_classes=1, + freeze=True, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=8, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_stage2_wo_bvcc.py b/utmosv2/config/fusion_stage2_wo_bvcc.py new file mode 100755 index 0000000000000000000000000000000000000000..3274ad71daa8e5b4ba4b7af1f89a5b1b2514853d --- /dev/null +++ b/utmosv2/config/fusion_stage2_wo_bvcc.py @@ -0,0 +1,160 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 16 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "somos", +] +use_bvcc = False + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2_wo_bvcc", + spec_weight="spec_only_wo_bvcc", + num_classes=1, + freeze=True, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=8, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_stage2_wo_sarulab.py b/utmosv2/config/fusion_stage2_wo_sarulab.py new file mode 100755 index 0000000000000000000000000000000000000000..ce4723efb9b4ba3610f35da8e14428fe30b067e5 --- /dev/null +++ b/utmosv2/config/fusion_stage2_wo_sarulab.py @@ -0,0 +1,160 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 16 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + # "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2_wo_sarulab", + spec_weight="spec_only_wo_sarulab", + num_classes=1, + freeze=True, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=8, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_stage2_wo_somos.py b/utmosv2/config/fusion_stage2_wo_somos.py new file mode 100755 index 0000000000000000000000000000000000000000..2ae01891d7bd4173e6c32241ed4a8442295c0cc4 --- /dev/null +++ b/utmosv2/config/fusion_stage2_wo_somos.py @@ -0,0 +1,160 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 16 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + # "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2_wo_somos", + spec_weight="spec_only_wo_somos", + num_classes=1, + freeze=True, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=8, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_stage3.py b/utmosv2/config/fusion_stage3.py new file mode 100755 index 0000000000000000000000000000000000000000..5ae28b30bcedc52ab4de0c6b0a27979158d2310f --- /dev/null +++ b/utmosv2/config/fusion_stage3.py @@ -0,0 +1,150 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 8 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = "all" +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2", + spec_weight="spec_only", + num_classes=1, + freeze=False, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=2, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_stage3_wo_bc.py b/utmosv2/config/fusion_stage3_wo_bc.py new file mode 100755 index 0000000000000000000000000000000000000000..d522daec43d1fb1ad7c6775180b28c081291a277 --- /dev/null +++ b/utmosv2/config/fusion_stage3_wo_bc.py @@ -0,0 +1,160 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 8 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + "sarulab", + # "blizzard2008", + # "blizzard2009", + # "blizzard2011", + # "blizzard2010-EH1", + # "blizzard2010-EH2", + # "blizzard2010-ES1", + # "blizzard2010-ES3", + "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2_wo_bc", + spec_weight="spec_only_wo_bc", + num_classes=1, + freeze=False, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=2, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_stage3_wo_bvcc.py b/utmosv2/config/fusion_stage3_wo_bvcc.py new file mode 100755 index 0000000000000000000000000000000000000000..958d47f67d05621de12b9c7222760e0591422031 --- /dev/null +++ b/utmosv2/config/fusion_stage3_wo_bvcc.py @@ -0,0 +1,160 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 8 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "somos", +] +use_bvcc = False + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2_wo_bvcc", + spec_weight="spec_only_wo_bvcc", + num_classes=1, + freeze=False, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=2, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_stage3_wo_sarulab.py b/utmosv2/config/fusion_stage3_wo_sarulab.py new file mode 100755 index 0000000000000000000000000000000000000000..a1e47052b1b952a990ba75c8b4086e05daf9d172 --- /dev/null +++ b/utmosv2/config/fusion_stage3_wo_sarulab.py @@ -0,0 +1,160 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 8 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + # "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2_wo_sarulab", + spec_weight="spec_only_wo_sarulab", + num_classes=1, + freeze=False, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=2, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_stage3_wo_somos.py b/utmosv2/config/fusion_stage3_wo_somos.py new file mode 100755 index 0000000000000000000000000000000000000000..dec2bc2672614183dbe82392d1132e61cdd364da --- /dev/null +++ b/utmosv2/config/fusion_stage3_wo_somos.py @@ -0,0 +1,160 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 8 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + # "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2_wo_somos", + spec_weight="spec_only_wo_somos", + num_classes=1, + freeze=False, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=2, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_wo_stage1and2.py b/utmosv2/config/fusion_wo_stage1and2.py new file mode 100755 index 0000000000000000000000000000000000000000..7b705c4979b5ba7e300a9026a7555cbdc66f5f3b --- /dev/null +++ b/utmosv2/config/fusion_wo_stage1and2.py @@ -0,0 +1,150 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 8 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = "all" +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight=None, + spec_weight=None, + num_classes=1, + freeze=False, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/fusion_wo_stage2.py b/utmosv2/config/fusion_wo_stage2.py new file mode 100755 index 0000000000000000000000000000000000000000..fed28c58dd73fbcd104f04db39995c81f3c31876 --- /dev/null +++ b/utmosv2/config/fusion_wo_stage2.py @@ -0,0 +1,150 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 8 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = "all" +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="ssl_multispec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), + ssl=SimpleNamespace( + duration=3, + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-4, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model = SimpleNamespace( + name="ssl_multispec_ext_v2", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), + ssl_spec=SimpleNamespace( + ssl_weight="ssl_only_stage2", + spec_weight="spec_only", + num_classes=1, + freeze=False, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/spec_only.py b/utmosv2/config/spec_only.py new file mode 100755 index 0000000000000000000000000000000000000000..131b7795ee87ac8c274e1ec1dc1a6db2a34efe5d --- /dev/null +++ b/utmosv2/config/spec_only.py @@ -0,0 +1,135 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 10 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = "all" +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="multi_spec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model = SimpleNamespace( + name="multi_spec_ext", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/spec_only_wo_bc.py b/utmosv2/config/spec_only_wo_bc.py new file mode 100755 index 0000000000000000000000000000000000000000..015c17d8d5f2ecd4348b259540b62fddcf9f0a63 --- /dev/null +++ b/utmosv2/config/spec_only_wo_bc.py @@ -0,0 +1,145 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 10 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + "sarulab", + # "blizzard2008", + # "blizzard2009", + # "blizzard2011", + # "blizzard2010-EH1", + # "blizzard2010-EH2", + # "blizzard2010-ES1", + # "blizzard2010-ES3", + "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="multi_spec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model = SimpleNamespace( + name="multi_spec_ext", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/spec_only_wo_bvcc.py b/utmosv2/config/spec_only_wo_bvcc.py new file mode 100755 index 0000000000000000000000000000000000000000..bddd5e1925f2f026755a3d2f515030134f768a7a --- /dev/null +++ b/utmosv2/config/spec_only_wo_bvcc.py @@ -0,0 +1,145 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 10 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "somos", +] +use_bvcc = False + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="multi_spec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model = SimpleNamespace( + name="multi_spec_ext", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/spec_only_wo_sarulab.py b/utmosv2/config/spec_only_wo_sarulab.py new file mode 100755 index 0000000000000000000000000000000000000000..0b78cae4285df12c905c776054b4da1befe67ac9 --- /dev/null +++ b/utmosv2/config/spec_only_wo_sarulab.py @@ -0,0 +1,145 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 10 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + # "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="multi_spec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model = SimpleNamespace( + name="multi_spec_ext", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/spec_only_wo_somos.py b/utmosv2/config/spec_only_wo_somos.py new file mode 100755 index 0000000000000000000000000000000000000000..c545b20cd6a6ba03b103eb8f188581d6f100bc72 --- /dev/null +++ b/utmosv2/config/spec_only_wo_somos.py @@ -0,0 +1,145 @@ +from pathlib import Path +from types import SimpleNamespace + +from torchvision import transforms + +from utmosv2.transform.xymasking import XYMasking + +batch_size = 10 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +external_data = [ + "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + # "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +dataset = SimpleNamespace( + name="multi_spec_ext", + specs=[ + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=4096, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=2048, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=1024, + n_mels=512, + shape=(512, 512), + norm=80, + ), + SimpleNamespace( + mode="melspec", + n_fft=4096, + hop_length=32, + win_length=512, + n_mels=512, + shape=(512, 512), + norm=80, + ), + ], + spec_frames=SimpleNamespace( + num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" + ), +) +transform = dict( + train=transforms.Compose( + [ + transforms.Resize((512, 512)), + XYMasking( + num_masks_x=(0, 2), + num_masks_y=(0, 2), + mask_x_length=(10, 40), + mask_y_length=(10, 30), + fill_value=0, + p=0.5, + ), + # transforms.ToTensor(), + ] + ), + valid=transforms.Compose( + [ + transforms.Resize((512, 512)), + # transforms.ToTensor() + ] + ), +) + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model = SimpleNamespace( + name="multi_spec_ext", + multi_spec=SimpleNamespace( + backbone="tf_efficientnetv2_s.in21k_ft_in1k", + pretrained=True, + num_classes=1, + pool_type="catavgmax", + # feature_height=16, + atten=True, + # classifier=None, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + extend="tile", +) diff --git a/utmosv2/config/ssl_only_stage1.py b/utmosv2/config/ssl_only_stage1.py new file mode 100755 index 0000000000000000000000000000000000000000..7af8f4da097839eedc327270057320048965a332 --- /dev/null +++ b/utmosv2/config/ssl_only_stage1.py @@ -0,0 +1,69 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = "all" +use_bvcc = True + + +validation_dataset = "each" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/ssl_only_stage1_wo_bc.py b/utmosv2/config/ssl_only_stage1_wo_bc.py new file mode 100755 index 0000000000000000000000000000000000000000..d20add7be581217faa5f1d5cad742a82af523811 --- /dev/null +++ b/utmosv2/config/ssl_only_stage1_wo_bc.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = [ + "sarulab", + # "blizzard2008", + # "blizzard2009", + # "blizzard2011", + # "blizzard2010-EH1", + # "blizzard2010-EH2", + # "blizzard2010-ES1", + # "blizzard2010-ES3", + "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/ssl_only_stage1_wo_bvcc.py b/utmosv2/config/ssl_only_stage1_wo_bvcc.py new file mode 100755 index 0000000000000000000000000000000000000000..251c3188a3307e2d6b38720675823f0b175d70bc --- /dev/null +++ b/utmosv2/config/ssl_only_stage1_wo_bvcc.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = [ + "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "somos", +] +use_bvcc = False + + +validation_dataset = "each" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/ssl_only_stage1_wo_sarulab.py b/utmosv2/config/ssl_only_stage1_wo_sarulab.py new file mode 100755 index 0000000000000000000000000000000000000000..0882ba78bdcba64cb4be64b655ea482319b160e9 --- /dev/null +++ b/utmosv2/config/ssl_only_stage1_wo_sarulab.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = [ + # "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/ssl_only_stage1_wo_somos.py b/utmosv2/config/ssl_only_stage1_wo_somos.py new file mode 100755 index 0000000000000000000000000000000000000000..5812b7e4b6cd5f12e4ff8fda03c822eac375b57f --- /dev/null +++ b/utmosv2/config/ssl_only_stage1_wo_somos.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = [ + "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + # "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=True, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=20, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/ssl_only_stage2.py b/utmosv2/config/ssl_only_stage2.py new file mode 100755 index 0000000000000000000000000000000000000000..f19b167d54de0d205f3af608d002fa76c1b61066 --- /dev/null +++ b/utmosv2/config/ssl_only_stage2.py @@ -0,0 +1,69 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = "all" +use_bvcc = True + + +validation_dataset = "each" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=5, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/ssl_only_stage2_wo_bc.py b/utmosv2/config/ssl_only_stage2_wo_bc.py new file mode 100755 index 0000000000000000000000000000000000000000..31c966f26ea6f3b5e19832b83b0072f7c231507b --- /dev/null +++ b/utmosv2/config/ssl_only_stage2_wo_bc.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = [ + "sarulab", + # "blizzard2008", + # "blizzard2009", + # "blizzard2011", + # "blizzard2010-EH1", + # "blizzard2010-EH2", + # "blizzard2010-ES1", + # "blizzard2010-ES3", + "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=5, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/ssl_only_stage2_wo_bvcc.py b/utmosv2/config/ssl_only_stage2_wo_bvcc.py new file mode 100755 index 0000000000000000000000000000000000000000..0ae143183f12a4ac09fbdbe9e831237dde659851 --- /dev/null +++ b/utmosv2/config/ssl_only_stage2_wo_bvcc.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = [ + "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "somos", +] +use_bvcc = False + + +validation_dataset = "each" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=5, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/ssl_only_stage2_wo_sarulab.py b/utmosv2/config/ssl_only_stage2_wo_sarulab.py new file mode 100755 index 0000000000000000000000000000000000000000..0a022f892c0db72927d7757551f481549da02616 --- /dev/null +++ b/utmosv2/config/ssl_only_stage2_wo_sarulab.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = [ + # "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=5, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/config/ssl_only_stage2_wo_somos.py b/utmosv2/config/ssl_only_stage2_wo_somos.py new file mode 100755 index 0000000000000000000000000000000000000000..4b5250472ac0ed490527859a00604218993f1179 --- /dev/null +++ b/utmosv2/config/ssl_only_stage2_wo_somos.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +batch_size = 32 +num_folds = 5 + +sr = 16000 + +preprocess = SimpleNamespace( + top_db=30, min_seconds=None, save_path=Path("preprocessed_data") +) + +split = SimpleNamespace( + type="sgkf_kind", + target="mos", + group="sys_id", + kind="dataset", +) + +dataset = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + duration=3, + ), +) + +external_data = [ + "sarulab", + "blizzard2008", + "blizzard2009", + "blizzard2011", + "blizzard2010-EH1", + "blizzard2010-EH2", + "blizzard2010-ES1", + "blizzard2010-ES3", + # "somos", +] +use_bvcc = True + + +validation_dataset = "each" + +loss = [ + (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), + (SimpleNamespace(name="mse"), 0.2), +] + +optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) + +scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) + +model_path = "model" +model = SimpleNamespace( + name="sslext", + ssl=SimpleNamespace( + name="facebook/wav2vec2-base", + attn=1, + freeze=False, + num_classes=1, + ), +) + +run = SimpleNamespace( + mixup=True, + mixup_alpha=0.4, + num_epochs=5, +) + +main_metric = "sys_srcc" +id_name = None + + +inference = SimpleNamespace( + save_path=Path("preds"), + submit_save_path=Path("submissions"), + num_tta=5, + batch_size=8, + # extend="tile", +) diff --git a/utmosv2/dataset/__init__.py b/utmosv2/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18ef80ec4c9d325d417662c70ee955448571719d --- /dev/null +++ b/utmosv2/dataset/__init__.py @@ -0,0 +1,11 @@ +from utmosv2.dataset.multi_spec import MultiSpecDataset, MultiSpecExtDataset +from utmosv2.dataset.ssl import SSLDataset, SSLExtDataset +from utmosv2.dataset.ssl_multispec import SSLLMultiSpecExtDataset + +__all__ = [ + "MultiSpecDataset", + "MultiSpecExtDataset", + "SSLLMultiSpecExtDataset", + "SSLDataset", + "SSLExtDataset", +] diff --git a/utmosv2/dataset/_utils.py b/utmosv2/dataset/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ef15ab074828610706dc79387bf8e2ea252f5f --- /dev/null +++ b/utmosv2/dataset/_utils.py @@ -0,0 +1,53 @@ +import json + +import librosa +import numpy as np + + +def load_audio(cfg, file: str) -> np.ndarray: + if file.endswith(".wav"): + y, sr = librosa.load(file, sr=None) + y = librosa.resample(y, orig_sr=sr, target_sr=cfg.sr) + else: + y = np.load(file) + return y + + +def extend_audio(cfg, y: np.ndarray, length: int, type: str) -> np.ndarray: + if y.shape[0] > length: + return y + elif type == "tile": + n = length // y.shape[0] + 1 + y = np.tile(y, n) + return y + else: + raise NotImplementedError + + +def select_random_start(y: np.ndarray, length: int) -> np.ndarray: + start = np.random.randint(0, y.shape[0] - length) + return y[start : start + length] + + +def get_dataset_map(cfg): + if cfg.data_config: + with open(cfg.data_config, "r") as f: + datasets = json.load(f) + return {d["name"]: i for i, d in enumerate(datasets["data"])} + else: + return { + "bvcc": 0, + "sarulab": 1, + "blizzard2008": 2, + "blizzard2009": 3, + "blizzard2010-EH1": 4, + "blizzard2010-EH2": 5, + "blizzard2010-ES1": 6, + "blizzard2010-ES3": 7, + "blizzard2011": 8, + "somos": 9, + } + + +def get_dataset_num(cfg): + return len(get_dataset_map(cfg)) diff --git a/utmosv2/dataset/multi_spec.py b/utmosv2/dataset/multi_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..1d40274314fc11834f73c6c2a009faaeb2005464 --- /dev/null +++ b/utmosv2/dataset/multi_spec.py @@ -0,0 +1,99 @@ +import librosa +import numpy as np +import pandas as pd +import torch + +from utmosv2.dataset._utils import ( + extend_audio, + get_dataset_map, + load_audio, + select_random_start, +) + + +class MultiSpecDataset(torch.utils.data.Dataset): + def __init__(self, cfg, data: pd.DataFrame, phase: str, transform=None): + self.cfg = cfg + self.data = data + self.phase = phase + self.transform = transform + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + row = self.data.iloc[idx] + file = row["file_path"] + y = load_audio(self.cfg, file) + specs = [] + length = int(self.cfg.dataset.spec_frames.frame_sec * self.cfg.sr) + y = extend_audio(self.cfg, y, length, type=self.cfg.dataset.spec_frames.extend) + for _ in range(self.cfg.dataset.spec_frames.num_frames): + y1 = select_random_start(y, length) + for spec_cfg in self.cfg.dataset.specs: + spec = _make_spctrogram(self.cfg, spec_cfg, y1) + if self.cfg.dataset.spec_frames.mixup_inner: + y2 = select_random_start(y, length) + spec2 = _make_spctrogram(self.cfg, spec_cfg, y2) + lmd = np.random.beta( + self.cfg.dataset.spec_frames.mixup_alpha, + self.cfg.dataset.spec_frames.mixup_alpha, + ) + spec = lmd * spec + (1 - lmd) * spec2 + spec = np.stack([spec, spec, spec], axis=0) + # spec = np.transpose(spec, (1, 2, 0)) + spec = torch.tensor(spec, dtype=torch.float32) + phase = "train" if self.phase == "train" else "valid" + spec = self.transform[phase](spec) + specs.append(spec) + spec = torch.stack(specs).float() + + target = row["mos"] + target = torch.tensor(target, dtype=torch.float32) + + return spec, target + + +class MultiSpecExtDataset(MultiSpecDataset): + def __init__(self, cfg, data: pd.DataFrame, phase: str, transform=None): + super().__init__(cfg, data, phase, transform) + self.dataset_map = get_dataset_map(cfg) + + def __getitem__(self, idx): + spec, target = super().__getitem__(idx) + + d = np.zeros(len(self.dataset_map)) + d[self.dataset_map[self.data.iloc[idx]["dataset"]]] = 1 + d = torch.tensor(d, dtype=torch.float32) + + return spec, d, target + + +def _make_spctrogram(cfg, spec_cfg, y: np.ndarray) -> np.ndarray: + if spec_cfg.mode == "melspec": + return _make_melspec(cfg, spec_cfg, y) + elif spec_cfg.mode == "stft": + return _make_stft(cfg, spec_cfg, y) + else: + raise NotImplementedError + + +def _make_melspec(cfg, spec_cfg, y: np.ndarray) -> np.ndarray: + spec = librosa.feature.melspectrogram( + y=y, + sr=cfg.sr, + n_fft=spec_cfg.n_fft, + hop_length=spec_cfg.hop_length, + n_mels=spec_cfg.n_mels, + ) + spec = librosa.power_to_db(spec, ref=np.max) + if spec_cfg.norm is not None: + spec = (spec + spec_cfg.norm) / spec_cfg.norm + return spec + + +def _make_stft(cfg, spec_cfg, y: np.ndarray) -> np.ndarray: + spec = librosa.stft(y=y, n_fft=spec_cfg.n_fft, hop_length=spec_cfg.hop_length) + spec = np.abs(spec) + spec = librosa.amplitude_to_db(spec) + return spec diff --git a/utmosv2/dataset/ssl.py b/utmosv2/dataset/ssl.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5a81c78e8670f5f316bc12082fd75911e31a4a --- /dev/null +++ b/utmosv2/dataset/ssl.py @@ -0,0 +1,48 @@ +import numpy as np +import pandas as pd +import torch + +from utmosv2.dataset._utils import ( + extend_audio, + get_dataset_map, + load_audio, + select_random_start, +) + + +class SSLDataset(torch.utils.data.Dataset): + def __init__(self, cfg, data: pd.DataFrame, phase: str): + self.cfg = cfg + self.data = data + self.phase = phase + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + row = self.data.iloc[idx] + file = row["file_path"] + y = load_audio(self.cfg, file) + length = int(self.cfg.dataset.ssl.duration * self.cfg.sr) + y = extend_audio(self.cfg, y, length, type="tile") + y = select_random_start(y, length) + + target = row["mos"] + target = torch.tensor(target, dtype=torch.float32) + + return y, target + + +class SSLExtDataset(SSLDataset): + def __init__(self, cfg, data: pd.DataFrame, phase: str): + super().__init__(cfg, data, phase) + self.dataset_map = get_dataset_map(cfg) + + def __getitem__(self, idx): + y, target = super().__getitem__(idx) + + d = np.zeros(len(self.dataset_map)) + d[self.dataset_map[self.data.iloc[idx]["dataset"]]] = 1 + d = torch.tensor(d, dtype=torch.float32) + + return y, d, target diff --git a/utmosv2/dataset/ssl_multispec.py b/utmosv2/dataset/ssl_multispec.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba7b13809ba8a45c51f376eadee263ce9b8a584 --- /dev/null +++ b/utmosv2/dataset/ssl_multispec.py @@ -0,0 +1,20 @@ +import pandas as pd +import torch + +from utmosv2.dataset import MultiSpecDataset, SSLExtDataset + + +class SSLLMultiSpecExtDataset(torch.utils.data.Dataset): + def __init__(self, cfg, data: pd.DataFrame, phase: str, transform=None): + self.data = data + self.ssl = SSLExtDataset(cfg, data, phase) + self.multi_spec = MultiSpecDataset(cfg, data, phase, transform) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + x1, d, target = self.ssl[idx] + x2, _ = self.multi_spec[idx] + + return x1, x2, d, target diff --git a/utmosv2/loss/__init__.py b/utmosv2/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38a4ca0dc4b15702a8d5a70962a76f777fb38a0f --- /dev/null +++ b/utmosv2/loss/__init__.py @@ -0,0 +1,3 @@ +from utmosv2.loss.losses import CombinedLoss, PairwizeDiffLoss + +__all__ = ["PairwizeDiffLoss", "CombinedLoss"] diff --git a/utmosv2/loss/losses.py b/utmosv2/loss/losses.py new file mode 100755 index 0000000000000000000000000000000000000000..dda0965ec64e176bfeb6be5b4b681411537b42a1 --- /dev/null +++ b/utmosv2/loss/losses.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PairwizeDiffLoss(nn.Module): + def __init__(self, margin: float = 0.2, norm: str = "l1"): + super().__init__() + self.margin = margin + self.norm = norm + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + s = input.unsqueeze(1) - input.unsqueeze(0) + t = target.unsqueeze(1) - target.unsqueeze(0) + if self.norm not in ["l1", "l2_squared"]: + raise ValueError( + f'Unknown norm: {self.norm}. Must be one of ["l1", "l2_squared"]' + ) + norm_fn = { + "l1": torch.abs, + "l2_squared": lambda x: x**2, + }[self.norm] + loss = F.relu(norm_fn(s - t) - self.margin) + return loss.mean().div(2) + + +class CombinedLoss(nn.Module): + def __init__(self, weighted_losses: list[tuple[nn.Module, float]]): + super().__init__() + self.weighted_losses = weighted_losses + + def forward( + self, input: torch.Tensor, target: torch.Tensor + ) -> list[tuple[float, torch.Tensor]]: + return [(w, loss(input, target)) for loss, w in self.weighted_losses] diff --git a/utmosv2/model/__init__.py b/utmosv2/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..155272d091ddcc510f0820183792eceb7b2bf803 --- /dev/null +++ b/utmosv2/model/__init__.py @@ -0,0 +1,11 @@ +from utmosv2.model.multi_spec import MultiSpecExtModel, MultiSpecModelV2 +from utmosv2.model.ssl import SSLExtModel +from utmosv2.model.ssl_multispec import SSLMultiSpecExtModelV1, SSLMultiSpecExtModelV2 + +__all__ = [ + "MultiSpecExtModel", + "MultiSpecModelV2", + "SSLExtModel", + "SSLMultiSpecExtModelV1", + "SSLMultiSpecExtModelV2", +] diff --git a/utmosv2/model/multi_spec.py b/utmosv2/model/multi_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..add0c09ede6b3a8ee0b01e1aeeff8f21e5854752 --- /dev/null +++ b/utmosv2/model/multi_spec.py @@ -0,0 +1,168 @@ +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utmosv2.dataset._utils import get_dataset_num + + +class MultiSpecModelV2(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.backbones = nn.ModuleList( + [ + timm.create_model( + cfg.model.multi_spec.backbone, + pretrained=True, + num_classes=0, + ) + for _ in range(len(cfg.dataset.specs)) + ] + ) + for backbone in self.backbones: + backbone.global_pool = nn.Identity() + + self.weights = nn.Parameter( + F.softmax(torch.randn(len(cfg.dataset.specs)), dim=0) + ) + + self.pooling = timm.layers.SelectAdaptivePool2d( + output_size=(None, 1) if self.cfg.model.multi_spec.atten else 1, + pool_type=self.cfg.model.multi_spec.pool_type, + flatten=False, + ) + + if self.cfg.model.multi_spec.atten: + self.attn = nn.MultiheadAttention( + embed_dim=self.backbones[0].num_features + * (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1), + num_heads=8, + dropout=0.2, + batch_first=True, + ) + + fc_in_features = ( + self.backbones[0].num_features + * (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1) + * (2 if self.cfg.model.multi_spec.atten else 1) + ) + + self.fc = nn.Linear(fc_in_features, cfg.model.multi_spec.num_classes) + + # if cfg.print_config: + # print(f"| backbone model: {cfg.model.multi_spec.backbone}") + # print(f"| Pooling: {cfg.model.multi_spec.pool_type}") + # print(f"| Number of fc input features: {self.fc.in_features}") + # print(f"| Number of fc output features: {self.fc.out_features}") + + def forward(self, x): + x = [ + x[:, i, :, :, :].squeeze(1) + for i in range( + self.cfg.dataset.spec_frames.num_frames * len(self.cfg.dataset.specs) + ) + ] + x = [ + self.backbones[i % len(self.cfg.dataset.specs)](t) for i, t in enumerate(x) + ] + x = [ + sum( + [ + x[i * len(self.cfg.dataset.specs) + j] * w + for j, w in enumerate(self.weights) + ] + ) + for i in range(self.cfg.dataset.spec_frames.num_frames) + ] + x = torch.cat(x, dim=3) + x = self.pooling(x).squeeze(3) + if self.cfg.model.multi_spec.atten: + xt = torch.permute(x, (0, 2, 1)) + y, _ = self.attn(xt, xt, xt) + x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=2).values], dim=1) + x = self.fc(x) + return x + + +class MultiSpecExtModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.backbones = nn.ModuleList( + [ + timm.create_model( + cfg.model.multi_spec.backbone, + pretrained=True, + num_classes=0, + ) + for _ in range(len(cfg.dataset.specs)) + ] + ) + for backbone in self.backbones: + backbone.global_pool = nn.Identity() + + self.weights = nn.Parameter( + F.softmax(torch.randn(len(cfg.dataset.specs)), dim=0) + ) + + self.pooling = timm.layers.SelectAdaptivePool2d( + output_size=(None, 1) if self.cfg.model.multi_spec.atten else 1, + pool_type=self.cfg.model.multi_spec.pool_type, + flatten=False, + ) + + if self.cfg.model.multi_spec.atten: + self.attn = nn.MultiheadAttention( + embed_dim=self.backbones[0].num_features + * (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1), + num_heads=8, + dropout=0.2, + batch_first=True, + ) + + fc_in_features = ( + self.backbones[0].num_features + * (2 if self.cfg.model.multi_spec.pool_type == "catavgmax" else 1) + * (2 if self.cfg.model.multi_spec.atten else 1) + ) + + self.num_dataset = get_dataset_num(cfg) + + self.fc = nn.Linear( + fc_in_features + self.num_dataset, cfg.model.multi_spec.num_classes + ) + + # if cfg.print_config: + # print(f"| backbone model: {cfg.model.multi_spec.backbone}") + # print(f"| Pooling: {cfg.model.multi_spec.pool_type}") + # print(f"| Number of fc input features: {self.fc.in_features}") + # print(f"| Number of fc output features: {self.fc.out_features}") + + def forward(self, x, d): + x = [ + x[:, i, :, :, :].squeeze(1) + for i in range( + self.cfg.dataset.spec_frames.num_frames * len(self.cfg.dataset.specs) + ) + ] + x = [ + self.backbones[i % len(self.cfg.dataset.specs)](t) for i, t in enumerate(x) + ] + x = [ + sum( + [ + x[i * len(self.cfg.dataset.specs) + j] * w + for j, w in enumerate(self.weights) + ] + ) + for i in range(self.cfg.dataset.spec_frames.num_frames) + ] + x = torch.cat(x, dim=3) + x = self.pooling(x).squeeze(3) + if self.cfg.model.multi_spec.atten: + xt = torch.permute(x, (0, 2, 1)) + y, _ = self.attn(xt, xt, xt) + x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=2).values], dim=1) + x = self.fc(torch.cat([x, d], dim=1)) + return x diff --git a/utmosv2/model/ssl.py b/utmosv2/model/ssl.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe308231e7d5b2c9ad91574fc821a07ecbee06a --- /dev/null +++ b/utmosv2/model/ssl.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoFeatureExtractor, AutoModel + +from utmosv2.dataset._utils import get_dataset_num + + +class _SSLEncoder(nn.Module): + def __init__(self, sr: int, model_name: str, freeze: bool): + super().__init__() + self.sr = sr + self.processor = AutoFeatureExtractor.from_pretrained(model_name) + self.model = AutoModel.from_pretrained(model_name) + if freeze: + for param in self.model.parameters(): + param.requires_grad = False + + def forward(self, x): + x = self.processor( + [t.cpu().numpy() for t in x], + sampling_rate=self.sr, + return_tensors="pt", + ).to(self.model.device) + outputs = self.model(**x, output_hidden_states=True) + return outputs.hidden_states + + +class SSLExtModel(nn.Module): + def __init__(self, cfg, name: str | None = None): + super().__init__() + self.cfg = cfg + self.encoder = _SSLEncoder( + cfg.sr, name or cfg.model.ssl.name, cfg.model.ssl.freeze + ) + hidden_num, in_features = get_ssl_output_shape(name or cfg.model.ssl.name) + self.weights = nn.Parameter(F.softmax(torch.randn(hidden_num), dim=0)) + if cfg.model.ssl.attn: + self.attn = nn.ModuleList( + [ + nn.MultiheadAttention( + embed_dim=in_features, + num_heads=8, + dropout=0.2, + batch_first=True, + ) + for _ in range(cfg.model.ssl.attn) + ] + ) + self.num_dataset = get_dataset_num(cfg) + self.fc = nn.Linear( + in_features * 2 + self.num_dataset, cfg.model.ssl.num_classes + ) + + def forward(self, x, d): + x = self.encoder(x) + x = sum([t * w for t, w in zip(x, self.weights)]) + if self.cfg.model.ssl.attn: + y = x + for attn in self.attn: + y, _ = attn(y, y, y) + x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=1)[0]], dim=1) + else: + x = torch.cat([torch.mean(x, dim=1), torch.max(x, dim=1)[0]], dim=1) + x = self.fc(torch.cat([x, d], dim=1)) + return x + + +def get_ssl_output_shape(name: str) -> tuple[int, int]: + if name in [ + "facebook/w2v-bert-2.0", + "facebook/wav2vec2-large", + "facebook/wav2vec2-large-robust", + "facebook/wav2vec2-large-960h", + "microsoft/wavlm-large", + "facebook/wav2vec2-large-xlsr-53", + ]: + return 25, 1024 + elif name in [ + "facebook/hubert-base-ls960", + "facebook/data2vec-audio-base-960h", + "microsoft/wavlm-base", + "microsoft/wavlm-base-plus", + "microsoft/wavlm-base-plus-sv", + "facebook/wav2vec2-base", + ]: + return 13, 768 + else: + raise NotImplementedError diff --git a/utmosv2/model/ssl_multispec.py b/utmosv2/model/ssl_multispec.py new file mode 100644 index 0000000000000000000000000000000000000000..559f8840db4a82516eff07171af5b7e6b9c1e2a7 --- /dev/null +++ b/utmosv2/model/ssl_multispec.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +from utmosv2.dataset._utils import get_dataset_num +from utmosv2.model import MultiSpecExtModel, MultiSpecModelV2, SSLExtModel + + +class SSLMultiSpecExtModelV1(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.ssl = SSLExtModel(cfg) + self.spec_long = MultiSpecModelV2(cfg) + self.ssl.load_state_dict( + torch.load( + f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" + ) + ) + self.spec_long.load_state_dict( + torch.load( + f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" + ) + ) + if cfg.model.ssl_spec.freeze: + for param in self.ssl.parameters(): + param.requires_grad = False + for param in self.spec_long.parameters(): + param.requires_grad = False + ssl_input = self.ssl.fc.in_features + spec_long_input = self.spec_long.fc.in_features + self.ssl.fc = nn.Identity() + self.spec_long.fc = nn.Identity() + + self.num_dataset = get_dataset_num(cfg) + + self.fc = nn.Linear( + ssl_input + spec_long_input + self.num_dataset, + cfg.model.ssl_spec.num_classes, + ) + + def forward(self, x1, x2, d): + x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device)) + x2 = self.spec_long(x2) + x = torch.cat([x1, x2, d], dim=1) + x = self.fc(x) + return x + + +class SSLMultiSpecExtModelV2(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.ssl = SSLExtModel(cfg) + self.spec_long = MultiSpecExtModel(cfg) + if cfg.model.ssl_spec.ssl_weight is not None and cfg.phase == "train": + self.ssl.load_state_dict( + torch.load( + f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" + ) + ) + if cfg.model.ssl_spec.spec_weight is not None and cfg.phase == "train": + self.spec_long.load_state_dict( + torch.load( + f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" + ) + ) + if cfg.model.ssl_spec.freeze: + for param in self.ssl.parameters(): + param.requires_grad = False + for param in self.spec_long.parameters(): + param.requires_grad = False + ssl_input = self.ssl.fc.in_features + spec_long_input = self.spec_long.fc.in_features + self.ssl.fc = nn.Identity() + self.spec_long.fc = nn.Identity() + + self.num_dataset = get_dataset_num(cfg) + + self.fc = nn.Linear( + ssl_input + spec_long_input + self.num_dataset, + cfg.model.ssl_spec.num_classes, + ) + + def forward(self, x1, x2, d): + x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device)) + x2 = self.spec_long( + x2, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device) + ) + x = torch.cat([x1, x2, d], dim=1) + x = self.fc(x) + return x diff --git a/utmosv2/preprocess/__init__.py b/utmosv2/preprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1bf91e92193b85542f6f94dede10af8920c10f --- /dev/null +++ b/utmosv2/preprocess/__init__.py @@ -0,0 +1,3 @@ +from utmosv2.preprocess.preprocess import add_sys_mean, preprocess, preprocess_test + +__all__ = ["preprocess", "preprocess_test", "add_sys_mean"] diff --git a/utmosv2/preprocess/preprocess.py b/utmosv2/preprocess/preprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..99a6cd4321b42d94d9aaaf684696b1832daeeca3 --- /dev/null +++ b/utmosv2/preprocess/preprocess.py @@ -0,0 +1,197 @@ +import os + +import librosa +import numpy as np +import pandas as pd +from tqdm import tqdm + + +def _clip_audio(cfg, data: pd.DataFrame, data_name: str = "bvcc"): + (cfg.preprocess.save_path / data_name).mkdir(parents=True, exist_ok=True) + for file in tqdm(data["file_path"].values, desc="Clipping audio files"): + y, _ = librosa.load(file, sr=None) + y, _ = librosa.effects.trim(y, top_db=cfg.preprocess.top_db) + np.save( + cfg.preprocess.save_path + / data_name + / file.split("/")[-1].replace(".wav", ".npy"), + y, + ) + + +def _select_audio(cfg, data: pd.DataFrame, data_name: str = "bvcc"): + if cfg.preprocess.min_seconds is None: + return data + select_file_name = f"min_seconds={cfg.preprocess.min_seconds}.txt" + if select_file_name in os.listdir(cfg.preprocess.save_path / data_name): + with open( + cfg.preprocess.save_path / data_name / select_file_name, + "r", + ) as f: + select = f.read().split("\n") + else: + select = [] + for file in tqdm(data["file_path"].values, desc="Selecting audio files"): + y = np.load(file) + if y.shape[0] >= cfg.preprocess.min_seconds * cfg.sr: + select.append(file) + with open( + cfg.preprocess.save_path / data_name / select_file_name, + "w", + ) as f: + f.write("\n".join(select)) + _change_file_path(cfg, data) + data = data[data["file_path"].isin(set(select))] + return data + + +def _clip_and_select_audio( + cfg, data: pd.DataFrame, data_name: str = "bvcc" +) -> pd.DataFrame: + if not (cfg.preprocess.save_path / data_name).exists(): + _clip_audio(cfg, data) + print( + f"Clipped audio files are saved to `{cfg.preprocess.save_path / data_name}`." + ) + else: + print( + f"Clipped audio files already exist in `{cfg.preprocess.save_path / data_name}`." + ) + _change_file_path(cfg, data) + data = _select_audio(cfg, data) + print(f"{len(data)} files are selected.") + return data + + +def _change_file_path(cfg, data: pd.DataFrame, data_name: str = "bvcc"): + data.loc[:, "file_path"] = data.loc[:, "file_path"].apply( + lambda x: cfg.preprocess.save_path + / data_name + / x.split("/")[-1].replace(".wav", ".npy") + ) + + +def _add_metadata(cfg, data: pd.DataFrame): + metadata = [] + for t in ["TRAINSET", "DEVSET", "TESTSET"]: + meta = pd.read_csv(cfg.input_dir / f"sets/{t}") + meta.columns = ["sys_id", "utt_id", "rating", "ignore", "listener_info"] + meta = meta.groupby("utt_id", as_index=False).first()[["utt_id", "sys_id"]] + metadata.append(meta) + metadata = pd.concat(metadata, axis=0) + dt = pd.merge(data, metadata, on="utt_id", how="left") + data["sys_id"] = dt["sys_id"] + + +def add_sys_mean(data: pd.DataFrame): + sys_mean = ( + data.groupby("sys_id", as_index=False)["mos"].mean().reset_index(drop=True) + ) + sys_mean.columns = ["sys_id", "sys_mos"] + dt = pd.merge(data, sys_mean, on="sys_id", how="left") + data["sys_mos"] = dt["sys_mos"] + + +def preprocess(cfg, data: pd.DataFrame) -> pd.DataFrame: + data = _clip_and_select_audio(cfg, data) + _add_metadata(cfg, data) + add_sys_mean(data) + data["dataset"] = "bvcc" + if cfg.external_data: + exdata = _get_external_data(cfg, data) + add_sys_mean(exdata) + for col in data.columns: + if col not in exdata.columns: + exdata[col] = None + data = pd.concat([data, exdata], axis=0) + print("Using dataset:", data["dataset"].unique()) + if not cfg.use_bvcc: + data = data[data["dataset"] != "bvcc"] + return data + + +def preprocess_test(cfg, data: pd.DataFrame) -> pd.DataFrame: + _change_file_path(cfg, data) + _add_metadata(cfg, data) + add_sys_mean(data) + data["dataset"] = cfg.predict_dataset + return data + + +def _get_external_data(cfg, data: pd.DataFrame) -> pd.DataFrame: + exdata = [] + if cfg.external_data == "all" or "sarulab" in cfg.external_data: + ysdata = pd.read_csv( + "data2/sarulab/VMC2024_MOS.csv", header=None, names=["utt_id", "mos"] + ) + ysdata["mos"] = ysdata["mos"].astype(float) + ysdata["sys_id"] = ysdata["utt_id"].apply( + lambda x: "sarulab-" + x.split("-")[0] + ) + ysdata["file_path"] = ysdata["utt_id"].apply( + lambda x: cfg.preprocess.save_path / "bvcc" / x.replace(".wav", ".npy") + ) + ysdata["dataset"] = "sarulab" + exdata.append(ysdata) + + for name in ["blizzard2008", "blizzard2009", "blizzard2011"]: + if cfg.external_data != "all" and name not in cfg.external_data: + continue + bzdata = pd.read_csv( + f"data2/{name}/{name}_mos.csv", + header=None, + names=["utt_id", "mos"], + ) + bzdata["mos"] = bzdata["mos"].astype(float) + bzdata["sys_id"] = bzdata["utt_id"].apply( + lambda x: f"{name}-" + x.split("_")[0] + ) + bzdata["file_path"] = bzdata["utt_id"].apply( + lambda x: os.path.join(f"data2/{name}/{name}_wavs", x) + ) + bzdata["dataset"] = name + exdata.append(bzdata) + + for a in ["EH1", "EH2", "ES1", "ES3"]: + if cfg.external_data != "all" and f"blizzard2010-{a}" not in cfg.external_data: + continue + d = pd.read_csv( + f"data2/blizzard2010/blizzard2010_mos_{a}.csv", + header=None, + names=["utt_id", "mos"], + ) + d["mos"] = d["mos"].astype(float) + d["sys_id"] = d["utt_id"].apply( + lambda x: f"blizzard2010-{a}-" + x.split("_")[0] + ) + d["file_path"] = d["utt_id"].apply( + lambda x: os.path.join(f"data2/blizzard2010/blizzard2010_wavs_{a}", x) + ) + d["dataset"] = f"blizzard2010-{a}" + exdata.append(d) + + if cfg.external_data == "all" or "somos" in cfg.external_data: + train_mos_list = pd.read_csv( + "data2/somos/training_files/split1/clean/train_mos_list.txt", + ) + val_mos_list = pd.read_csv( + "data2/somos/training_files/split1/clean/valid_mos_list.txt", + ) + test_mos_list = pd.read_csv( + "data2/somos/training_files/split1/clean/test_mos_list.txt", + ) + somosdata = pd.concat([train_mos_list, val_mos_list, test_mos_list], axis=0) + somosdata.columns = ["utt_id", "mos"] + somosdata["mos"] = somosdata["mos"].astype(float) + somosdata["sys_id"] = somosdata["utt_id"].apply( + lambda x: "somos-" + x.replace(".wav", "").split("_")[-1] + ) + somosdata["file_path"] = somosdata["utt_id"].apply( + lambda x: os.path.join("data2/somos/audios", x) + ) + somosdata["dataset"] = "somos" + exdata.append(somosdata) + + exdata = pd.concat(exdata, axis=0) + + return exdata diff --git a/utmosv2/runner/__init__.py b/utmosv2/runner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f6a3971d5a128574971b53e404f633a4e65daa --- /dev/null +++ b/utmosv2/runner/__init__.py @@ -0,0 +1,4 @@ +from utmosv2.runner._inference import run_inference +from utmosv2.runner._train import run_train + +__all__ = ["run_train", "run_inference"] diff --git a/utmosv2/runner/_inference.py b/utmosv2/runner/_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bb32cc4ef6350faf90d4fe021da895079f0067 --- /dev/null +++ b/utmosv2/runner/_inference.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import torch +from torch.cuda.amp import autocast +from tqdm import tqdm + +from utmosv2.utils import calc_metrics, print_metrics + + +def run_inference( + cfg, + model: torch.nn.Module, + test_dataloader: torch.utils.data.DataLoader, + cycle: int, + test_data: pd.DataFrame, + device: torch.device, +) -> tuple[np.ndarray, dict[str, float] | None]: + model.eval() + test_preds = [] + pbar = tqdm( + test_dataloader, + total=len(test_dataloader), + desc=f" [Inference] ({cycle + 1}/{cfg.inference.num_tta})", + ) + + with torch.no_grad(): + for t in pbar: + x = t[:-1] + x = [t.to(device, non_blocking=True) for t in x] + with autocast(): + output = model(*x).squeeze() + test_preds.append(output.squeeze().cpu().numpy()) + test_preds = np.concatenate(test_preds) if cfg.input_dir else np.array(test_preds) + if cfg.reproduce: + test_metrics = calc_metrics(test_data, test_preds) + print_metrics(test_metrics) + else: + test_metrics = None + + return test_preds, test_metrics diff --git a/utmosv2/runner/_train.py b/utmosv2/runner/_train.py new file mode 100755 index 0000000000000000000000000000000000000000..c9c7214d9e9b688bb3fb2bfd5e6ee8c4ddf7b5d6 --- /dev/null +++ b/utmosv2/runner/_train.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import os +from collections import defaultdict +from collections.abc import Callable + +import numpy as np +import pandas as pd +import torch +import wandb +from torch.cuda.amp import GradScaler, autocast +from tqdm import tqdm + +from utmosv2.utils import calc_metrics, print_metrics + + +def _train_1epoch( + cfg, + model: torch.nn.Module, + train_dataloader: torch.utils.data.DataLoader, + criterion: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + device: torch.device, +) -> dict[str, float]: + model.train() + train_loss = defaultdict(float) + scaler = GradScaler() + print(f" (lr: {scheduler.get_last_lr()[0]:.6f})") + pbar = tqdm(train_dataloader, total=len(train_dataloader)) + for i, t in enumerate(pbar): + x, y = t[:-1], t[-1] + x = [t.to(device, non_blocking=True) for t in x] + y = y.to(device, non_blocking=True) + + if cfg.run.mixup: + lmd = np.random.beta(cfg.run.mixup_alpha, cfg.run.mixup_alpha) + perm = torch.randperm(x[0].shape[0]).to(device) + x2 = [t[perm, :] for t in x] + y2 = y[perm] + + optimizer.zero_grad() + with autocast(): + if cfg.run.mixup: + output = model( + *[lmd * t + (1 - lmd) * t2 for t, t2 in zip(x, x2)] + ).squeeze(1) + if isinstance(cfg.loss, list): + loss = [ + (w1, lmd * l1 + (1 - lmd) * l2) + for (w1, l1), (_, l2) in zip( + criterion(output, y), criterion(output, y2) + ) + ] + else: + loss = lmd * criterion(output, y) + (1 - lmd) * criterion( + output, y2 + ) + else: + output = model(*x).squeeze(1) + loss = criterion(output, y) + if isinstance(loss, list): + loss_total = sum(w * ls for w, ls in loss) + else: + loss_total = loss + + scaler.scale(loss_total).backward() + scaler.step(optimizer) + scaler.update() + scheduler.step() + + train_loss["loss"] += loss_total.detach().float().cpu().item() + if isinstance(loss, list): + for (cl, _), (_, ls) in zip(cfg.loss, loss): + train_loss[cl.name] += ls.detach().float().cpu().item() + + pbar.set_description( + f' loss: {train_loss["loss"] / (i + 1):.4f}' + + ( + f' ({", ".join([f"{cl.name}: {train_loss[cl.name] / (i + 1):.4f}" for cl, _ in cfg.loss])})' + if isinstance(loss, list) + else "" + ) + ) + + return {name: v / len(train_dataloader) for name, v in train_loss.items()} + + +def _validate_1epoch( + cfg, + model: torch.nn.Module, + valid_dataloader: torch.utils.data.DataLoader, + criterion: torch.nn.Module, + metrics: dict[str, Callable[[np.ndarray, np.ndarray], float]], + device: torch.device, +) -> tuple[dict[str, float], dict[str, float], np.ndarray]: + model.eval() + valid_loss = defaultdict(float) + valid_metrics = {name: 0.0 for name in metrics} + valid_preds = [] + pbar = tqdm(valid_dataloader, total=len(valid_dataloader)) + + with torch.no_grad(): + for i, t in enumerate(pbar): + x, y = t[:-1], t[-1] + x = [t.to(device, non_blocking=True) for t in x] + y_cpu = y + y = y.to(device, non_blocking=True) + with autocast(): + output = model(*x).squeeze(1) + loss = criterion(output, y) + if isinstance(loss, list): + loss_total = sum(w * ls for w, ls in loss) + else: + loss_total = loss + valid_loss["loss"] += loss_total.detach().float().cpu().item() + if isinstance(loss, list): + for (cl, _), (_, ls) in zip(cfg.loss, loss): + valid_loss[cl.name] += ls.detach().float().cpu().item() + output = output.cpu().numpy() + for name, metric in metrics.items(): + valid_metrics[name] += metric(output, y_cpu.numpy()) + pbar.set_description( + f' val_loss: {valid_loss["loss"] / (i + 1):.4f} ' + + ( + f'({", ".join([f"{cl.name}: {valid_loss[cl.name] / (i + 1):.4f}" for cl, _ in cfg.loss])}) ' + if isinstance(loss, list) + else "" + ) + + " - ".join( + [ + f"val_{name}: {v / (i + 1):.4f}" + for name, v in valid_metrics.items() + ] + ) + ) + valid_preds.append(output) + + valid_loss = {name: v / len(valid_dataloader) for name, v in valid_loss.items()} + valid_metrics = { + name: v / len(valid_dataloader) for name, v in valid_metrics.items() + } + valid_preds = np.concatenate(valid_preds) + + return valid_loss, valid_metrics, valid_preds + + +def run_train( + cfg, + model: torch.nn.Module, + train_dataloader: torch.utils.data.DataLoader, + valid_dataloader: torch.utils.data.DataLoader, + valid_data: pd.DataFrame, + oof_preds: np.ndarray, + now_fold: int, + criterion: torch.nn.Module, + metrics: dict[str, Callable[[np.ndarray, np.ndarray], float]], + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + device: torch.device, +) -> None: + best_metric = 0.0 + os.makedirs(cfg.save_path, exist_ok=True) + + for epoch in range(cfg.run.num_epochs): + print(f"[Epoch {epoch + 1}/{cfg.run.num_epochs}]") + train_loss = _train_1epoch( + cfg, model, train_dataloader, criterion, optimizer, scheduler, device + ) + valid_loss, _, valid_preds = _validate_1epoch( + cfg, model, valid_dataloader, criterion, metrics, device + ) + + print(f"Validation dataset: {cfg.validation_dataset}") + if cfg.validation_dataset == "each": + dataset = valid_data["dataset"].unique() + val_metrics = [ + calc_metrics( + valid_data[valid_data["dataset"] == ds], + valid_preds[valid_data["dataset"] == ds], + ) + for ds in dataset + ] + val_metrics = { + name: sum([m[name] for m in val_metrics]) / len(val_metrics) + for name in val_metrics[0].keys() + } + if cfg.validation_dataset == "all": + print("Validation dataset: ALL") + val_metrics = calc_metrics(valid_data, valid_preds) + else: + val_metrics = calc_metrics( + valid_data[valid_data["dataset"] == cfg.validation_dataset], + valid_preds[valid_data["dataset"] == cfg.validation_dataset], + ) + print_metrics(val_metrics) + + if val_metrics[cfg.main_metric] > best_metric: + new_metric = val_metrics[cfg.main_metric] + print(f"(Found best metric: {best_metric:.4f} -> {new_metric:.4f})") + best_metric = new_metric + save_path = ( + cfg.save_path / f"fold{now_fold}_s{cfg.split.seed}_best_model.pth" + ) + torch.save(model.state_dict(), save_path) + print(f"Save best model: {save_path}") + oof_preds[valid_data.index] = valid_preds + + save_path = cfg.save_path / f"fold{now_fold}_s{cfg.split.seed}_last_model.pth" + torch.save(model.state_dict(), save_path) + print() + + val_metrics["train_loss"] = train_loss["loss"] + val_metrics["val_loss"] = valid_loss["loss"] + for cl, _ in cfg.loss: + val_metrics[f"train_loss_{cl.name}"] = train_loss[cl.name] + val_metrics[f"val_loss_{cl.name}"] = valid_loss[cl.name] + if cfg.wandb: + wandb.log(val_metrics) diff --git a/utmosv2/transform/xymasking.py b/utmosv2/transform/xymasking.py new file mode 100644 index 0000000000000000000000000000000000000000..3970cab3090742a0b21396eb512240d6f499945f --- /dev/null +++ b/utmosv2/transform/xymasking.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import numpy as np +import torch + + +class XYMasking: + def __init__( + self, + num_masks_x: int | tuple[int, int], + num_masks_y: int | tuple[int, int], + mask_x_length: int | tuple[int, int], + mask_y_length: int | tuple[int, int], + fill_value: int, + p: float = 1.0, + ): + self.num_masks_x = num_masks_x + self.num_masks_y = num_masks_y + self.mask_x_length = mask_x_length + self.mask_y_length = mask_y_length + self.fill_value = fill_value + self.p = p + + def __call__(self, img: torch.tensor) -> torch.tensor: + if np.random.rand() < self.p: + return img + _, width, height = img.shape + num_masks_x = ( + np.random.randint(*self.num_masks_x) + if isinstance(self.num_masks_x, tuple) + else self.num_masks_x + ) + for _ in range(num_masks_x): + mask_x_length = ( + np.random.randint(*self.mask_x_length) + if isinstance(self.mask_x_length, tuple) + else self.mask_x_length + ) + x = np.random.randint(0, width - mask_x_length) + img[:, :, x : x + mask_x_length] = self.fill_value + + num_masks_y = ( + np.random.randint(*self.num_masks_y) + if isinstance(self.num_masks_y, tuple) + else self.num_masks_y + ) + for _ in range(num_masks_y): + mask_y_length = ( + np.random.randint(*self.mask_y_length) + if isinstance(self.mask_y_length, tuple) + else self.mask_y_length + ) + y = np.random.randint(0, height - mask_y_length) + img[:, y : y + mask_y_length, :] = self.fill_value + + return img diff --git a/utmosv2/utils/__init__.py b/utmosv2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9dbbee717a6709e45bb3bceb16e21bca3daa8d0a --- /dev/null +++ b/utmosv2/utils/__init__.py @@ -0,0 +1,49 @@ +from utmosv2.utils.pure import ( + configure_args, + configure_inference_args, + get_dataloader, + get_loss, + get_optimizer, + get_scheduler, + print_metrics, + save_oof_preds, + split_data, +) +from utmosv2.utils.task_dependents import ( + calc_metrics, + configure_defaults, + get_data, + get_dataset, + get_inference_data, + get_metrics, + get_model, + get_train_data, + make_submission_file, + save_preds, + save_test_preds, + show_inference_data, +) + +__all__ = [ + "configure_args", + "configure_inference_args", + "get_dataloader", + "get_loss", + "get_optimizer", + "get_scheduler", + "print_metrics", + "save_oof_preds", + "split_data", + "calc_metrics", + "configure_defaults", + "get_data", + "get_dataset", + "get_inference_data", + "get_train_data", + "get_metrics", + "get_model", + "make_submission_file", + "save_preds", + "save_test_preds", + "show_inference_data", +] diff --git a/utmosv2/utils/pure.py b/utmosv2/utils/pure.py new file mode 100755 index 0000000000000000000000000000000000000000..df7c86dba0847bc64c6a160d0b5f92847ac59964 --- /dev/null +++ b/utmosv2/utils/pure.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +from collections.abc import Generator +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.optim as optim +from sklearn.model_selection import ( + GroupKFold, + KFold, + StratifiedGroupKFold, + StratifiedKFold, +) + +from utmosv2.loss import CombinedLoss, PairwizeDiffLoss + + +def split_data( + cfg, data: pd.DataFrame +) -> Generator[tuple[np.ndarray, np.ndarray], None, None]: + if cfg.print_config: + print(f"Using split: {cfg.split.type}") + if cfg.split.type == "simple": + kf = KFold(n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed) + for train_idx, valid_idx in kf.split(data): + yield train_idx, valid_idx + elif cfg.split.type == "stratified": + kf = StratifiedKFold( + n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed + ) + for train_idx, valid_idx in kf.split(data, data[cfg.split.target].astype(int)): + yield train_idx, valid_idx + elif cfg.split.type == "group": + kf = GroupKFold(n_splits=cfg.num_folds) + for train_idx, valid_idx in kf.split(data, groups=data[cfg.split.group]): + yield train_idx, valid_idx + elif cfg.split.type == "stratified_group": + kf = StratifiedGroupKFold( + n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed + ) + for train_idx, valid_idx in kf.split( + data, data[cfg.split.target].astype(int), groups=data[cfg.split.group] + ): + yield train_idx, valid_idx + elif cfg.split.type == "sgkf_kind": + kind = data[cfg.split.kind].unique() + kf = [ + StratifiedGroupKFold( + n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed + ) + for _ in range(len(kind)) + ] + kf = [ + kf_i.split( + data[data[cfg.split.kind] == ds], + data[data[cfg.split.kind] == ds][cfg.split.target].astype(int), + groups=data[data[cfg.split.kind] == ds][cfg.split.group], + ) + for kf_i, ds in zip(kf, kind) + ] + for ds_idx in zip(*kf): + train_idx = np.concatenate([d[0] for d in ds_idx]) + valid_idx = np.concatenate([d[1] for d in ds_idx]) + yield train_idx, valid_idx + else: + raise NotImplementedError + + +def get_dataloader( + cfg, dataset: torch.utils.data.Dataset, phase: str +) -> torch.utils.data.DataLoader: + if phase == "train": + return torch.utils.data.DataLoader( + dataset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=True, + ) + elif phase == "valid": + return torch.utils.data.DataLoader( + dataset, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + pin_memory=True, + ) + elif phase == "test": + return torch.utils.data.DataLoader( + dataset, + batch_size=cfg.inference.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + pin_memory=True, + ) + else: + raise ValueError(f"Phase must be one of [train, valid, test], but got {phase}") + + +def _get_unit_loss(loss_cfg) -> nn.Module: + if loss_cfg.name == "pairwize_diff": + return PairwizeDiffLoss(loss_cfg.margin, loss_cfg.norm) + elif loss_cfg.name == "mse": + return nn.MSELoss() + else: + raise NotImplementedError + + +def _get_combined_loss(cfg) -> nn.Module: + if cfg.print_config: + print( + "Using losses: " + + ", ".join([f"{loss_cfg.name} ({w})" for loss_cfg, w in cfg.loss]) + ) + weighted_losses = [(_get_unit_loss(loss_cfg), w) for loss_cfg, w in cfg.loss] + return CombinedLoss(weighted_losses) + + +def get_loss(cfg) -> nn.Module: + if isinstance(cfg.loss, list): + return _get_combined_loss(cfg) + else: + return _get_unit_loss(cfg.loss) + + +def get_optimizer(cfg, model: nn.Module) -> optim.Optimizer: + if cfg.print_config: + print(f"Using optimizer: {cfg.optimizer.name}") + if cfg.optimizer.name == "adam": + return optim.Adam(model.parameters(), lr=cfg.optimizer.lr) + elif cfg.optimizer.name == "adamw": + return optim.AdamW( + model.parameters(), + lr=cfg.optimizer.lr, + weight_decay=cfg.optimizer.weight_decay, + ) + elif cfg.optimizer.name == "sgd": + return optim.SGD( + model.parameters(), + lr=cfg.optimizer.lr, + weight_decay=cfg.optimizer.weight_decay, + ) + else: + raise NotImplementedError + + +def get_scheduler( + cfg, optimizer: optim.Optimizer, n_iterations: int +) -> optim.lr_scheduler._LRScheduler: + if cfg.print_config: + print(f"Using scheduler: {cfg.scheduler}") + if cfg.scheduler is None: + return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1) + if cfg.scheduler.name == "cosine": + return optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=cfg.scheduler.T_max or n_iterations, + eta_min=cfg.scheduler.eta_min, + ) + else: + raise NotImplementedError + + +def print_metrics(metrics: dict[str, float]): + print(", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])) + + +def save_oof_preds(cfg, data: pd.DataFrame, oof_preds: np.ndarray, fold: int): + oof_df = pd.DataFrame({cfg.id_name: data[cfg.id_name], "oof_preds": oof_preds}) + oof_df.to_csv( + cfg.save_path / f"fold{fold}_s{cfg.split.seed}_oof_preds.csv", index=False + ) + + +def configure_args(cfg, args): + cfg.fold = args.fold + cfg.split.seed = args.seed + cfg.config_name = args.config + cfg.input_dir = args.input_dir and Path(args.input_dir) + cfg.num_workers = args.num_workers + cfg.weight = args.weight + cfg.save_path = Path("models") / cfg.config_name + cfg.wandb = args.wandb + cfg.reproduce = args.reproduce + cfg.data_config = args.data_config + cfg.phase = "train" + + +def configure_inference_args(cfg, args): + cfg.inference.fold = args.fold + cfg.split.seed = args.seed + cfg.config_name = args.config + cfg.input_dir = args.input_dir and Path(args.input_dir) + cfg.input_path = args.input_path and Path(args.input_path) + cfg.num_workers = args.num_workers + cfg.weight = args.weight + cfg.inference.val_list_path = args.val_list_path and Path(args.val_list_path) + cfg.save_path = Path("models") / cfg.config_name + cfg.predict_dataset = args.predict_dataset + cfg.final = args.final + cfg.inference.num_tta = args.num_repetitions + cfg.reproduce = args.reproduce + cfg.out_path = args.out_path and Path(args.out_path) + cfg.data_config = None + cfg.phase = "inference" diff --git a/utmosv2/utils/task_dependents.py b/utmosv2/utils/task_dependents.py new file mode 100755 index 0000000000000000000000000000000000000000..b99599bfda7aad511f38596088e5821c19258bec --- /dev/null +++ b/utmosv2/utils/task_dependents.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +import glob +import json +import os +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pandas as pd +import scipy.stats +import torch +import torch.nn as nn + +from utmosv2.dataset import ( + MultiSpecDataset, + MultiSpecExtDataset, + SSLDataset, + SSLExtDataset, + SSLLMultiSpecExtDataset, +) +from utmosv2.model import ( + MultiSpecExtModel, + MultiSpecModelV2, + SSLExtModel, + SSLMultiSpecExtModelV1, + SSLMultiSpecExtModelV2, +) +from utmosv2.preprocess import add_sys_mean, preprocess, preprocess_test + + +def get_data(cfg) -> pd.DataFrame: + train_mos_list = pd.read_csv(cfg.input_dir / "sets/train_mos_list.txt", header=None) + val_mos_list = pd.read_csv(cfg.input_dir / "sets/val_mos_list.txt", header=None) + test_mos_list = pd.read_csv(cfg.input_dir / "sets/test_mos_list.txt", header=None) + data = pd.concat([train_mos_list, val_mos_list, test_mos_list], axis=0) + data.columns = ["utt_id", "mos"] + data["file_path"] = data["utt_id"].apply(lambda x: cfg.input_dir / f"wav/{x}") + return data + + +def get_dataset(cfg, data: pd.DataFrame, phase: str) -> torch.utils.data.Dataset: + if cfg.print_config: + print(f"Using dataset: {cfg.dataset.name}") + if cfg.dataset.name == "multi_spec": + res = MultiSpecDataset(cfg, data, phase, cfg.transform) + elif cfg.dataset.name == "ssl": + res = SSLDataset(cfg, data, phase) + elif cfg.dataset.name == "sslext": + res = SSLExtDataset(cfg, data, phase) + elif cfg.dataset.name == "ssl_multispec_ext": + res = SSLLMultiSpecExtDataset(cfg, data, phase, cfg.transform) + elif cfg.dataset.name == "multi_spec_ext": + res = MultiSpecExtDataset(cfg, data, phase, cfg.transform) + else: + raise NotImplementedError + return res + + +def get_model(cfg, device: torch.device) -> nn.Module: + if cfg.print_config: + print(f"Using model: {cfg.model.name}") + if cfg.model.name == "multi_specv2": + model = MultiSpecModelV2(cfg) + elif cfg.model.name == "sslext": + model = SSLExtModel(cfg) + elif cfg.model.name == "multi_spec_ext": + model = MultiSpecExtModel(cfg) + elif cfg.model.name == "ssl_multispec_ext": + model = SSLMultiSpecExtModelV1(cfg) + elif cfg.model.name == "ssl_multispec_ext_v2": + model = SSLMultiSpecExtModelV2(cfg) + else: + raise NotImplementedError + model = model.to(device) + if cfg.weight is not None: + model.load_state_dict(torch.load(cfg.weight)) + return model + + +def get_metrics() -> dict[str, Callable[[np.ndarray, np.ndarray], float]]: + return { + "mse": lambda x, y: np.mean((x - y) ** 2), + "lcc": lambda x, y: np.corrcoef(x, y)[0][1], + "srcc": lambda x, y: scipy.stats.spearmanr(x, y)[0], + "ktau": lambda x, y: scipy.stats.kendalltau(x, y)[0], + } + + +def calc_metrics(data: pd.DataFrame, preds: np.ndarray) -> dict[str, float]: + data = data.copy() + data["preds"] = preds + data_sys = data.groupby("sys_id", as_index=False)[["mos", "preds"]].mean() + res = {} + for name, d in {"utt": data, "sys": data_sys}.items(): + res[f"{name}_mse"] = np.mean((d["mos"].values - d["preds"].values) ** 2) + res[f"{name}_lcc"] = np.corrcoef(d["mos"].values, d["preds"].values)[0][1] + res[f"{name}_srcc"] = scipy.stats.spearmanr(d["mos"].values, d["preds"].values)[ + 0 + ] + res[f"{name}_ktau"] = scipy.stats.kendalltau( + d["mos"].values, d["preds"].values + )[0] + return res + + +def configure_defaults(cfg): + if cfg.id_name is None: + cfg.id_name = "utt_id" + + +def _get_testdata(cfg, data: pd.DataFrame) -> pd.DataFrame: + with open(cfg.inference.val_list_path, "r") as f: + val_lists = [s.replace("\n", "") + ".wav" for s in f.readlines()] + test_data = data[data["utt_id"].isin(set(val_lists))] + return test_data + + +def get_inference_data(cfg) -> pd.DataFrame: + if cfg.reproduce: + data = get_data(cfg) + data = preprocess_test(cfg, data) + data = _get_testdata(cfg, data) + else: + if cfg.input_dir: + files = sorted(glob.glob(str(cfg.input_dir / "*.wav"))) + data = pd.DataFrame({"file_path": files}) + else: + data = pd.DataFrame({"file_path": [cfg.input_path.as_posix()]}) + data["utt_id"] = data["file_path"].apply( + lambda x: x.split("/")[-1].replace(".wav", "") + ) + data["sys_id"] = data["utt_id"].apply(lambda x: x.split("-")[0]) + if cfg.inference.val_list_path: + with open(cfg.inference.val_list_path, "r") as f: + val_lists = [s.replace(".wav", "") for s in f.read().splitlines()] + print(val_lists) + data = data[data["utt_id"].isin(set(val_lists))] + data["dataset"] = cfg.predict_dataset + data["mos"] = 0 + return data + + +def get_train_data(cfg) -> pd.DataFrame: + if cfg.reproduce: + data = get_data(cfg) + data = preprocess(cfg, data) + else: + with open(cfg.data_config, "r") as f: + datasets = json.load(f) + data = [] + for dt in datasets["data"]: + files = sorted(glob.glob(str(Path(dt["dir"]) / "*.wav"))) + d = pd.DataFrame({"file_path": files}) + d["dataset"] = dt["name"] + d["utt_id"] = d["file_path"].apply( + lambda x: x.split("/")[-1].replace(".wav", "") + ) + mos_list = pd.read_csv(dt["mos_list"], header=None) + mos_list.columns = ["utt_id", "mos"] + mos_list["utt_id"] = mos_list["utt_id"].apply( + lambda x: x.replace(".wav", "") + ) + d = d.merge(mos_list, on="utt_id", how="inner") + d["sys_id"] = d["utt_id"].apply(lambda x: x.split("-")[0]) + add_sys_mean(d) + data.append(d) + data = pd.concat(data, axis=0) + + return data + + +def show_inference_data(data: pd.DataFrame): + print( + data[[c for c in data.columns if c != "mos"]] + .rename(columns={"dataset": "predict_dataset"}) + .head() + ) + + +def _get_test_save_name(cfg) -> str: + return f"{cfg.config_name}_[fold{cfg.inference.fold}_tta{cfg.inference.num_tta}_s{cfg.split.seed}]" + + +def save_test_preds( + cfg, data: pd.DataFrame, test_preds: np.ndarray, test_metrics: dict[str, float] +): + test_df = pd.DataFrame({cfg.id_name: data[cfg.id_name], "test_preds": test_preds}) + save_path = ( + cfg.inference.save_path + / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})_test_preds{'_final' if cfg.final else ''}.csv", + ) + test_df.to_csv(save_path, index=False) + save_path = ( + cfg.inference.save_path + / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})_val_score{'_final' if cfg.final else ''}.json", + ) + with open(save_path, "w") as f: + json.dump(test_metrics, f) + print(f"Test predictions are saved to {save_path}") + + +def make_submission_file(cfg, data: pd.DataFrame, test_preds: np.ndarray): + submit = pd.DataFrame({cfg.id_name: data[cfg.id_name], "prediction": test_preds}) + os.makedirs( + cfg.inference.submit_save_path + / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})", + exist_ok=True, + ) + sub_file = ( + cfg.inference.submit_save_path + / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})" + / "answer.txt" + ) + submit.to_csv( + sub_file, + index=False, + header=False, + ) + print(f"Submission file is saved to {sub_file}") + + +def save_preds(cfg, data: pd.DataFrame, test_preds: np.ndarray): + pred = pd.DataFrame({cfg.id_name: data[cfg.id_name], "mos": test_preds}) + if cfg.out_path is None: + print("Predictions:") + print(pred) + else: + pred.to_csv(cfg.out_path, index=False) + print(f"Predictions are saved to {cfg.out_path}")