Upload 57 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +22 -0
- inference.py +113 -0
- models/bandit/core/__init__.py +744 -0
- models/bandit/core/data/__init__.py +2 -0
- models/bandit/core/data/_types.py +18 -0
- models/bandit/core/data/augmentation.py +107 -0
- models/bandit/core/data/augmented.py +35 -0
- models/bandit/core/data/base.py +69 -0
- models/bandit/core/data/dnr/__init__.py +0 -0
- models/bandit/core/data/dnr/datamodule.py +74 -0
- models/bandit/core/data/dnr/dataset.py +392 -0
- models/bandit/core/data/dnr/preprocess.py +54 -0
- models/bandit/core/data/musdb/__init__.py +0 -0
- models/bandit/core/data/musdb/datamodule.py +77 -0
- models/bandit/core/data/musdb/dataset.py +280 -0
- models/bandit/core/data/musdb/preprocess.py +238 -0
- models/bandit/core/data/musdb/validation.yaml +15 -0
- models/bandit/core/loss/__init__.py +2 -0
- models/bandit/core/loss/_complex.py +34 -0
- models/bandit/core/loss/_multistem.py +45 -0
- models/bandit/core/loss/_timefreq.py +113 -0
- models/bandit/core/loss/snr.py +146 -0
- models/bandit/core/metrics/__init__.py +9 -0
- models/bandit/core/metrics/_squim.py +383 -0
- models/bandit/core/metrics/snr.py +150 -0
- models/bandit/core/model/__init__.py +3 -0
- models/bandit/core/model/_spectral.py +58 -0
- models/bandit/core/model/bsrnn/__init__.py +23 -0
- models/bandit/core/model/bsrnn/bandsplit.py +139 -0
- models/bandit/core/model/bsrnn/core.py +661 -0
- models/bandit/core/model/bsrnn/maskestim.py +347 -0
- models/bandit/core/model/bsrnn/tfmodel.py +317 -0
- models/bandit/core/model/bsrnn/utils.py +583 -0
- models/bandit/core/model/bsrnn/wrapper.py +882 -0
- models/bandit/core/utils/__init__.py +0 -0
- models/bandit/core/utils/audio.py +463 -0
- models/bandit/model_from_config.py +31 -0
- models/bs_roformer/__init__.py +2 -0
- models/bs_roformer/attend.py +120 -0
- models/bs_roformer/bs_roformer.py +577 -0
- models/bs_roformer/mel_band_roformer.py +637 -0
- models/demucs4ht.py +713 -0
- models/mdx23c_tfc_tdf_v3.py +242 -0
- models/scnet/__init__.py +1 -0
- models/scnet/scnet.py +373 -0
- models/scnet/separation.py +178 -0
- models/scnet_unofficial/__init__.py +1 -0
- models/scnet_unofficial/modules/__init__.py +3 -0
- models/scnet_unofficial/modules/dualpath_rnn.py +228 -0
- models/scnet_unofficial/modules/sd_encoder.py +285 -0
app.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
DESCRIPTION = """
|
4 |
+
# audio sep
|
5 |
+
being made
|
6 |
+
"""
|
7 |
+
|
8 |
+
theme = gr.themes.Base(
|
9 |
+
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
|
10 |
+
)
|
11 |
+
with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
|
12 |
+
gr.Markdown(DESCRIPTION)
|
13 |
+
gr.DuplicateButton(
|
14 |
+
value="Duplicate Space for private use",
|
15 |
+
elem_id="duplicate-button",
|
16 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
demo.queue(max_size=20, api_open=False).launch(show_api=False)
|
inference.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import time
|
6 |
+
import librosa
|
7 |
+
from tqdm import tqdm
|
8 |
+
import sys
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
import soundfile as sf
|
14 |
+
import torch.nn as nn
|
15 |
+
from utils import demix_track, demix_track_demucs, get_model_from_config
|
16 |
+
|
17 |
+
import warnings
|
18 |
+
warnings.filterwarnings("ignore")
|
19 |
+
|
20 |
+
|
21 |
+
def run_folder(model, args, config, device, verbose=False):
|
22 |
+
start_time = time.time()
|
23 |
+
model.eval()
|
24 |
+
all_mixtures_path = glob.glob(args.input_folder + '/*.*')
|
25 |
+
print('Total files found: {}'.format(len(all_mixtures_path)))
|
26 |
+
|
27 |
+
instruments = config.training.instruments
|
28 |
+
if config.training.target_instrument is not None:
|
29 |
+
instruments = [config.training.target_instrument]
|
30 |
+
|
31 |
+
if not os.path.isdir(args.store_dir):
|
32 |
+
os.mkdir(args.store_dir)
|
33 |
+
|
34 |
+
if not verbose:
|
35 |
+
all_mixtures_path = tqdm(all_mixtures_path)
|
36 |
+
|
37 |
+
for path in all_mixtures_path:
|
38 |
+
if not verbose:
|
39 |
+
all_mixtures_path.set_postfix({'track': os.path.basename(path)})
|
40 |
+
try:
|
41 |
+
# mix, sr = sf.read(path)
|
42 |
+
mix, sr = librosa.load(path, sr=44100, mono=False)
|
43 |
+
mix = mix.T
|
44 |
+
except Exception as e:
|
45 |
+
print('Can read track: {}'.format(path))
|
46 |
+
print('Error message: {}'.format(str(e)))
|
47 |
+
continue
|
48 |
+
|
49 |
+
# Convert mono to stereo if needed
|
50 |
+
if len(mix.shape) == 1:
|
51 |
+
mix = np.stack([mix, mix], axis=-1)
|
52 |
+
|
53 |
+
mixture = torch.tensor(mix.T, dtype=torch.float32)
|
54 |
+
if args.model_type == 'htdemucs':
|
55 |
+
res = demix_track_demucs(config, model, mixture, device)
|
56 |
+
else:
|
57 |
+
res = demix_track(config, model, mixture, device)
|
58 |
+
for instr in instruments:
|
59 |
+
sf.write("{}/{}_{}.wav".format(args.store_dir, os.path.basename(path)[:-4], instr), res[instr].T, sr, subtype='FLOAT')
|
60 |
+
|
61 |
+
if 'vocals' in instruments and args.extract_instrumental:
|
62 |
+
instrum_file_name = "{}/{}_{}.wav".format(args.store_dir, os.path.basename(path)[:-4], 'instrumental')
|
63 |
+
sf.write(instrum_file_name, mix - res['vocals'].T, sr, subtype='FLOAT')
|
64 |
+
|
65 |
+
time.sleep(1)
|
66 |
+
print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
|
67 |
+
|
68 |
+
|
69 |
+
def proc_folder(args):
|
70 |
+
parser = argparse.ArgumentParser()
|
71 |
+
parser.add_argument("--model_type", type=str, default='mdx23c', help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit")
|
72 |
+
parser.add_argument("--config_path", type=str, help="path to config file")
|
73 |
+
parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to valid weights")
|
74 |
+
parser.add_argument("--input_folder", type=str, help="folder with mixtures to process")
|
75 |
+
parser.add_argument("--store_dir", default="", type=str, help="path to store results as wav file")
|
76 |
+
parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='list of gpu ids')
|
77 |
+
parser.add_argument("--extract_instrumental", action='store_true', help="invert vocals to get instrumental if provided")
|
78 |
+
if args is None:
|
79 |
+
args = parser.parse_args()
|
80 |
+
else:
|
81 |
+
args = parser.parse_args(args)
|
82 |
+
|
83 |
+
torch.backends.cudnn.benchmark = True
|
84 |
+
|
85 |
+
model, config = get_model_from_config(args.model_type, args.config_path)
|
86 |
+
if args.start_check_point != '':
|
87 |
+
print('Start from checkpoint: {}'.format(args.start_check_point))
|
88 |
+
state_dict = torch.load(args.start_check_point)
|
89 |
+
if args.model_type == 'htdemucs':
|
90 |
+
# Fix for htdemucs pround etrained models
|
91 |
+
if 'state' in state_dict:
|
92 |
+
state_dict = state_dict['state']
|
93 |
+
model.load_state_dict(state_dict)
|
94 |
+
print("Instruments: {}".format(config.training.instruments))
|
95 |
+
|
96 |
+
if torch.cuda.is_available():
|
97 |
+
device_ids = args.device_ids
|
98 |
+
if type(device_ids)==int:
|
99 |
+
device = torch.device(f'cuda:{device_ids}')
|
100 |
+
model = model.to(device)
|
101 |
+
else:
|
102 |
+
device = torch.device(f'cuda:{device_ids[0]}')
|
103 |
+
model = nn.DataParallel(model, device_ids=device_ids).to(device)
|
104 |
+
else:
|
105 |
+
device = 'cpu'
|
106 |
+
print('CUDA is not avilable. Run inference on CPU. It will be very slow...')
|
107 |
+
model = model.to(device)
|
108 |
+
|
109 |
+
run_folder(model, args, config, device, verbose=False)
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
proc_folder(None)
|
models/bandit/core/__init__.py
ADDED
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from collections import defaultdict
|
3 |
+
from itertools import chain, combinations
|
4 |
+
from typing import (
|
5 |
+
Any,
|
6 |
+
Dict,
|
7 |
+
Iterator,
|
8 |
+
Mapping, Optional,
|
9 |
+
Tuple, Type,
|
10 |
+
TypedDict
|
11 |
+
)
|
12 |
+
|
13 |
+
import pytorch_lightning as pl
|
14 |
+
import torch
|
15 |
+
import torchaudio as ta
|
16 |
+
import torchmetrics as tm
|
17 |
+
from asteroid import losses as asteroid_losses
|
18 |
+
# from deepspeed.ops.adam import DeepSpeedCPUAdam
|
19 |
+
# from geoopt import optim as gooptim
|
20 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
21 |
+
from torch import nn, optim
|
22 |
+
from torch.optim import lr_scheduler
|
23 |
+
from torch.optim.lr_scheduler import LRScheduler
|
24 |
+
|
25 |
+
from models.bandit.core import loss, metrics as metrics_, model
|
26 |
+
from models.bandit.core.data._types import BatchedDataDict
|
27 |
+
from models.bandit.core.data.augmentation import BaseAugmentor, StemAugmentor
|
28 |
+
from models.bandit.core.utils import audio as audio_
|
29 |
+
from models.bandit.core.utils.audio import BaseFader
|
30 |
+
|
31 |
+
# from pandas.io.json._normalize import nested_to_record
|
32 |
+
|
33 |
+
ConfigDict = TypedDict('ConfigDict', {'name': str, 'kwargs': Dict[str, Any]})
|
34 |
+
|
35 |
+
|
36 |
+
class SchedulerConfigDict(ConfigDict):
|
37 |
+
monitor: str
|
38 |
+
|
39 |
+
|
40 |
+
OptimizerSchedulerConfigDict = TypedDict(
|
41 |
+
'OptimizerSchedulerConfigDict',
|
42 |
+
{"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
|
43 |
+
total=False
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
class LRSchedulerReturnDict(TypedDict, total=False):
|
48 |
+
scheduler: LRScheduler
|
49 |
+
monitor: str
|
50 |
+
|
51 |
+
|
52 |
+
class ConfigureOptimizerReturnDict(TypedDict, total=False):
|
53 |
+
optimizer: torch.optim.Optimizer
|
54 |
+
lr_scheduler: LRSchedulerReturnDict
|
55 |
+
|
56 |
+
|
57 |
+
OutputType = Dict[str, Any]
|
58 |
+
MetricsType = Dict[str, torch.Tensor]
|
59 |
+
|
60 |
+
|
61 |
+
def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
|
62 |
+
|
63 |
+
if name == "DeepSpeedCPUAdam":
|
64 |
+
return DeepSpeedCPUAdam
|
65 |
+
|
66 |
+
for module in [optim, gooptim]:
|
67 |
+
if name in module.__dict__:
|
68 |
+
return module.__dict__[name]
|
69 |
+
|
70 |
+
raise NameError
|
71 |
+
|
72 |
+
|
73 |
+
def parse_optimizer_config(
|
74 |
+
config: OptimizerSchedulerConfigDict,
|
75 |
+
parameters: Iterator[nn.Parameter]
|
76 |
+
) -> ConfigureOptimizerReturnDict:
|
77 |
+
optim_class = get_optimizer_class(config["optimizer"]["name"])
|
78 |
+
optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
|
79 |
+
|
80 |
+
optim_dict: ConfigureOptimizerReturnDict = {
|
81 |
+
"optimizer": optimizer,
|
82 |
+
}
|
83 |
+
|
84 |
+
if "scheduler" in config:
|
85 |
+
|
86 |
+
lr_scheduler_class_ = config["scheduler"]["name"]
|
87 |
+
lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
|
88 |
+
lr_scheduler_dict: LRSchedulerReturnDict = {
|
89 |
+
"scheduler": lr_scheduler_class(
|
90 |
+
optimizer,
|
91 |
+
**config["scheduler"]["kwargs"]
|
92 |
+
)
|
93 |
+
}
|
94 |
+
|
95 |
+
if lr_scheduler_class_ == "ReduceLROnPlateau":
|
96 |
+
lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
|
97 |
+
|
98 |
+
optim_dict["lr_scheduler"] = lr_scheduler_dict
|
99 |
+
|
100 |
+
return optim_dict
|
101 |
+
|
102 |
+
|
103 |
+
def parse_model_config(config: ConfigDict) -> Any:
|
104 |
+
name = config["name"]
|
105 |
+
|
106 |
+
for module in [model]:
|
107 |
+
if name in module.__dict__:
|
108 |
+
return module.__dict__[name](**config["kwargs"])
|
109 |
+
|
110 |
+
raise NameError
|
111 |
+
|
112 |
+
|
113 |
+
_LEGACY_LOSS_NAMES = ["HybridL1Loss"]
|
114 |
+
|
115 |
+
|
116 |
+
def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
|
117 |
+
name = config["name"]
|
118 |
+
|
119 |
+
if name == "HybridL1Loss":
|
120 |
+
return loss.TimeFreqL1Loss(**config["kwargs"])
|
121 |
+
|
122 |
+
raise NameError
|
123 |
+
|
124 |
+
|
125 |
+
def parse_loss_config(config: ConfigDict) -> nn.Module:
|
126 |
+
name = config["name"]
|
127 |
+
|
128 |
+
if name in _LEGACY_LOSS_NAMES:
|
129 |
+
return _parse_legacy_loss_config(config)
|
130 |
+
|
131 |
+
for module in [loss, nn.modules.loss, asteroid_losses]:
|
132 |
+
if name in module.__dict__:
|
133 |
+
# print(config["kwargs"])
|
134 |
+
return module.__dict__[name](**config["kwargs"])
|
135 |
+
|
136 |
+
raise NameError
|
137 |
+
|
138 |
+
|
139 |
+
def get_metric(config: ConfigDict) -> tm.Metric:
|
140 |
+
name = config["name"]
|
141 |
+
|
142 |
+
for module in [tm, metrics_]:
|
143 |
+
if name in module.__dict__:
|
144 |
+
return module.__dict__[name](**config["kwargs"])
|
145 |
+
raise NameError
|
146 |
+
|
147 |
+
|
148 |
+
def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
|
149 |
+
metrics = {}
|
150 |
+
|
151 |
+
for metric in config:
|
152 |
+
metrics[metric] = get_metric(config[metric])
|
153 |
+
|
154 |
+
return tm.MetricCollection(metrics)
|
155 |
+
|
156 |
+
|
157 |
+
def parse_fader_config(config: ConfigDict) -> BaseFader:
|
158 |
+
name = config["name"]
|
159 |
+
|
160 |
+
for module in [audio_]:
|
161 |
+
if name in module.__dict__:
|
162 |
+
return module.__dict__[name](**config["kwargs"])
|
163 |
+
|
164 |
+
raise NameError
|
165 |
+
|
166 |
+
|
167 |
+
class LightningSystem(pl.LightningModule):
|
168 |
+
_VOX_STEMS = ["speech", "vocals"]
|
169 |
+
_BG_STEMS = ["background", "effects", "mne"]
|
170 |
+
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
config: Dict,
|
174 |
+
loss_adjustment: float = 1.0,
|
175 |
+
attach_fader: bool = False
|
176 |
+
) -> None:
|
177 |
+
super().__init__()
|
178 |
+
self.optimizer_config = config["optimizer"]
|
179 |
+
self.model = parse_model_config(config["model"])
|
180 |
+
self.loss = parse_loss_config(config["loss"])
|
181 |
+
self.metrics = nn.ModuleDict(
|
182 |
+
{
|
183 |
+
stem: parse_metric_config(config["metrics"]["dev"])
|
184 |
+
for stem in self.model.stems
|
185 |
+
}
|
186 |
+
)
|
187 |
+
|
188 |
+
self.metrics.disallow_fsdp = True
|
189 |
+
|
190 |
+
self.test_metrics = nn.ModuleDict(
|
191 |
+
{
|
192 |
+
stem: parse_metric_config(config["metrics"]["test"])
|
193 |
+
for stem in self.model.stems
|
194 |
+
}
|
195 |
+
)
|
196 |
+
|
197 |
+
self.test_metrics.disallow_fsdp = True
|
198 |
+
|
199 |
+
self.fs = config["model"]["kwargs"]["fs"]
|
200 |
+
|
201 |
+
self.fader_config = config["inference"]["fader"]
|
202 |
+
if attach_fader:
|
203 |
+
self.fader = parse_fader_config(config["inference"]["fader"])
|
204 |
+
else:
|
205 |
+
self.fader = None
|
206 |
+
|
207 |
+
self.augmentation: Optional[BaseAugmentor]
|
208 |
+
if config.get("augmentation", None) is not None:
|
209 |
+
self.augmentation = StemAugmentor(**config["augmentation"])
|
210 |
+
else:
|
211 |
+
self.augmentation = None
|
212 |
+
|
213 |
+
self.predict_output_path: Optional[str] = None
|
214 |
+
self.loss_adjustment = loss_adjustment
|
215 |
+
|
216 |
+
self.val_prefix = None
|
217 |
+
self.test_prefix = None
|
218 |
+
|
219 |
+
|
220 |
+
def configure_optimizers(self) -> Any:
|
221 |
+
return parse_optimizer_config(
|
222 |
+
self.optimizer_config,
|
223 |
+
self.trainer.model.parameters()
|
224 |
+
)
|
225 |
+
|
226 |
+
def compute_loss(self, batch: BatchedDataDict, output: OutputType) -> Dict[
|
227 |
+
str, torch.Tensor]:
|
228 |
+
return {"loss": self.loss(output, batch)}
|
229 |
+
|
230 |
+
def update_metrics(
|
231 |
+
self,
|
232 |
+
batch: BatchedDataDict,
|
233 |
+
output: OutputType,
|
234 |
+
mode: str
|
235 |
+
) -> None:
|
236 |
+
|
237 |
+
if mode == "test":
|
238 |
+
metrics = self.test_metrics
|
239 |
+
else:
|
240 |
+
metrics = self.metrics
|
241 |
+
|
242 |
+
for stem, metric in metrics.items():
|
243 |
+
|
244 |
+
if stem == "mne:+":
|
245 |
+
stem = "mne"
|
246 |
+
|
247 |
+
# print(f"matching for {stem}")
|
248 |
+
if mode == "train":
|
249 |
+
metric.update(
|
250 |
+
output["audio"][stem],#.cpu(),
|
251 |
+
batch["audio"][stem],#.cpu()
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
if stem not in batch["audio"]:
|
255 |
+
matched = False
|
256 |
+
if stem in self._VOX_STEMS:
|
257 |
+
for bstem in self._VOX_STEMS:
|
258 |
+
if bstem in batch["audio"]:
|
259 |
+
batch["audio"][stem] = batch["audio"][bstem]
|
260 |
+
matched = True
|
261 |
+
break
|
262 |
+
elif stem in self._BG_STEMS:
|
263 |
+
for bstem in self._BG_STEMS:
|
264 |
+
if bstem in batch["audio"]:
|
265 |
+
batch["audio"][stem] = batch["audio"][bstem]
|
266 |
+
matched = True
|
267 |
+
break
|
268 |
+
else:
|
269 |
+
matched = True
|
270 |
+
|
271 |
+
# print(batch["audio"].keys())
|
272 |
+
|
273 |
+
if matched:
|
274 |
+
# print(f"matched {stem}!")
|
275 |
+
if stem == "mne" and "mne" not in output["audio"]:
|
276 |
+
output["audio"]["mne"] = output["audio"]["music"] + output["audio"]["effects"]
|
277 |
+
|
278 |
+
metric.update(
|
279 |
+
output["audio"][stem],#.cpu(),
|
280 |
+
batch["audio"][stem],#.cpu(),
|
281 |
+
)
|
282 |
+
|
283 |
+
# print(metric.compute())
|
284 |
+
def compute_metrics(self, mode: str="dev") -> Dict[
|
285 |
+
str, torch.Tensor]:
|
286 |
+
|
287 |
+
if mode == "test":
|
288 |
+
metrics = self.test_metrics
|
289 |
+
else:
|
290 |
+
metrics = self.metrics
|
291 |
+
|
292 |
+
metric_dict = {}
|
293 |
+
|
294 |
+
for stem, metric in metrics.items():
|
295 |
+
md = metric.compute()
|
296 |
+
metric_dict.update(
|
297 |
+
{f"{stem}/{k}": v for k, v in md.items()}
|
298 |
+
)
|
299 |
+
|
300 |
+
self.log_dict(metric_dict, prog_bar=True, logger=False)
|
301 |
+
|
302 |
+
return metric_dict
|
303 |
+
|
304 |
+
def reset_metrics(self, test_mode: bool = False) -> None:
|
305 |
+
|
306 |
+
if test_mode:
|
307 |
+
metrics = self.test_metrics
|
308 |
+
else:
|
309 |
+
metrics = self.metrics
|
310 |
+
|
311 |
+
for _, metric in metrics.items():
|
312 |
+
metric.reset()
|
313 |
+
|
314 |
+
|
315 |
+
def forward(self, batch: BatchedDataDict) -> Any:
|
316 |
+
batch, output = self.model(batch)
|
317 |
+
|
318 |
+
|
319 |
+
return batch, output
|
320 |
+
|
321 |
+
def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
|
322 |
+
batch, output = self.forward(batch)
|
323 |
+
# print(batch)
|
324 |
+
# print(output)
|
325 |
+
loss_dict = self.compute_loss(batch, output)
|
326 |
+
|
327 |
+
with torch.no_grad():
|
328 |
+
self.update_metrics(batch, output, mode=mode)
|
329 |
+
|
330 |
+
if mode == "train":
|
331 |
+
self.log("loss", loss_dict["loss"], prog_bar=True)
|
332 |
+
|
333 |
+
return output, loss_dict
|
334 |
+
|
335 |
+
|
336 |
+
def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
|
337 |
+
|
338 |
+
if self.augmentation is not None:
|
339 |
+
with torch.no_grad():
|
340 |
+
batch = self.augmentation(batch)
|
341 |
+
|
342 |
+
_, loss_dict = self.common_step(batch, mode="train")
|
343 |
+
|
344 |
+
with torch.inference_mode():
|
345 |
+
self.log_dict_with_prefix(
|
346 |
+
loss_dict,
|
347 |
+
"train",
|
348 |
+
batch_size=batch["audio"]["mixture"].shape[0]
|
349 |
+
)
|
350 |
+
|
351 |
+
loss_dict["loss"] *= self.loss_adjustment
|
352 |
+
|
353 |
+
return loss_dict
|
354 |
+
|
355 |
+
def on_train_batch_end(
|
356 |
+
self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
|
357 |
+
) -> None:
|
358 |
+
|
359 |
+
metric_dict = self.compute_metrics()
|
360 |
+
self.log_dict_with_prefix(metric_dict, "train")
|
361 |
+
self.reset_metrics()
|
362 |
+
|
363 |
+
def validation_step(
|
364 |
+
self,
|
365 |
+
batch: BatchedDataDict,
|
366 |
+
batch_idx: int,
|
367 |
+
dataloader_idx: int = 0
|
368 |
+
) -> Dict[str, Any]:
|
369 |
+
|
370 |
+
with torch.inference_mode():
|
371 |
+
curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
|
372 |
+
|
373 |
+
if curr_val_prefix != self.val_prefix:
|
374 |
+
# print(f"Switching to validation dataloader {dataloader_idx}")
|
375 |
+
if self.val_prefix is not None:
|
376 |
+
self._on_validation_epoch_end()
|
377 |
+
self.val_prefix = curr_val_prefix
|
378 |
+
_, loss_dict = self.common_step(batch, mode="val")
|
379 |
+
|
380 |
+
self.log_dict_with_prefix(
|
381 |
+
loss_dict,
|
382 |
+
self.val_prefix,
|
383 |
+
batch_size=batch["audio"]["mixture"].shape[0],
|
384 |
+
prog_bar=True,
|
385 |
+
add_dataloader_idx=False
|
386 |
+
)
|
387 |
+
|
388 |
+
return loss_dict
|
389 |
+
|
390 |
+
def on_validation_epoch_end(self) -> None:
|
391 |
+
self._on_validation_epoch_end()
|
392 |
+
|
393 |
+
def _on_validation_epoch_end(self) -> None:
|
394 |
+
metric_dict = self.compute_metrics()
|
395 |
+
self.log_dict_with_prefix(metric_dict, self.val_prefix, prog_bar=True,
|
396 |
+
add_dataloader_idx=False)
|
397 |
+
# self.logger.save()
|
398 |
+
# print(self.val_prefix, "Validation metrics:", metric_dict)
|
399 |
+
self.reset_metrics()
|
400 |
+
|
401 |
+
|
402 |
+
def old_predtest_step(
|
403 |
+
self,
|
404 |
+
batch: BatchedDataDict,
|
405 |
+
batch_idx: int,
|
406 |
+
dataloader_idx: int = 0
|
407 |
+
) -> Tuple[BatchedDataDict, OutputType]:
|
408 |
+
|
409 |
+
audio_batch = batch["audio"]["mixture"]
|
410 |
+
track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
|
411 |
+
|
412 |
+
output_list_of_dicts = [
|
413 |
+
self.fader(
|
414 |
+
audio[None, ...],
|
415 |
+
lambda a: self.test_forward(a, track)
|
416 |
+
)
|
417 |
+
for audio, track in zip(audio_batch, track_batch)
|
418 |
+
]
|
419 |
+
|
420 |
+
output_dict_of_lists = defaultdict(list)
|
421 |
+
|
422 |
+
for output_dict in output_list_of_dicts:
|
423 |
+
for stem, audio in output_dict.items():
|
424 |
+
output_dict_of_lists[stem].append(audio)
|
425 |
+
|
426 |
+
output = {
|
427 |
+
"audio": {
|
428 |
+
stem: torch.concat(output_list, dim=0)
|
429 |
+
for stem, output_list in output_dict_of_lists.items()
|
430 |
+
}
|
431 |
+
}
|
432 |
+
|
433 |
+
return batch, output
|
434 |
+
|
435 |
+
def predtest_step(
|
436 |
+
self,
|
437 |
+
batch: BatchedDataDict,
|
438 |
+
batch_idx: int = -1,
|
439 |
+
dataloader_idx: int = 0
|
440 |
+
) -> Tuple[BatchedDataDict, OutputType]:
|
441 |
+
|
442 |
+
if getattr(self.model, "bypass_fader", False):
|
443 |
+
batch, output = self.model(batch)
|
444 |
+
else:
|
445 |
+
audio_batch = batch["audio"]["mixture"]
|
446 |
+
output = self.fader(
|
447 |
+
audio_batch,
|
448 |
+
lambda a: self.test_forward(a, "", batch=batch)
|
449 |
+
)
|
450 |
+
|
451 |
+
return batch, output
|
452 |
+
|
453 |
+
def test_forward(
|
454 |
+
self,
|
455 |
+
audio: torch.Tensor,
|
456 |
+
track: str = "",
|
457 |
+
batch: BatchedDataDict = None
|
458 |
+
) -> torch.Tensor:
|
459 |
+
|
460 |
+
if self.fader is None:
|
461 |
+
self.attach_fader()
|
462 |
+
|
463 |
+
cond = batch.get("condition", None)
|
464 |
+
|
465 |
+
if cond is not None and cond.shape[0] == 1:
|
466 |
+
cond = cond.repeat(audio.shape[0], 1)
|
467 |
+
|
468 |
+
_, output = self.forward(
|
469 |
+
{"audio": {"mixture": audio},
|
470 |
+
"track": track,
|
471 |
+
"condition": cond,
|
472 |
+
}
|
473 |
+
) # TODO: support track properly
|
474 |
+
|
475 |
+
return output["audio"]
|
476 |
+
|
477 |
+
def on_test_epoch_start(self) -> None:
|
478 |
+
self.attach_fader(force_reattach=True)
|
479 |
+
|
480 |
+
def test_step(
|
481 |
+
self,
|
482 |
+
batch: BatchedDataDict,
|
483 |
+
batch_idx: int,
|
484 |
+
dataloader_idx: int = 0
|
485 |
+
) -> Any:
|
486 |
+
curr_test_prefix = f"test{dataloader_idx}"
|
487 |
+
|
488 |
+
# print(batch["audio"].keys())
|
489 |
+
|
490 |
+
if curr_test_prefix != self.test_prefix:
|
491 |
+
# print(f"Switching to test dataloader {dataloader_idx}")
|
492 |
+
if self.test_prefix is not None:
|
493 |
+
self._on_test_epoch_end()
|
494 |
+
self.test_prefix = curr_test_prefix
|
495 |
+
|
496 |
+
with torch.inference_mode():
|
497 |
+
_, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
498 |
+
# print(output)
|
499 |
+
self.update_metrics(batch, output, mode="test")
|
500 |
+
|
501 |
+
return output
|
502 |
+
|
503 |
+
def on_test_epoch_end(self) -> None:
|
504 |
+
self._on_test_epoch_end()
|
505 |
+
|
506 |
+
def _on_test_epoch_end(self) -> None:
|
507 |
+
metric_dict = self.compute_metrics(mode="test")
|
508 |
+
self.log_dict_with_prefix(metric_dict, self.test_prefix, prog_bar=True,
|
509 |
+
add_dataloader_idx=False)
|
510 |
+
# self.logger.save()
|
511 |
+
# print(self.test_prefix, "Test metrics:", metric_dict)
|
512 |
+
self.reset_metrics()
|
513 |
+
|
514 |
+
def predict_step(
|
515 |
+
self,
|
516 |
+
batch: BatchedDataDict,
|
517 |
+
batch_idx: int = 0,
|
518 |
+
dataloader_idx: int = 0,
|
519 |
+
include_track_name: Optional[bool] = None,
|
520 |
+
get_no_vox_combinations: bool = True,
|
521 |
+
get_residual: bool = False,
|
522 |
+
treat_batch_as_channels: bool = False,
|
523 |
+
fs: Optional[int] = None,
|
524 |
+
) -> Any:
|
525 |
+
assert self.predict_output_path is not None
|
526 |
+
|
527 |
+
batch_size = batch["audio"]["mixture"].shape[0]
|
528 |
+
|
529 |
+
if include_track_name is None:
|
530 |
+
include_track_name = batch_size > 1
|
531 |
+
|
532 |
+
with torch.inference_mode():
|
533 |
+
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
534 |
+
print('Pred test finished...')
|
535 |
+
torch.cuda.empty_cache()
|
536 |
+
metric_dict = {}
|
537 |
+
|
538 |
+
if get_residual:
|
539 |
+
mixture = batch["audio"]["mixture"]
|
540 |
+
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
541 |
+
residual = mixture - extracted
|
542 |
+
print(extracted.shape, mixture.shape, residual.shape)
|
543 |
+
|
544 |
+
output["audio"]["residual"] = residual
|
545 |
+
|
546 |
+
if get_no_vox_combinations:
|
547 |
+
no_vox_stems = [
|
548 |
+
stem for stem in output["audio"] if
|
549 |
+
stem not in self._VOX_STEMS
|
550 |
+
]
|
551 |
+
no_vox_combinations = chain.from_iterable(
|
552 |
+
combinations(no_vox_stems, r) for r in
|
553 |
+
range(2, len(no_vox_stems) + 1)
|
554 |
+
)
|
555 |
+
|
556 |
+
for combination in no_vox_combinations:
|
557 |
+
combination_ = list(combination)
|
558 |
+
output["audio"]["+".join(combination_)] = sum(
|
559 |
+
[output["audio"][stem] for stem in combination_]
|
560 |
+
)
|
561 |
+
|
562 |
+
if treat_batch_as_channels:
|
563 |
+
for stem in output["audio"]:
|
564 |
+
output["audio"][stem] = output["audio"][stem].reshape(
|
565 |
+
1, -1, output["audio"][stem].shape[-1]
|
566 |
+
)
|
567 |
+
batch_size = 1
|
568 |
+
|
569 |
+
for b in range(batch_size):
|
570 |
+
print("!!", b)
|
571 |
+
for stem in output["audio"]:
|
572 |
+
print(f"Saving audio for {stem} to {self.predict_output_path}")
|
573 |
+
track_name = batch["track"][b].split("/")[-1]
|
574 |
+
|
575 |
+
if batch.get("audio", {}).get(stem, None) is not None:
|
576 |
+
self.test_metrics[stem].reset()
|
577 |
+
metrics = self.test_metrics[stem](
|
578 |
+
batch["audio"][stem][[b], ...],
|
579 |
+
output["audio"][stem][[b], ...]
|
580 |
+
)
|
581 |
+
snr = metrics["snr"]
|
582 |
+
sisnr = metrics["sisnr"]
|
583 |
+
sdr = metrics["sdr"]
|
584 |
+
metric_dict[stem] = metrics
|
585 |
+
print(
|
586 |
+
track_name,
|
587 |
+
f"snr={snr:2.2f} dB",
|
588 |
+
f"sisnr={sisnr:2.2f}",
|
589 |
+
f"sdr={sdr:2.2f} dB",
|
590 |
+
)
|
591 |
+
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
592 |
+
else:
|
593 |
+
filename = f"{stem}.wav"
|
594 |
+
|
595 |
+
if include_track_name:
|
596 |
+
output_dir = os.path.join(
|
597 |
+
self.predict_output_path,
|
598 |
+
track_name
|
599 |
+
)
|
600 |
+
else:
|
601 |
+
output_dir = self.predict_output_path
|
602 |
+
|
603 |
+
os.makedirs(output_dir, exist_ok=True)
|
604 |
+
|
605 |
+
if fs is None:
|
606 |
+
fs = self.fs
|
607 |
+
|
608 |
+
ta.save(
|
609 |
+
os.path.join(output_dir, filename),
|
610 |
+
output["audio"][stem][b, ...].cpu(),
|
611 |
+
fs,
|
612 |
+
)
|
613 |
+
|
614 |
+
return metric_dict
|
615 |
+
|
616 |
+
def get_stems(
|
617 |
+
self,
|
618 |
+
batch: BatchedDataDict,
|
619 |
+
batch_idx: int = 0,
|
620 |
+
dataloader_idx: int = 0,
|
621 |
+
include_track_name: Optional[bool] = None,
|
622 |
+
get_no_vox_combinations: bool = True,
|
623 |
+
get_residual: bool = False,
|
624 |
+
treat_batch_as_channels: bool = False,
|
625 |
+
fs: Optional[int] = None,
|
626 |
+
) -> Any:
|
627 |
+
assert self.predict_output_path is not None
|
628 |
+
|
629 |
+
batch_size = batch["audio"]["mixture"].shape[0]
|
630 |
+
|
631 |
+
if include_track_name is None:
|
632 |
+
include_track_name = batch_size > 1
|
633 |
+
|
634 |
+
with torch.inference_mode():
|
635 |
+
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
636 |
+
torch.cuda.empty_cache()
|
637 |
+
metric_dict = {}
|
638 |
+
|
639 |
+
if get_residual:
|
640 |
+
mixture = batch["audio"]["mixture"]
|
641 |
+
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
642 |
+
residual = mixture - extracted
|
643 |
+
# print(extracted.shape, mixture.shape, residual.shape)
|
644 |
+
|
645 |
+
output["audio"]["residual"] = residual
|
646 |
+
|
647 |
+
if get_no_vox_combinations:
|
648 |
+
no_vox_stems = [
|
649 |
+
stem for stem in output["audio"] if
|
650 |
+
stem not in self._VOX_STEMS
|
651 |
+
]
|
652 |
+
no_vox_combinations = chain.from_iterable(
|
653 |
+
combinations(no_vox_stems, r) for r in
|
654 |
+
range(2, len(no_vox_stems) + 1)
|
655 |
+
)
|
656 |
+
|
657 |
+
for combination in no_vox_combinations:
|
658 |
+
combination_ = list(combination)
|
659 |
+
output["audio"]["+".join(combination_)] = sum(
|
660 |
+
[output["audio"][stem] for stem in combination_]
|
661 |
+
)
|
662 |
+
|
663 |
+
if treat_batch_as_channels:
|
664 |
+
for stem in output["audio"]:
|
665 |
+
output["audio"][stem] = output["audio"][stem].reshape(
|
666 |
+
1, -1, output["audio"][stem].shape[-1]
|
667 |
+
)
|
668 |
+
batch_size = 1
|
669 |
+
|
670 |
+
result = {}
|
671 |
+
for b in range(batch_size):
|
672 |
+
for stem in output["audio"]:
|
673 |
+
track_name = batch["track"][b].split("/")[-1]
|
674 |
+
|
675 |
+
if batch.get("audio", {}).get(stem, None) is not None:
|
676 |
+
self.test_metrics[stem].reset()
|
677 |
+
metrics = self.test_metrics[stem](
|
678 |
+
batch["audio"][stem][[b], ...],
|
679 |
+
output["audio"][stem][[b], ...]
|
680 |
+
)
|
681 |
+
snr = metrics["snr"]
|
682 |
+
sisnr = metrics["sisnr"]
|
683 |
+
sdr = metrics["sdr"]
|
684 |
+
metric_dict[stem] = metrics
|
685 |
+
print(
|
686 |
+
track_name,
|
687 |
+
f"snr={snr:2.2f} dB",
|
688 |
+
f"sisnr={sisnr:2.2f}",
|
689 |
+
f"sdr={sdr:2.2f} dB",
|
690 |
+
)
|
691 |
+
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
692 |
+
else:
|
693 |
+
filename = f"{stem}.wav"
|
694 |
+
|
695 |
+
if include_track_name:
|
696 |
+
output_dir = os.path.join(
|
697 |
+
self.predict_output_path,
|
698 |
+
track_name
|
699 |
+
)
|
700 |
+
else:
|
701 |
+
output_dir = self.predict_output_path
|
702 |
+
|
703 |
+
os.makedirs(output_dir, exist_ok=True)
|
704 |
+
|
705 |
+
if fs is None:
|
706 |
+
fs = self.fs
|
707 |
+
|
708 |
+
result[stem] = output["audio"][stem][b, ...].cpu().numpy()
|
709 |
+
|
710 |
+
return result
|
711 |
+
|
712 |
+
def load_state_dict(
|
713 |
+
self, state_dict: Mapping[str, Any], strict: bool = False
|
714 |
+
) -> Any:
|
715 |
+
|
716 |
+
return super().load_state_dict(state_dict, strict=False)
|
717 |
+
|
718 |
+
|
719 |
+
def set_predict_output_path(self, path: str) -> None:
|
720 |
+
self.predict_output_path = path
|
721 |
+
os.makedirs(self.predict_output_path, exist_ok=True)
|
722 |
+
|
723 |
+
self.attach_fader()
|
724 |
+
|
725 |
+
def attach_fader(self, force_reattach=False) -> None:
|
726 |
+
if self.fader is None or force_reattach:
|
727 |
+
self.fader = parse_fader_config(self.fader_config)
|
728 |
+
self.fader.to(self.device)
|
729 |
+
|
730 |
+
|
731 |
+
def log_dict_with_prefix(
|
732 |
+
self,
|
733 |
+
dict_: Dict[str, torch.Tensor],
|
734 |
+
prefix: str,
|
735 |
+
batch_size: Optional[int] = None,
|
736 |
+
**kwargs: Any
|
737 |
+
) -> None:
|
738 |
+
self.log_dict(
|
739 |
+
{f"{prefix}/{k}": v for k, v in dict_.items()},
|
740 |
+
batch_size=batch_size,
|
741 |
+
logger=True,
|
742 |
+
sync_dist=True,
|
743 |
+
**kwargs,
|
744 |
+
)
|
models/bandit/core/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .dnr.datamodule import DivideAndRemasterDataModule
|
2 |
+
from .musdb.datamodule import MUSDB18DataModule
|
models/bandit/core/data/_types.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Sequence, TypedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
AudioDict = Dict[str, torch.Tensor]
|
6 |
+
|
7 |
+
DataDict = TypedDict('DataDict', {'audio': AudioDict, 'track': str})
|
8 |
+
|
9 |
+
BatchedDataDict = TypedDict(
|
10 |
+
'BatchedDataDict',
|
11 |
+
{'audio': AudioDict, 'track': Sequence[str]}
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
class DataDictWithLanguage(TypedDict):
|
16 |
+
audio: AudioDict
|
17 |
+
track: str
|
18 |
+
language: str
|
models/bandit/core/data/augmentation.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from typing import Any, Dict, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch_audiomentations as tam
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from models.bandit.core.data._types import BatchedDataDict, DataDict
|
9 |
+
|
10 |
+
|
11 |
+
class BaseAugmentor(nn.Module, ABC):
|
12 |
+
def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
|
13 |
+
DataDict, BatchedDataDict]:
|
14 |
+
raise NotImplementedError
|
15 |
+
|
16 |
+
|
17 |
+
class StemAugmentor(BaseAugmentor):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
audiomentations: Dict[str, Dict[str, Any]],
|
21 |
+
fix_clipping: bool = True,
|
22 |
+
scaler_margin: float = 0.5,
|
23 |
+
apply_both_default_and_common: bool = False,
|
24 |
+
) -> None:
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
augmentations = {}
|
28 |
+
|
29 |
+
self.has_default = "[default]" in audiomentations
|
30 |
+
self.has_common = "[common]" in audiomentations
|
31 |
+
self.apply_both_default_and_common = apply_both_default_and_common
|
32 |
+
|
33 |
+
for stem in audiomentations:
|
34 |
+
if audiomentations[stem]["name"] == "Compose":
|
35 |
+
augmentations[stem] = getattr(
|
36 |
+
tam,
|
37 |
+
audiomentations[stem]["name"]
|
38 |
+
)(
|
39 |
+
[
|
40 |
+
getattr(tam, aug["name"])(**aug["kwargs"])
|
41 |
+
for aug in
|
42 |
+
audiomentations[stem]["kwargs"]["transforms"]
|
43 |
+
],
|
44 |
+
**audiomentations[stem]["kwargs"]["kwargs"],
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
augmentations[stem] = getattr(
|
48 |
+
tam,
|
49 |
+
audiomentations[stem]["name"]
|
50 |
+
)(
|
51 |
+
**audiomentations[stem]["kwargs"]
|
52 |
+
)
|
53 |
+
|
54 |
+
self.augmentations = nn.ModuleDict(augmentations)
|
55 |
+
self.fix_clipping = fix_clipping
|
56 |
+
self.scaler_margin = scaler_margin
|
57 |
+
|
58 |
+
def check_and_fix_clipping(
|
59 |
+
self, item: Union[DataDict, BatchedDataDict]
|
60 |
+
) -> Union[DataDict, BatchedDataDict]:
|
61 |
+
max_abs = []
|
62 |
+
|
63 |
+
for stem in item["audio"]:
|
64 |
+
max_abs.append(item["audio"][stem].abs().max().item())
|
65 |
+
|
66 |
+
if max(max_abs) > 1.0:
|
67 |
+
scaler = 1.0 / (max(max_abs) + torch.rand(
|
68 |
+
(1,),
|
69 |
+
device=item["audio"]["mixture"].device
|
70 |
+
) * self.scaler_margin)
|
71 |
+
|
72 |
+
for stem in item["audio"]:
|
73 |
+
item["audio"][stem] *= scaler
|
74 |
+
|
75 |
+
return item
|
76 |
+
|
77 |
+
def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
|
78 |
+
DataDict, BatchedDataDict]:
|
79 |
+
|
80 |
+
for stem in item["audio"]:
|
81 |
+
if stem == "mixture":
|
82 |
+
continue
|
83 |
+
|
84 |
+
if self.has_common:
|
85 |
+
item["audio"][stem] = self.augmentations["[common]"](
|
86 |
+
item["audio"][stem]
|
87 |
+
).samples
|
88 |
+
|
89 |
+
if stem in self.augmentations:
|
90 |
+
item["audio"][stem] = self.augmentations[stem](
|
91 |
+
item["audio"][stem]
|
92 |
+
).samples
|
93 |
+
elif self.has_default:
|
94 |
+
if not self.has_common or self.apply_both_default_and_common:
|
95 |
+
item["audio"][stem] = self.augmentations["[default]"](
|
96 |
+
item["audio"][stem]
|
97 |
+
).samples
|
98 |
+
|
99 |
+
item["audio"]["mixture"] = sum(
|
100 |
+
[item["audio"][stem] for stem in item["audio"]
|
101 |
+
if stem != "mixture"]
|
102 |
+
) # type: ignore[call-overload, assignment]
|
103 |
+
|
104 |
+
if self.fix_clipping:
|
105 |
+
item = self.check_and_fix_clipping(item)
|
106 |
+
|
107 |
+
return item
|
models/bandit/core/data/augmented.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Dict, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.utils import data
|
7 |
+
|
8 |
+
|
9 |
+
class AugmentedDataset(data.Dataset):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
dataset: data.Dataset,
|
13 |
+
augmentation: nn.Module = nn.Identity(),
|
14 |
+
target_length: Optional[int] = None,
|
15 |
+
) -> None:
|
16 |
+
warnings.warn(
|
17 |
+
"This class is no longer used. Attach augmentation to "
|
18 |
+
"the LightningSystem instead.",
|
19 |
+
DeprecationWarning,
|
20 |
+
)
|
21 |
+
|
22 |
+
self.dataset = dataset
|
23 |
+
self.augmentation = augmentation
|
24 |
+
|
25 |
+
self.ds_length: int = len(dataset) # type: ignore[arg-type]
|
26 |
+
self.length = target_length if target_length is not None else self.ds_length
|
27 |
+
|
28 |
+
def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str,
|
29 |
+
torch.Tensor]]]:
|
30 |
+
item = self.dataset[index % self.ds_length]
|
31 |
+
item = self.augmentation(item)
|
32 |
+
return item
|
33 |
+
|
34 |
+
def __len__(self) -> int:
|
35 |
+
return self.length
|
models/bandit/core/data/base.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import Any, Dict, List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pedalboard as pb
|
7 |
+
import torch
|
8 |
+
import torchaudio as ta
|
9 |
+
from torch.utils import data
|
10 |
+
|
11 |
+
from models.bandit.core.data._types import AudioDict, DataDict
|
12 |
+
|
13 |
+
|
14 |
+
class BaseSourceSeparationDataset(data.Dataset, ABC):
|
15 |
+
def __init__(
|
16 |
+
self, split: str,
|
17 |
+
stems: List[str],
|
18 |
+
files: List[str],
|
19 |
+
data_path: str,
|
20 |
+
fs: int,
|
21 |
+
npy_memmap: bool,
|
22 |
+
recompute_mixture: bool
|
23 |
+
):
|
24 |
+
self.split = split
|
25 |
+
self.stems = stems
|
26 |
+
self.stems_no_mixture = [s for s in stems if s != "mixture"]
|
27 |
+
self.files = files
|
28 |
+
self.data_path = data_path
|
29 |
+
self.fs = fs
|
30 |
+
self.npy_memmap = npy_memmap
|
31 |
+
self.recompute_mixture = recompute_mixture
|
32 |
+
|
33 |
+
@abstractmethod
|
34 |
+
def get_stem(
|
35 |
+
self,
|
36 |
+
*,
|
37 |
+
stem: str,
|
38 |
+
identifier: Dict[str, Any]
|
39 |
+
) -> torch.Tensor:
|
40 |
+
raise NotImplementedError
|
41 |
+
|
42 |
+
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
43 |
+
audio = {}
|
44 |
+
for stem in stems:
|
45 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier)
|
46 |
+
|
47 |
+
return audio
|
48 |
+
|
49 |
+
def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
|
50 |
+
|
51 |
+
if self.recompute_mixture:
|
52 |
+
audio = self._get_audio(
|
53 |
+
self.stems_no_mixture,
|
54 |
+
identifier=identifier
|
55 |
+
)
|
56 |
+
audio["mixture"] = self.compute_mixture(audio)
|
57 |
+
return audio
|
58 |
+
else:
|
59 |
+
return self._get_audio(self.stems, identifier=identifier)
|
60 |
+
|
61 |
+
@abstractmethod
|
62 |
+
def get_identifier(self, index: int) -> Dict[str, Any]:
|
63 |
+
pass
|
64 |
+
|
65 |
+
def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
|
66 |
+
|
67 |
+
return sum(
|
68 |
+
audio[stem] for stem in audio if stem != "mixture"
|
69 |
+
)
|
models/bandit/core/data/dnr/__init__.py
ADDED
File without changes
|
models/bandit/core/data/dnr/datamodule.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Mapping, Optional
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
from .dataset import (
|
7 |
+
DivideAndRemasterDataset,
|
8 |
+
DivideAndRemasterDeterministicChunkDataset,
|
9 |
+
DivideAndRemasterRandomChunkDataset,
|
10 |
+
DivideAndRemasterRandomChunkDatasetWithSpeechReverb
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def DivideAndRemasterDataModule(
|
15 |
+
data_root: str = "$DATA_ROOT/DnR/v2",
|
16 |
+
batch_size: int = 2,
|
17 |
+
num_workers: int = 8,
|
18 |
+
train_kwargs: Optional[Mapping] = None,
|
19 |
+
val_kwargs: Optional[Mapping] = None,
|
20 |
+
test_kwargs: Optional[Mapping] = None,
|
21 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
22 |
+
use_speech_reverb: bool = False
|
23 |
+
# augmentor=None
|
24 |
+
) -> pl.LightningDataModule:
|
25 |
+
if train_kwargs is None:
|
26 |
+
train_kwargs = {}
|
27 |
+
|
28 |
+
if val_kwargs is None:
|
29 |
+
val_kwargs = {}
|
30 |
+
|
31 |
+
if test_kwargs is None:
|
32 |
+
test_kwargs = {}
|
33 |
+
|
34 |
+
if datamodule_kwargs is None:
|
35 |
+
datamodule_kwargs = {}
|
36 |
+
|
37 |
+
if num_workers is None:
|
38 |
+
num_workers = os.cpu_count()
|
39 |
+
|
40 |
+
if num_workers is None:
|
41 |
+
num_workers = 32
|
42 |
+
|
43 |
+
num_workers = min(num_workers, 64)
|
44 |
+
|
45 |
+
if use_speech_reverb:
|
46 |
+
train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
|
47 |
+
else:
|
48 |
+
train_cls = DivideAndRemasterRandomChunkDataset
|
49 |
+
|
50 |
+
train_dataset = train_cls(
|
51 |
+
data_root, "train", **train_kwargs
|
52 |
+
)
|
53 |
+
|
54 |
+
# if augmentor is not None:
|
55 |
+
# train_dataset = AugmentedDataset(train_dataset, augmentor)
|
56 |
+
|
57 |
+
datamodule = pl.LightningDataModule.from_datasets(
|
58 |
+
train_dataset=train_dataset,
|
59 |
+
val_dataset=DivideAndRemasterDeterministicChunkDataset(
|
60 |
+
data_root, "val", **val_kwargs
|
61 |
+
),
|
62 |
+
test_dataset=DivideAndRemasterDataset(
|
63 |
+
data_root,
|
64 |
+
"test",
|
65 |
+
**test_kwargs
|
66 |
+
),
|
67 |
+
batch_size=batch_size,
|
68 |
+
num_workers=num_workers,
|
69 |
+
**datamodule_kwargs
|
70 |
+
)
|
71 |
+
|
72 |
+
datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
|
73 |
+
|
74 |
+
return datamodule
|
models/bandit/core/data/dnr/dataset.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC
|
3 |
+
from typing import Any, Dict, List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pedalboard as pb
|
7 |
+
import torch
|
8 |
+
import torchaudio as ta
|
9 |
+
from torch.utils import data
|
10 |
+
|
11 |
+
from models.bandit.core.data._types import AudioDict, DataDict
|
12 |
+
from models.bandit.core.data.base import BaseSourceSeparationDataset
|
13 |
+
|
14 |
+
|
15 |
+
class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
|
16 |
+
ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
|
17 |
+
STEM_NAME_MAP = {
|
18 |
+
"mixture": "mix",
|
19 |
+
"speech": "speech",
|
20 |
+
"music": "music",
|
21 |
+
"effects": "sfx",
|
22 |
+
}
|
23 |
+
SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
|
24 |
+
|
25 |
+
FULL_TRACK_LENGTH_SECOND = 60
|
26 |
+
FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
split: str,
|
31 |
+
stems: List[str],
|
32 |
+
files: List[str],
|
33 |
+
data_path: str,
|
34 |
+
fs: int = 44100,
|
35 |
+
npy_memmap: bool = True,
|
36 |
+
recompute_mixture: bool = False,
|
37 |
+
) -> None:
|
38 |
+
super().__init__(
|
39 |
+
split=split,
|
40 |
+
stems=stems,
|
41 |
+
files=files,
|
42 |
+
data_path=data_path,
|
43 |
+
fs=fs,
|
44 |
+
npy_memmap=npy_memmap,
|
45 |
+
recompute_mixture=recompute_mixture
|
46 |
+
)
|
47 |
+
|
48 |
+
def get_stem(
|
49 |
+
self,
|
50 |
+
*,
|
51 |
+
stem: str,
|
52 |
+
identifier: Dict[str, Any]
|
53 |
+
) -> torch.Tensor:
|
54 |
+
|
55 |
+
if stem == "mne":
|
56 |
+
return self.get_stem(
|
57 |
+
stem="music",
|
58 |
+
identifier=identifier) + self.get_stem(
|
59 |
+
stem="effects",
|
60 |
+
identifier=identifier)
|
61 |
+
|
62 |
+
track = identifier["track"]
|
63 |
+
path = os.path.join(self.data_path, track)
|
64 |
+
|
65 |
+
if self.npy_memmap:
|
66 |
+
audio = np.load(
|
67 |
+
os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"),
|
68 |
+
mmap_mode="r"
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
# noinspection PyUnresolvedReferences
|
72 |
+
audio, _ = ta.load(
|
73 |
+
os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav")
|
74 |
+
)
|
75 |
+
|
76 |
+
return audio
|
77 |
+
|
78 |
+
def get_identifier(self, index):
|
79 |
+
return dict(track=self.files[index])
|
80 |
+
|
81 |
+
def __getitem__(self, index: int) -> DataDict:
|
82 |
+
identifier = self.get_identifier(index)
|
83 |
+
audio = self.get_audio(identifier)
|
84 |
+
|
85 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
86 |
+
|
87 |
+
|
88 |
+
class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
data_root: str,
|
92 |
+
split: str,
|
93 |
+
stems: Optional[List[str]] = None,
|
94 |
+
fs: int = 44100,
|
95 |
+
npy_memmap: bool = True,
|
96 |
+
) -> None:
|
97 |
+
|
98 |
+
if stems is None:
|
99 |
+
stems = self.ALLOWED_STEMS
|
100 |
+
self.stems = stems
|
101 |
+
|
102 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
103 |
+
|
104 |
+
files = sorted(os.listdir(data_path))
|
105 |
+
files = [
|
106 |
+
f
|
107 |
+
for f in files
|
108 |
+
if (not f.startswith(".")) and os.path.isdir(
|
109 |
+
os.path.join(data_path, f)
|
110 |
+
)
|
111 |
+
]
|
112 |
+
# pprint(list(enumerate(files)))
|
113 |
+
if split == "train":
|
114 |
+
assert len(files) == 3406, len(files)
|
115 |
+
elif split == "val":
|
116 |
+
assert len(files) == 487, len(files)
|
117 |
+
elif split == "test":
|
118 |
+
assert len(files) == 973, len(files)
|
119 |
+
|
120 |
+
self.n_tracks = len(files)
|
121 |
+
|
122 |
+
super().__init__(
|
123 |
+
data_path=data_path,
|
124 |
+
split=split,
|
125 |
+
stems=stems,
|
126 |
+
files=files,
|
127 |
+
fs=fs,
|
128 |
+
npy_memmap=npy_memmap,
|
129 |
+
)
|
130 |
+
|
131 |
+
def __len__(self) -> int:
|
132 |
+
return self.n_tracks
|
133 |
+
|
134 |
+
|
135 |
+
class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
data_root: str,
|
139 |
+
split: str,
|
140 |
+
target_length: int,
|
141 |
+
chunk_size_second: float,
|
142 |
+
stems: Optional[List[str]] = None,
|
143 |
+
fs: int = 44100,
|
144 |
+
npy_memmap: bool = True,
|
145 |
+
) -> None:
|
146 |
+
|
147 |
+
if stems is None:
|
148 |
+
stems = self.ALLOWED_STEMS
|
149 |
+
self.stems = stems
|
150 |
+
|
151 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
152 |
+
|
153 |
+
files = sorted(os.listdir(data_path))
|
154 |
+
files = [
|
155 |
+
f
|
156 |
+
for f in files
|
157 |
+
if (not f.startswith(".")) and os.path.isdir(
|
158 |
+
os.path.join(data_path, f)
|
159 |
+
)
|
160 |
+
]
|
161 |
+
|
162 |
+
if split == "train":
|
163 |
+
assert len(files) == 3406, len(files)
|
164 |
+
elif split == "val":
|
165 |
+
assert len(files) == 487, len(files)
|
166 |
+
elif split == "test":
|
167 |
+
assert len(files) == 973, len(files)
|
168 |
+
|
169 |
+
self.n_tracks = len(files)
|
170 |
+
|
171 |
+
self.target_length = target_length
|
172 |
+
self.chunk_size = int(chunk_size_second * fs)
|
173 |
+
|
174 |
+
super().__init__(
|
175 |
+
data_path=data_path,
|
176 |
+
split=split,
|
177 |
+
stems=stems,
|
178 |
+
files=files,
|
179 |
+
fs=fs,
|
180 |
+
npy_memmap=npy_memmap,
|
181 |
+
)
|
182 |
+
|
183 |
+
def __len__(self) -> int:
|
184 |
+
return self.target_length
|
185 |
+
|
186 |
+
def get_identifier(self, index):
|
187 |
+
return super().get_identifier(index % self.n_tracks)
|
188 |
+
|
189 |
+
def get_stem(
|
190 |
+
self,
|
191 |
+
*,
|
192 |
+
stem: str,
|
193 |
+
identifier: Dict[str, Any],
|
194 |
+
chunk_here: bool = False,
|
195 |
+
) -> torch.Tensor:
|
196 |
+
|
197 |
+
stem = super().get_stem(
|
198 |
+
stem=stem,
|
199 |
+
identifier=identifier
|
200 |
+
)
|
201 |
+
|
202 |
+
if chunk_here:
|
203 |
+
start = np.random.randint(
|
204 |
+
0,
|
205 |
+
self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
|
206 |
+
)
|
207 |
+
end = start + self.chunk_size
|
208 |
+
|
209 |
+
stem = stem[:, start:end]
|
210 |
+
|
211 |
+
return stem
|
212 |
+
|
213 |
+
def __getitem__(self, index: int) -> DataDict:
|
214 |
+
identifier = self.get_identifier(index)
|
215 |
+
# self.index_lock = index
|
216 |
+
audio = self.get_audio(identifier)
|
217 |
+
# self.index_lock = None
|
218 |
+
|
219 |
+
start = np.random.randint(
|
220 |
+
0,
|
221 |
+
self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
|
222 |
+
)
|
223 |
+
end = start + self.chunk_size
|
224 |
+
|
225 |
+
audio = {
|
226 |
+
k: v[:, start:end] for k, v in audio.items()
|
227 |
+
}
|
228 |
+
|
229 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
230 |
+
|
231 |
+
|
232 |
+
class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
data_root: str,
|
236 |
+
split: str,
|
237 |
+
chunk_size_second: float,
|
238 |
+
hop_size_second: float,
|
239 |
+
stems: Optional[List[str]] = None,
|
240 |
+
fs: int = 44100,
|
241 |
+
npy_memmap: bool = True,
|
242 |
+
) -> None:
|
243 |
+
|
244 |
+
if stems is None:
|
245 |
+
stems = self.ALLOWED_STEMS
|
246 |
+
self.stems = stems
|
247 |
+
|
248 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
249 |
+
|
250 |
+
files = sorted(os.listdir(data_path))
|
251 |
+
files = [
|
252 |
+
f
|
253 |
+
for f in files
|
254 |
+
if (not f.startswith(".")) and os.path.isdir(
|
255 |
+
os.path.join(data_path, f)
|
256 |
+
)
|
257 |
+
]
|
258 |
+
# pprint(list(enumerate(files)))
|
259 |
+
if split == "train":
|
260 |
+
assert len(files) == 3406, len(files)
|
261 |
+
elif split == "val":
|
262 |
+
assert len(files) == 487, len(files)
|
263 |
+
elif split == "test":
|
264 |
+
assert len(files) == 973, len(files)
|
265 |
+
|
266 |
+
self.n_tracks = len(files)
|
267 |
+
|
268 |
+
self.chunk_size = int(chunk_size_second * fs)
|
269 |
+
self.hop_size = int(hop_size_second * fs)
|
270 |
+
self.n_chunks_per_track = int(
|
271 |
+
(
|
272 |
+
self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
|
273 |
+
)
|
274 |
+
|
275 |
+
self.length = self.n_tracks * self.n_chunks_per_track
|
276 |
+
|
277 |
+
super().__init__(
|
278 |
+
data_path=data_path,
|
279 |
+
split=split,
|
280 |
+
stems=stems,
|
281 |
+
files=files,
|
282 |
+
fs=fs,
|
283 |
+
npy_memmap=npy_memmap,
|
284 |
+
)
|
285 |
+
|
286 |
+
def get_identifier(self, index):
|
287 |
+
return super().get_identifier(index % self.n_tracks)
|
288 |
+
|
289 |
+
def __len__(self) -> int:
|
290 |
+
return self.length
|
291 |
+
|
292 |
+
def __getitem__(self, item: int) -> DataDict:
|
293 |
+
|
294 |
+
index = item % self.n_tracks
|
295 |
+
chunk = item // self.n_tracks
|
296 |
+
|
297 |
+
data_ = super().__getitem__(index)
|
298 |
+
|
299 |
+
audio = data_["audio"]
|
300 |
+
|
301 |
+
start = chunk * self.hop_size
|
302 |
+
end = start + self.chunk_size
|
303 |
+
|
304 |
+
for stem in self.stems:
|
305 |
+
data_["audio"][stem] = audio[stem][:, start:end]
|
306 |
+
|
307 |
+
return data_
|
308 |
+
|
309 |
+
|
310 |
+
class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
311 |
+
DivideAndRemasterRandomChunkDataset
|
312 |
+
):
|
313 |
+
def __init__(
|
314 |
+
self,
|
315 |
+
data_root: str,
|
316 |
+
split: str,
|
317 |
+
target_length: int,
|
318 |
+
chunk_size_second: float,
|
319 |
+
stems: Optional[List[str]] = None,
|
320 |
+
fs: int = 44100,
|
321 |
+
npy_memmap: bool = True,
|
322 |
+
) -> None:
|
323 |
+
|
324 |
+
if stems is None:
|
325 |
+
stems = self.ALLOWED_STEMS
|
326 |
+
|
327 |
+
stems_no_mixture = [s for s in stems if s != "mixture"]
|
328 |
+
|
329 |
+
super().__init__(
|
330 |
+
data_root=data_root,
|
331 |
+
split=split,
|
332 |
+
target_length=target_length,
|
333 |
+
chunk_size_second=chunk_size_second,
|
334 |
+
stems=stems_no_mixture,
|
335 |
+
fs=fs,
|
336 |
+
npy_memmap=npy_memmap,
|
337 |
+
)
|
338 |
+
|
339 |
+
self.stems = stems
|
340 |
+
self.stems_no_mixture = stems_no_mixture
|
341 |
+
|
342 |
+
def __getitem__(self, index: int) -> DataDict:
|
343 |
+
|
344 |
+
data_ = super().__getitem__(index)
|
345 |
+
|
346 |
+
dry = data_["audio"]["speech"][:]
|
347 |
+
n_samples = dry.shape[-1]
|
348 |
+
|
349 |
+
wet_level = np.random.rand()
|
350 |
+
|
351 |
+
speech = pb.Reverb(
|
352 |
+
room_size=np.random.rand(),
|
353 |
+
damping=np.random.rand(),
|
354 |
+
wet_level=wet_level,
|
355 |
+
dry_level=(1 - wet_level),
|
356 |
+
width=np.random.rand()
|
357 |
+
).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
|
358 |
+
|
359 |
+
data_["audio"]["speech"] = speech
|
360 |
+
|
361 |
+
data_["audio"]["mixture"] = sum(
|
362 |
+
[data_["audio"][s] for s in self.stems_no_mixture]
|
363 |
+
)
|
364 |
+
|
365 |
+
return data_
|
366 |
+
|
367 |
+
def __len__(self) -> int:
|
368 |
+
return super().__len__()
|
369 |
+
|
370 |
+
|
371 |
+
if __name__ == "__main__":
|
372 |
+
|
373 |
+
from pprint import pprint
|
374 |
+
from tqdm import tqdm
|
375 |
+
|
376 |
+
for split_ in ["train", "val", "test"]:
|
377 |
+
ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
378 |
+
data_root="$DATA_ROOT/DnR/v2np",
|
379 |
+
split=split_,
|
380 |
+
target_length=100,
|
381 |
+
chunk_size_second=6.0
|
382 |
+
)
|
383 |
+
|
384 |
+
print(split_, len(ds))
|
385 |
+
|
386 |
+
for track_ in tqdm(ds): # type: ignore
|
387 |
+
pprint(track_)
|
388 |
+
track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
|
389 |
+
pprint(track_)
|
390 |
+
# break
|
391 |
+
|
392 |
+
break
|
models/bandit/core/data/dnr/preprocess.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torchaudio as ta
|
7 |
+
from tqdm.contrib.concurrent import process_map
|
8 |
+
|
9 |
+
|
10 |
+
def process_one(inputs: Tuple[str, str, int]) -> None:
|
11 |
+
infile, outfile, target_fs = inputs
|
12 |
+
|
13 |
+
dir = os.path.dirname(outfile)
|
14 |
+
os.makedirs(dir, exist_ok=True)
|
15 |
+
|
16 |
+
data, fs = ta.load(infile)
|
17 |
+
|
18 |
+
if fs != target_fs:
|
19 |
+
data = ta.functional.resample(data, fs, target_fs, resampling_method="sinc_interp_kaiser")
|
20 |
+
fs = target_fs
|
21 |
+
|
22 |
+
data = data.numpy()
|
23 |
+
data = data.astype(np.float32)
|
24 |
+
|
25 |
+
if os.path.exists(outfile):
|
26 |
+
data_ = np.load(outfile)
|
27 |
+
if np.allclose(data, data_):
|
28 |
+
return
|
29 |
+
|
30 |
+
np.save(outfile, data)
|
31 |
+
|
32 |
+
|
33 |
+
def preprocess(
|
34 |
+
data_path: str,
|
35 |
+
output_path: str,
|
36 |
+
fs: int
|
37 |
+
) -> None:
|
38 |
+
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
39 |
+
print(files)
|
40 |
+
outfiles = [
|
41 |
+
f.replace(data_path, output_path).replace(".wav", ".npy") for f in
|
42 |
+
files
|
43 |
+
]
|
44 |
+
|
45 |
+
os.makedirs(output_path, exist_ok=True)
|
46 |
+
inputs = list(zip(files, outfiles, [fs] * len(files)))
|
47 |
+
|
48 |
+
process_map(process_one, inputs, chunksize=32)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
import fire
|
53 |
+
|
54 |
+
fire.Fire()
|
models/bandit/core/data/musdb/__init__.py
ADDED
File without changes
|
models/bandit/core/data/musdb/datamodule.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from typing import Mapping, Optional
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
from models.bandit.core.data.musdb.dataset import (
|
7 |
+
MUSDB18BaseDataset,
|
8 |
+
MUSDB18FullTrackDataset,
|
9 |
+
MUSDB18SadDataset,
|
10 |
+
MUSDB18SadOnTheFlyAugmentedDataset
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def MUSDB18DataModule(
|
15 |
+
data_root: str = "$DATA_ROOT/MUSDB18/HQ",
|
16 |
+
target_stem: str = "vocals",
|
17 |
+
batch_size: int = 2,
|
18 |
+
num_workers: int = 8,
|
19 |
+
train_kwargs: Optional[Mapping] = None,
|
20 |
+
val_kwargs: Optional[Mapping] = None,
|
21 |
+
test_kwargs: Optional[Mapping] = None,
|
22 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
23 |
+
use_on_the_fly: bool = True,
|
24 |
+
npy_memmap: bool = True
|
25 |
+
) -> pl.LightningDataModule:
|
26 |
+
if train_kwargs is None:
|
27 |
+
train_kwargs = {}
|
28 |
+
|
29 |
+
if val_kwargs is None:
|
30 |
+
val_kwargs = {}
|
31 |
+
|
32 |
+
if test_kwargs is None:
|
33 |
+
test_kwargs = {}
|
34 |
+
|
35 |
+
if datamodule_kwargs is None:
|
36 |
+
datamodule_kwargs = {}
|
37 |
+
|
38 |
+
train_dataset: MUSDB18BaseDataset
|
39 |
+
|
40 |
+
if use_on_the_fly:
|
41 |
+
train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
|
42 |
+
data_root=os.path.join(data_root, "saded-np"),
|
43 |
+
split="train",
|
44 |
+
target_stem=target_stem,
|
45 |
+
**train_kwargs
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
train_dataset = MUSDB18SadDataset(
|
49 |
+
data_root=os.path.join(data_root, "saded-np"),
|
50 |
+
split="train",
|
51 |
+
target_stem=target_stem,
|
52 |
+
**train_kwargs
|
53 |
+
)
|
54 |
+
|
55 |
+
datamodule = pl.LightningDataModule.from_datasets(
|
56 |
+
train_dataset=train_dataset,
|
57 |
+
val_dataset=MUSDB18SadDataset(
|
58 |
+
data_root=os.path.join(data_root, "saded-np"),
|
59 |
+
split="val",
|
60 |
+
target_stem=target_stem,
|
61 |
+
**val_kwargs
|
62 |
+
),
|
63 |
+
test_dataset=MUSDB18FullTrackDataset(
|
64 |
+
data_root=os.path.join(data_root, "canonical"),
|
65 |
+
split="test",
|
66 |
+
**test_kwargs
|
67 |
+
),
|
68 |
+
batch_size=batch_size,
|
69 |
+
num_workers=num_workers,
|
70 |
+
**datamodule_kwargs
|
71 |
+
)
|
72 |
+
|
73 |
+
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
74 |
+
datamodule.test_dataloader
|
75 |
+
)
|
76 |
+
|
77 |
+
return datamodule
|
models/bandit/core/data/musdb/dataset.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC
|
3 |
+
from typing import List, Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchaudio as ta
|
8 |
+
from torch.utils import data
|
9 |
+
|
10 |
+
from models.bandit.core.data._types import AudioDict, DataDict
|
11 |
+
from models.bandit.core.data.base import BaseSourceSeparationDataset
|
12 |
+
|
13 |
+
|
14 |
+
class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
|
15 |
+
|
16 |
+
ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
split: str,
|
21 |
+
stems: List[str],
|
22 |
+
files: List[str],
|
23 |
+
data_path: str,
|
24 |
+
fs: int = 44100,
|
25 |
+
npy_memmap=False,
|
26 |
+
) -> None:
|
27 |
+
super().__init__(
|
28 |
+
split=split,
|
29 |
+
stems=stems,
|
30 |
+
files=files,
|
31 |
+
data_path=data_path,
|
32 |
+
fs=fs,
|
33 |
+
npy_memmap=npy_memmap,
|
34 |
+
recompute_mixture=False
|
35 |
+
)
|
36 |
+
|
37 |
+
def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
|
38 |
+
track = identifier["track"]
|
39 |
+
path = os.path.join(self.data_path, track)
|
40 |
+
# noinspection PyUnresolvedReferences
|
41 |
+
|
42 |
+
if self.npy_memmap:
|
43 |
+
audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
|
44 |
+
else:
|
45 |
+
audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
|
46 |
+
|
47 |
+
return audio
|
48 |
+
|
49 |
+
def get_identifier(self, index):
|
50 |
+
return dict(track=self.files[index])
|
51 |
+
|
52 |
+
def __getitem__(self, index: int) -> DataDict:
|
53 |
+
identifier = self.get_identifier(index)
|
54 |
+
audio = self.get_audio(identifier)
|
55 |
+
|
56 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
57 |
+
|
58 |
+
|
59 |
+
class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
|
60 |
+
|
61 |
+
N_TRAIN_TRACKS = 100
|
62 |
+
N_TEST_TRACKS = 50
|
63 |
+
VALIDATION_FILES = [
|
64 |
+
"Actions - One Minute Smile",
|
65 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
66 |
+
"Johnny Lokke - Promises & Lies",
|
67 |
+
"Patrick Talbot - A Reason To Leave",
|
68 |
+
"Triviul - Angelsaint",
|
69 |
+
"Alexander Ross - Goodbye Bolero",
|
70 |
+
"Fergessen - Nos Palpitants",
|
71 |
+
"Leaf - Summerghost",
|
72 |
+
"Skelpolu - Human Mistakes",
|
73 |
+
"Young Griffo - Pennies",
|
74 |
+
"ANiMAL - Rockshow",
|
75 |
+
"James May - On The Line",
|
76 |
+
"Meaxic - Take A Step",
|
77 |
+
"Traffic Experiment - Sirens",
|
78 |
+
]
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self, data_root: str, split: str, stems: Optional[List[
|
82 |
+
str]] = None
|
83 |
+
) -> None:
|
84 |
+
|
85 |
+
if stems is None:
|
86 |
+
stems = self.ALLOWED_STEMS
|
87 |
+
self.stems = stems
|
88 |
+
|
89 |
+
if split == "test":
|
90 |
+
subset = "test"
|
91 |
+
elif split in ["train", "val"]:
|
92 |
+
subset = "train"
|
93 |
+
else:
|
94 |
+
raise NameError
|
95 |
+
|
96 |
+
data_path = os.path.join(data_root, subset)
|
97 |
+
|
98 |
+
files = sorted(os.listdir(data_path))
|
99 |
+
files = [f for f in files if not f.startswith(".")]
|
100 |
+
# pprint(list(enumerate(files)))
|
101 |
+
if subset == "train":
|
102 |
+
assert len(files) == 100, len(files)
|
103 |
+
if split == "train":
|
104 |
+
files = [f for f in files if f not in self.VALIDATION_FILES]
|
105 |
+
assert len(files) == 100 - len(self.VALIDATION_FILES)
|
106 |
+
else:
|
107 |
+
files = [f for f in files if f in self.VALIDATION_FILES]
|
108 |
+
assert len(files) == len(self.VALIDATION_FILES)
|
109 |
+
else:
|
110 |
+
split = "test"
|
111 |
+
assert len(files) == 50
|
112 |
+
|
113 |
+
self.n_tracks = len(files)
|
114 |
+
|
115 |
+
super().__init__(
|
116 |
+
data_path=data_path,
|
117 |
+
split=split,
|
118 |
+
stems=stems,
|
119 |
+
files=files
|
120 |
+
)
|
121 |
+
|
122 |
+
def __len__(self) -> int:
|
123 |
+
return self.n_tracks
|
124 |
+
|
125 |
+
class MUSDB18SadDataset(MUSDB18BaseDataset):
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
data_root: str,
|
129 |
+
split: str,
|
130 |
+
target_stem: str,
|
131 |
+
stems: Optional[List[str]] = None,
|
132 |
+
target_length: Optional[int] = None,
|
133 |
+
npy_memmap=False,
|
134 |
+
) -> None:
|
135 |
+
|
136 |
+
if stems is None:
|
137 |
+
stems = self.ALLOWED_STEMS
|
138 |
+
|
139 |
+
data_path = os.path.join(data_root, target_stem, split)
|
140 |
+
|
141 |
+
files = sorted(os.listdir(data_path))
|
142 |
+
files = [f for f in files if not f.startswith(".")]
|
143 |
+
|
144 |
+
super().__init__(
|
145 |
+
data_path=data_path,
|
146 |
+
split=split,
|
147 |
+
stems=stems,
|
148 |
+
files=files,
|
149 |
+
npy_memmap=npy_memmap
|
150 |
+
)
|
151 |
+
self.n_segments = len(files)
|
152 |
+
self.target_stem = target_stem
|
153 |
+
self.target_length = (
|
154 |
+
target_length if target_length is not None else self.n_segments
|
155 |
+
)
|
156 |
+
|
157 |
+
def __len__(self) -> int:
|
158 |
+
return self.target_length
|
159 |
+
|
160 |
+
def __getitem__(self, index: int) -> DataDict:
|
161 |
+
|
162 |
+
index = index % self.n_segments
|
163 |
+
|
164 |
+
return super().__getitem__(index)
|
165 |
+
|
166 |
+
def get_identifier(self, index):
|
167 |
+
return super().get_identifier(index % self.n_segments)
|
168 |
+
|
169 |
+
|
170 |
+
class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
data_root: str,
|
174 |
+
split: str,
|
175 |
+
target_stem: str,
|
176 |
+
stems: Optional[List[str]] = None,
|
177 |
+
target_length: int = 20000,
|
178 |
+
apply_probability: Optional[float] = None,
|
179 |
+
chunk_size_second: float = 3.0,
|
180 |
+
random_scale_range_db: Tuple[float, float] = (-10, 10),
|
181 |
+
drop_probability: float = 0.1,
|
182 |
+
rescale: bool = True,
|
183 |
+
) -> None:
|
184 |
+
super().__init__(data_root, split, target_stem, stems)
|
185 |
+
|
186 |
+
if apply_probability is None:
|
187 |
+
apply_probability = (
|
188 |
+
target_length - self.n_segments) / target_length
|
189 |
+
|
190 |
+
self.apply_probability = apply_probability
|
191 |
+
self.drop_probability = drop_probability
|
192 |
+
self.chunk_size_second = chunk_size_second
|
193 |
+
self.random_scale_range_db = random_scale_range_db
|
194 |
+
self.rescale = rescale
|
195 |
+
|
196 |
+
self.chunk_size_sample = int(self.chunk_size_second * self.fs)
|
197 |
+
self.target_length = target_length
|
198 |
+
|
199 |
+
def __len__(self) -> int:
|
200 |
+
return self.target_length
|
201 |
+
|
202 |
+
def __getitem__(self, index: int) -> DataDict:
|
203 |
+
|
204 |
+
index = index % self.n_segments
|
205 |
+
|
206 |
+
# if np.random.rand() > self.apply_probability:
|
207 |
+
# return super().__getitem__(index)
|
208 |
+
|
209 |
+
audio = {}
|
210 |
+
identifier = self.get_identifier(index)
|
211 |
+
|
212 |
+
# assert self.target_stem in self.stems_no_mixture
|
213 |
+
for stem in self.stems_no_mixture:
|
214 |
+
if stem == self.target_stem:
|
215 |
+
identifier_ = identifier
|
216 |
+
else:
|
217 |
+
if np.random.rand() < self.apply_probability:
|
218 |
+
index_ = np.random.randint(self.n_segments)
|
219 |
+
identifier_ = self.get_identifier(index_)
|
220 |
+
else:
|
221 |
+
identifier_ = identifier
|
222 |
+
|
223 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
|
224 |
+
|
225 |
+
# if stem == self.target_stem:
|
226 |
+
|
227 |
+
if self.chunk_size_sample < audio[stem].shape[-1]:
|
228 |
+
chunk_start = np.random.randint(
|
229 |
+
audio[stem].shape[-1] - self.chunk_size_sample
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
chunk_start = 0
|
233 |
+
|
234 |
+
if np.random.rand() < self.drop_probability:
|
235 |
+
# db_scale = "-inf"
|
236 |
+
linear_scale = 0.0
|
237 |
+
else:
|
238 |
+
db_scale = np.random.uniform(*self.random_scale_range_db)
|
239 |
+
linear_scale = np.power(10, db_scale / 20)
|
240 |
+
# db_scale = f"{db_scale:+2.1f}"
|
241 |
+
# print(linear_scale)
|
242 |
+
audio[stem][...,
|
243 |
+
chunk_start: chunk_start + self.chunk_size_sample] = (
|
244 |
+
linear_scale
|
245 |
+
* audio[stem][...,
|
246 |
+
chunk_start: chunk_start + self.chunk_size_sample]
|
247 |
+
)
|
248 |
+
|
249 |
+
audio["mixture"] = self.compute_mixture(audio)
|
250 |
+
|
251 |
+
if self.rescale:
|
252 |
+
max_abs_val = max(
|
253 |
+
[torch.max(torch.abs(audio[stem])) for stem in self.stems]
|
254 |
+
) # type: ignore[type-var]
|
255 |
+
if max_abs_val > 1:
|
256 |
+
audio = {k: v / max_abs_val for k, v in audio.items()}
|
257 |
+
|
258 |
+
track = identifier["track"]
|
259 |
+
|
260 |
+
return {"audio": audio, "track": f"{self.split}/{track}"}
|
261 |
+
|
262 |
+
# if __name__ == "__main__":
|
263 |
+
#
|
264 |
+
# from pprint import pprint
|
265 |
+
# from tqdm import tqdm
|
266 |
+
#
|
267 |
+
# for split_ in ["train", "val", "test"]:
|
268 |
+
# ds = MUSDB18SadOnTheFlyAugmentedDataset(
|
269 |
+
# data_root="$DATA_ROOT/MUSDB18/HQ/saded",
|
270 |
+
# split=split_,
|
271 |
+
# target_stem="vocals"
|
272 |
+
# )
|
273 |
+
#
|
274 |
+
# print(split_, len(ds))
|
275 |
+
#
|
276 |
+
# for track_ in tqdm(ds):
|
277 |
+
# track_["audio"] = {
|
278 |
+
# k: v.shape for k, v in track_["audio"].items()
|
279 |
+
# }
|
280 |
+
# pprint(track_)
|
models/bandit/core/data/musdb/preprocess.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchaudio as ta
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from tqdm.contrib.concurrent import process_map
|
10 |
+
|
11 |
+
from core.data._types import DataDict
|
12 |
+
from core.data.musdb.dataset import MUSDB18FullTrackDataset
|
13 |
+
import pyloudnorm as pyln
|
14 |
+
|
15 |
+
class SourceActivityDetector(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
analysis_stem: str,
|
19 |
+
output_path: str,
|
20 |
+
fs: int = 44100,
|
21 |
+
segment_length_second: float = 6.0,
|
22 |
+
hop_length_second: float = 3.0,
|
23 |
+
n_chunks: int = 10,
|
24 |
+
chunk_epsilon: float = 1e-5,
|
25 |
+
energy_threshold_quantile: float = 0.15,
|
26 |
+
segment_epsilon: float = 1e-3,
|
27 |
+
salient_proportion_threshold: float = 0.5,
|
28 |
+
target_lufs: float = -24
|
29 |
+
) -> None:
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.fs = fs
|
33 |
+
self.segment_length = int(segment_length_second * self.fs)
|
34 |
+
self.hop_length = int(hop_length_second * self.fs)
|
35 |
+
self.n_chunks = n_chunks
|
36 |
+
assert self.segment_length % self.n_chunks == 0
|
37 |
+
self.chunk_size = self.segment_length // self.n_chunks
|
38 |
+
self.chunk_epsilon = chunk_epsilon
|
39 |
+
self.energy_threshold_quantile = energy_threshold_quantile
|
40 |
+
self.segment_epsilon = segment_epsilon
|
41 |
+
self.salient_proportion_threshold = salient_proportion_threshold
|
42 |
+
self.analysis_stem = analysis_stem
|
43 |
+
|
44 |
+
self.meter = pyln.Meter(self.fs)
|
45 |
+
self.target_lufs = target_lufs
|
46 |
+
|
47 |
+
self.output_path = output_path
|
48 |
+
|
49 |
+
def forward(self, data: DataDict) -> None:
|
50 |
+
|
51 |
+
stem_ = self.analysis_stem if (
|
52 |
+
self.analysis_stem != "none") else "mixture"
|
53 |
+
|
54 |
+
x = data["audio"][stem_]
|
55 |
+
|
56 |
+
xnp = x.numpy()
|
57 |
+
loudness = self.meter.integrated_loudness(xnp.T)
|
58 |
+
|
59 |
+
for stem in data["audio"]:
|
60 |
+
s = data["audio"][stem]
|
61 |
+
s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
|
62 |
+
s = torch.as_tensor(s)
|
63 |
+
data["audio"][stem] = s
|
64 |
+
|
65 |
+
if x.ndim == 3:
|
66 |
+
assert x.shape[0] == 1
|
67 |
+
x = x[0]
|
68 |
+
|
69 |
+
n_chan, n_samples = x.shape
|
70 |
+
|
71 |
+
n_segments = (
|
72 |
+
int(
|
73 |
+
np.ceil((n_samples - self.segment_length) / self.hop_length)
|
74 |
+
) + 1
|
75 |
+
)
|
76 |
+
|
77 |
+
segments = torch.zeros((n_segments, n_chan, self.segment_length))
|
78 |
+
for i in range(n_segments):
|
79 |
+
start = i * self.hop_length
|
80 |
+
end = start + self.segment_length
|
81 |
+
end = min(end, n_samples)
|
82 |
+
|
83 |
+
xseg = x[:, start:end]
|
84 |
+
|
85 |
+
if end - start < self.segment_length:
|
86 |
+
xseg = F.pad(
|
87 |
+
xseg,
|
88 |
+
pad=(0, self.segment_length - (end - start)),
|
89 |
+
value=torch.nan
|
90 |
+
)
|
91 |
+
|
92 |
+
segments[i, :, :] = xseg
|
93 |
+
|
94 |
+
chunks = segments.reshape(
|
95 |
+
(n_segments, n_chan, self.n_chunks, self.chunk_size)
|
96 |
+
)
|
97 |
+
|
98 |
+
if self.analysis_stem != "none":
|
99 |
+
chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
|
100 |
+
chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
|
101 |
+
chunk_energies[chunk_energies == 0] = self.chunk_epsilon
|
102 |
+
|
103 |
+
energy_threshold = torch.nanquantile(
|
104 |
+
chunk_energies, q=self.energy_threshold_quantile
|
105 |
+
)
|
106 |
+
|
107 |
+
if energy_threshold < self.segment_epsilon:
|
108 |
+
energy_threshold = self.segment_epsilon # type: ignore[assignment]
|
109 |
+
|
110 |
+
chunks_above_threshold = chunk_energies > energy_threshold
|
111 |
+
n_chunks_above_threshold = torch.mean(
|
112 |
+
chunks_above_threshold.to(torch.float), dim=-1
|
113 |
+
)
|
114 |
+
|
115 |
+
segment_above_threshold = (
|
116 |
+
n_chunks_above_threshold > self.salient_proportion_threshold
|
117 |
+
)
|
118 |
+
|
119 |
+
if torch.sum(segment_above_threshold) == 0:
|
120 |
+
return
|
121 |
+
|
122 |
+
else:
|
123 |
+
segment_above_threshold = torch.ones((n_segments,))
|
124 |
+
|
125 |
+
for i in range(n_segments):
|
126 |
+
if not segment_above_threshold[i]:
|
127 |
+
continue
|
128 |
+
|
129 |
+
outpath = os.path.join(
|
130 |
+
self.output_path,
|
131 |
+
self.analysis_stem,
|
132 |
+
f"{data['track']} - {self.analysis_stem}{i:03d}",
|
133 |
+
)
|
134 |
+
os.makedirs(outpath, exist_ok=True)
|
135 |
+
|
136 |
+
for stem in data["audio"]:
|
137 |
+
if stem == self.analysis_stem:
|
138 |
+
segment = torch.nan_to_num(segments[i, :, :], nan=0)
|
139 |
+
else:
|
140 |
+
start = i * self.hop_length
|
141 |
+
end = start + self.segment_length
|
142 |
+
end = min(n_samples, end)
|
143 |
+
|
144 |
+
segment = data["audio"][stem][:, start:end]
|
145 |
+
|
146 |
+
if end - start < self.segment_length:
|
147 |
+
segment = F.pad(
|
148 |
+
segment,
|
149 |
+
(0, self.segment_length - (end - start))
|
150 |
+
)
|
151 |
+
|
152 |
+
assert segment.shape[-1] == self.segment_length, segment.shape
|
153 |
+
|
154 |
+
# ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs)
|
155 |
+
|
156 |
+
np.save(os.path.join(outpath, f"{stem}.wav"), segment)
|
157 |
+
|
158 |
+
|
159 |
+
def preprocess(
|
160 |
+
analysis_stem: str,
|
161 |
+
output_path: str = "/data/MUSDB18/HQ/saded-np",
|
162 |
+
fs: int = 44100,
|
163 |
+
segment_length_second: float = 6.0,
|
164 |
+
hop_length_second: float = 3.0,
|
165 |
+
n_chunks: int = 10,
|
166 |
+
chunk_epsilon: float = 1e-5,
|
167 |
+
energy_threshold_quantile: float = 0.15,
|
168 |
+
segment_epsilon: float = 1e-3,
|
169 |
+
salient_proportion_threshold: float = 0.5,
|
170 |
+
) -> None:
|
171 |
+
|
172 |
+
sad = SourceActivityDetector(
|
173 |
+
analysis_stem=analysis_stem,
|
174 |
+
output_path=output_path,
|
175 |
+
fs=fs,
|
176 |
+
segment_length_second=segment_length_second,
|
177 |
+
hop_length_second=hop_length_second,
|
178 |
+
n_chunks=n_chunks,
|
179 |
+
chunk_epsilon=chunk_epsilon,
|
180 |
+
energy_threshold_quantile=energy_threshold_quantile,
|
181 |
+
segment_epsilon=segment_epsilon,
|
182 |
+
salient_proportion_threshold=salient_proportion_threshold,
|
183 |
+
)
|
184 |
+
|
185 |
+
for split in ["train", "val", "test"]:
|
186 |
+
ds = MUSDB18FullTrackDataset(
|
187 |
+
data_root="/data/MUSDB18/HQ/canonical",
|
188 |
+
split=split,
|
189 |
+
)
|
190 |
+
|
191 |
+
tracks = []
|
192 |
+
for i, track in enumerate(tqdm(ds, total=len(ds))):
|
193 |
+
if i % 32 == 0 and tracks:
|
194 |
+
process_map(sad, tracks, max_workers=8)
|
195 |
+
tracks = []
|
196 |
+
tracks.append(track)
|
197 |
+
process_map(sad, tracks, max_workers=8)
|
198 |
+
|
199 |
+
def loudness_norm_one(
|
200 |
+
inputs
|
201 |
+
):
|
202 |
+
infile, outfile, target_lufs = inputs
|
203 |
+
|
204 |
+
audio, fs = ta.load(infile)
|
205 |
+
audio = audio.mean(dim=0, keepdim=True).numpy().T
|
206 |
+
|
207 |
+
meter = pyln.Meter(fs)
|
208 |
+
loudness = meter.integrated_loudness(audio)
|
209 |
+
audio = pyln.normalize.loudness(audio, loudness, target_lufs)
|
210 |
+
|
211 |
+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
|
212 |
+
np.save(outfile, audio.T)
|
213 |
+
|
214 |
+
def loudness_norm(
|
215 |
+
data_path: str,
|
216 |
+
# output_path: str,
|
217 |
+
target_lufs = -17.0,
|
218 |
+
):
|
219 |
+
files = glob.glob(
|
220 |
+
os.path.join(data_path, "**", "*.wav"), recursive=True
|
221 |
+
)
|
222 |
+
|
223 |
+
outfiles = [
|
224 |
+
f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files
|
225 |
+
]
|
226 |
+
|
227 |
+
files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
|
228 |
+
|
229 |
+
process_map(loudness_norm_one, files, chunksize=2)
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
if __name__ == "__main__":
|
234 |
+
|
235 |
+
from tqdm import tqdm
|
236 |
+
import fire
|
237 |
+
|
238 |
+
fire.Fire()
|
models/bandit/core/data/musdb/validation.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
validation:
|
2 |
+
- 'Actions - One Minute Smile'
|
3 |
+
- 'Clara Berry And Wooldog - Waltz For My Victims'
|
4 |
+
- 'Johnny Lokke - Promises & Lies'
|
5 |
+
- 'Patrick Talbot - A Reason To Leave'
|
6 |
+
- 'Triviul - Angelsaint'
|
7 |
+
- 'Alexander Ross - Goodbye Bolero'
|
8 |
+
- 'Fergessen - Nos Palpitants'
|
9 |
+
- 'Leaf - Summerghost'
|
10 |
+
- 'Skelpolu - Human Mistakes'
|
11 |
+
- 'Young Griffo - Pennies'
|
12 |
+
- 'ANiMAL - Rockshow'
|
13 |
+
- 'James May - On The Line'
|
14 |
+
- 'Meaxic - Take A Step'
|
15 |
+
- 'Traffic Experiment - Sirens'
|
models/bandit/core/loss/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from ._multistem import MultiStemWrapperFromConfig
|
2 |
+
from ._timefreq import ReImL1Loss, ReImL2Loss, TimeFreqL1Loss, TimeFreqL2Loss, TimeFreqSignalNoisePNormRatioLoss
|
models/bandit/core/loss/_complex.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.modules import loss as _loss
|
6 |
+
from torch.nn.modules.loss import _Loss
|
7 |
+
|
8 |
+
|
9 |
+
class ReImLossWrapper(_Loss):
|
10 |
+
def __init__(self, module: _Loss) -> None:
|
11 |
+
super().__init__()
|
12 |
+
self.module = module
|
13 |
+
|
14 |
+
def forward(
|
15 |
+
self,
|
16 |
+
preds: torch.Tensor,
|
17 |
+
target: torch.Tensor
|
18 |
+
) -> torch.Tensor:
|
19 |
+
return self.module(
|
20 |
+
torch.view_as_real(preds),
|
21 |
+
torch.view_as_real(target)
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class ReImL1Loss(ReImLossWrapper):
|
26 |
+
def __init__(self, **kwargs: Any) -> None:
|
27 |
+
l1_loss = _loss.L1Loss(**kwargs)
|
28 |
+
super().__init__(module=(l1_loss))
|
29 |
+
|
30 |
+
|
31 |
+
class ReImL2Loss(ReImLossWrapper):
|
32 |
+
def __init__(self, **kwargs: Any) -> None:
|
33 |
+
l2_loss = _loss.MSELoss(**kwargs)
|
34 |
+
super().__init__(module=(l2_loss))
|
models/bandit/core/loss/_multistem.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from asteroid import losses as asteroid_losses
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn.modules.loss import _Loss
|
7 |
+
|
8 |
+
from . import snr
|
9 |
+
|
10 |
+
|
11 |
+
def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
|
12 |
+
|
13 |
+
for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
|
14 |
+
if name in module.__dict__:
|
15 |
+
return module.__dict__[name](**kwargs)
|
16 |
+
|
17 |
+
raise NameError
|
18 |
+
|
19 |
+
|
20 |
+
class MultiStemWrapper(_Loss):
|
21 |
+
def __init__(self, module: _Loss, modality: str = "audio") -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.loss = module
|
24 |
+
self.modality = modality
|
25 |
+
|
26 |
+
def forward(
|
27 |
+
self,
|
28 |
+
preds: Dict[str, Dict[str, torch.Tensor]],
|
29 |
+
target: Dict[str, Dict[str, torch.Tensor]],
|
30 |
+
) -> torch.Tensor:
|
31 |
+
loss = {
|
32 |
+
stem: self.loss(
|
33 |
+
preds[self.modality][stem],
|
34 |
+
target[self.modality][stem]
|
35 |
+
)
|
36 |
+
for stem in preds[self.modality] if stem in target[self.modality]
|
37 |
+
}
|
38 |
+
|
39 |
+
return sum(list(loss.values()))
|
40 |
+
|
41 |
+
|
42 |
+
class MultiStemWrapperFromConfig(MultiStemWrapper):
|
43 |
+
def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
|
44 |
+
loss = parse_loss(name, kwargs)
|
45 |
+
super().__init__(module=loss, modality=modality)
|
models/bandit/core/loss/_timefreq.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.modules.loss import _Loss
|
6 |
+
|
7 |
+
from models.bandit.core.loss._multistem import MultiStemWrapper
|
8 |
+
from models.bandit.core.loss._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
|
9 |
+
from models.bandit.core.loss.snr import SignalNoisePNormRatio
|
10 |
+
|
11 |
+
class TimeFreqWrapper(_Loss):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
time_module: _Loss,
|
15 |
+
freq_module: Optional[_Loss] = None,
|
16 |
+
time_weight: float = 1.0,
|
17 |
+
freq_weight: float = 1.0,
|
18 |
+
multistem: bool = True,
|
19 |
+
) -> None:
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
if freq_module is None:
|
23 |
+
freq_module = time_module
|
24 |
+
|
25 |
+
if multistem:
|
26 |
+
time_module = MultiStemWrapper(time_module, modality="audio")
|
27 |
+
freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
|
28 |
+
|
29 |
+
self.time_module = time_module
|
30 |
+
self.freq_module = freq_module
|
31 |
+
|
32 |
+
self.time_weight = time_weight
|
33 |
+
self.freq_weight = freq_weight
|
34 |
+
|
35 |
+
# TODO: add better type hints
|
36 |
+
def forward(self, preds: Any, target: Any) -> torch.Tensor:
|
37 |
+
|
38 |
+
return self.time_weight * self.time_module(
|
39 |
+
preds, target
|
40 |
+
) + self.freq_weight * self.freq_module(preds, target)
|
41 |
+
|
42 |
+
|
43 |
+
class TimeFreqL1Loss(TimeFreqWrapper):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
time_weight: float = 1.0,
|
47 |
+
freq_weight: float = 1.0,
|
48 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
49 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
50 |
+
multistem: bool = True,
|
51 |
+
) -> None:
|
52 |
+
if tkwargs is None:
|
53 |
+
tkwargs = {}
|
54 |
+
if fkwargs is None:
|
55 |
+
fkwargs = {}
|
56 |
+
time_module = (nn.L1Loss(**tkwargs))
|
57 |
+
freq_module = ReImL1Loss(**fkwargs)
|
58 |
+
super().__init__(
|
59 |
+
time_module,
|
60 |
+
freq_module,
|
61 |
+
time_weight,
|
62 |
+
freq_weight,
|
63 |
+
multistem
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
class TimeFreqL2Loss(TimeFreqWrapper):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
time_weight: float = 1.0,
|
71 |
+
freq_weight: float = 1.0,
|
72 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
73 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
74 |
+
multistem: bool = True,
|
75 |
+
) -> None:
|
76 |
+
if tkwargs is None:
|
77 |
+
tkwargs = {}
|
78 |
+
if fkwargs is None:
|
79 |
+
fkwargs = {}
|
80 |
+
time_module = nn.MSELoss(**tkwargs)
|
81 |
+
freq_module = ReImL2Loss(**fkwargs)
|
82 |
+
super().__init__(
|
83 |
+
time_module,
|
84 |
+
freq_module,
|
85 |
+
time_weight,
|
86 |
+
freq_weight,
|
87 |
+
multistem
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
time_weight: float = 1.0,
|
96 |
+
freq_weight: float = 1.0,
|
97 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
98 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
99 |
+
multistem: bool = True,
|
100 |
+
) -> None:
|
101 |
+
if tkwargs is None:
|
102 |
+
tkwargs = {}
|
103 |
+
if fkwargs is None:
|
104 |
+
fkwargs = {}
|
105 |
+
time_module = SignalNoisePNormRatio(**tkwargs)
|
106 |
+
freq_module = SignalNoisePNormRatio(**fkwargs)
|
107 |
+
super().__init__(
|
108 |
+
time_module,
|
109 |
+
freq_module,
|
110 |
+
time_weight,
|
111 |
+
freq_weight,
|
112 |
+
multistem
|
113 |
+
)
|
models/bandit/core/loss/snr.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.modules.loss import _Loss
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class SignalNoisePNormRatio(_Loss):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
p: float = 1.0,
|
9 |
+
scale_invariant: bool = False,
|
10 |
+
zero_mean: bool = False,
|
11 |
+
take_log: bool = True,
|
12 |
+
reduction: str = "mean",
|
13 |
+
EPS: float = 1e-3,
|
14 |
+
) -> None:
|
15 |
+
assert reduction != "sum", NotImplementedError
|
16 |
+
super().__init__(reduction=reduction)
|
17 |
+
assert not zero_mean
|
18 |
+
|
19 |
+
self.p = p
|
20 |
+
|
21 |
+
self.EPS = EPS
|
22 |
+
self.take_log = take_log
|
23 |
+
|
24 |
+
self.scale_invariant = scale_invariant
|
25 |
+
|
26 |
+
def forward(
|
27 |
+
self,
|
28 |
+
est_target: torch.Tensor,
|
29 |
+
target: torch.Tensor
|
30 |
+
) -> torch.Tensor:
|
31 |
+
|
32 |
+
target_ = target
|
33 |
+
if self.scale_invariant:
|
34 |
+
ndim = target.ndim
|
35 |
+
dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
|
36 |
+
s_target_energy = (
|
37 |
+
torch.sum(target * torch.conj(target), dim=-1, keepdim=True)
|
38 |
+
)
|
39 |
+
|
40 |
+
if ndim > 2:
|
41 |
+
dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
|
42 |
+
s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True)
|
43 |
+
|
44 |
+
target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
|
45 |
+
target = target_ * target_scaler
|
46 |
+
|
47 |
+
if torch.is_complex(est_target):
|
48 |
+
est_target = torch.view_as_real(est_target)
|
49 |
+
target = torch.view_as_real(target)
|
50 |
+
|
51 |
+
|
52 |
+
batch_size = est_target.shape[0]
|
53 |
+
est_target = est_target.reshape(batch_size, -1)
|
54 |
+
target = target.reshape(batch_size, -1)
|
55 |
+
# target_ = target_.reshape(batch_size, -1)
|
56 |
+
|
57 |
+
if self.p == 1:
|
58 |
+
e_error = torch.abs(est_target-target).mean(dim=-1)
|
59 |
+
e_target = torch.abs(target).mean(dim=-1)
|
60 |
+
elif self.p == 2:
|
61 |
+
e_error = torch.square(est_target-target).mean(dim=-1)
|
62 |
+
e_target = torch.square(target).mean(dim=-1)
|
63 |
+
else:
|
64 |
+
raise NotImplementedError
|
65 |
+
|
66 |
+
if self.take_log:
|
67 |
+
loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS))
|
68 |
+
else:
|
69 |
+
loss = (e_error + self.EPS)/(e_target + self.EPS)
|
70 |
+
|
71 |
+
if self.reduction == "mean":
|
72 |
+
loss = loss.mean()
|
73 |
+
elif self.reduction == "sum":
|
74 |
+
loss = loss.sum()
|
75 |
+
|
76 |
+
return loss
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
class MultichannelSingleSrcNegSDR(_Loss):
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
sdr_type: str,
|
84 |
+
p: float = 2.0,
|
85 |
+
zero_mean: bool = True,
|
86 |
+
take_log: bool = True,
|
87 |
+
reduction: str = "mean",
|
88 |
+
EPS: float = 1e-8,
|
89 |
+
) -> None:
|
90 |
+
assert reduction != "sum", NotImplementedError
|
91 |
+
super().__init__(reduction=reduction)
|
92 |
+
|
93 |
+
assert sdr_type in ["snr", "sisdr", "sdsdr"]
|
94 |
+
self.sdr_type = sdr_type
|
95 |
+
self.zero_mean = zero_mean
|
96 |
+
self.take_log = take_log
|
97 |
+
self.EPS = 1e-8
|
98 |
+
|
99 |
+
self.p = p
|
100 |
+
|
101 |
+
def forward(
|
102 |
+
self,
|
103 |
+
est_target: torch.Tensor,
|
104 |
+
target: torch.Tensor
|
105 |
+
) -> torch.Tensor:
|
106 |
+
if target.size() != est_target.size() or target.ndim != 3:
|
107 |
+
raise TypeError(
|
108 |
+
f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
|
109 |
+
)
|
110 |
+
# Step 1. Zero-mean norm
|
111 |
+
if self.zero_mean:
|
112 |
+
mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
|
113 |
+
mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
|
114 |
+
target = target - mean_source
|
115 |
+
est_target = est_target - mean_estimate
|
116 |
+
# Step 2. Pair-wise SI-SDR.
|
117 |
+
if self.sdr_type in ["sisdr", "sdsdr"]:
|
118 |
+
# [batch, 1]
|
119 |
+
dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
|
120 |
+
# [batch, 1]
|
121 |
+
s_target_energy = (
|
122 |
+
torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS
|
123 |
+
)
|
124 |
+
# [batch, time]
|
125 |
+
scaled_target = dot * target / s_target_energy
|
126 |
+
else:
|
127 |
+
# [batch, time]
|
128 |
+
scaled_target = target
|
129 |
+
if self.sdr_type in ["sdsdr", "snr"]:
|
130 |
+
e_noise = est_target - target
|
131 |
+
else:
|
132 |
+
e_noise = est_target - scaled_target
|
133 |
+
# [batch]
|
134 |
+
|
135 |
+
if self.p == 2.0:
|
136 |
+
losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / (
|
137 |
+
torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
|
141 |
+
torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
|
142 |
+
)
|
143 |
+
if self.take_log:
|
144 |
+
losses = 10 * torch.log10(losses + self.EPS)
|
145 |
+
losses = losses.mean() if self.reduction == "mean" else losses
|
146 |
+
return -losses
|
models/bandit/core/metrics/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .snr import (
|
2 |
+
ChunkMedianScaleInvariantSignalDistortionRatio,
|
3 |
+
ChunkMedianScaleInvariantSignalNoiseRatio,
|
4 |
+
ChunkMedianSignalDistortionRatio,
|
5 |
+
ChunkMedianSignalNoiseRatio,
|
6 |
+
SafeSignalDistortionRatio,
|
7 |
+
)
|
8 |
+
|
9 |
+
# from .mushra import EstimatedMushraScore
|
models/bandit/core/metrics/_squim.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from torchaudio._internal import load_state_dict_from_url
|
4 |
+
|
5 |
+
import math
|
6 |
+
from typing import List, Optional, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
def transform_wb_pesq_range(x: float) -> float:
|
14 |
+
"""The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
|
15 |
+
for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
|
16 |
+
defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
|
17 |
+
|
18 |
+
Args:
|
19 |
+
x (float): Narrow-band PESQ score.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
(float): Wide-band PESQ score.
|
23 |
+
"""
|
24 |
+
return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
|
25 |
+
|
26 |
+
|
27 |
+
PESQRange: Tuple[float, float] = (
|
28 |
+
1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
|
29 |
+
# the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
|
30 |
+
# We are using 1.0 as a reasonable approximation.
|
31 |
+
transform_wb_pesq_range(4.5),
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class RangeSigmoid(nn.Module):
|
36 |
+
def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
|
37 |
+
super(RangeSigmoid, self).__init__()
|
38 |
+
assert isinstance(val_range, tuple) and len(val_range) == 2
|
39 |
+
self.val_range: Tuple[float, float] = val_range
|
40 |
+
self.sigmoid: nn.modules.Module = nn.Sigmoid()
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
43 |
+
out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
|
44 |
+
return out
|
45 |
+
|
46 |
+
|
47 |
+
class Encoder(nn.Module):
|
48 |
+
"""Encoder module that transform 1D waveform to 2D representations.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
|
52 |
+
win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
|
56 |
+
super(Encoder, self).__init__()
|
57 |
+
|
58 |
+
self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
|
59 |
+
|
60 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
61 |
+
"""Apply waveforms to convolutional layer and ReLU layer.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
(torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
|
68 |
+
"""
|
69 |
+
out = x.unsqueeze(dim=1)
|
70 |
+
out = F.relu(self.conv1d(out))
|
71 |
+
return out
|
72 |
+
|
73 |
+
|
74 |
+
class SingleRNN(nn.Module):
|
75 |
+
def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
|
76 |
+
super(SingleRNN, self).__init__()
|
77 |
+
|
78 |
+
self.rnn_type = rnn_type
|
79 |
+
self.input_size = input_size
|
80 |
+
self.hidden_size = hidden_size
|
81 |
+
|
82 |
+
self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
|
83 |
+
input_size,
|
84 |
+
hidden_size,
|
85 |
+
1,
|
86 |
+
dropout=dropout,
|
87 |
+
batch_first=True,
|
88 |
+
bidirectional=True,
|
89 |
+
)
|
90 |
+
|
91 |
+
self.proj = nn.Linear(hidden_size * 2, input_size)
|
92 |
+
|
93 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
94 |
+
# input shape: batch, seq, dim
|
95 |
+
out, _ = self.rnn(x)
|
96 |
+
out = self.proj(out)
|
97 |
+
return out
|
98 |
+
|
99 |
+
|
100 |
+
class DPRNN(nn.Module):
|
101 |
+
"""*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
|
105 |
+
hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
|
106 |
+
num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
|
107 |
+
rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
|
108 |
+
d_model (int, optional): The number of expected features in the input. (Default: 256)
|
109 |
+
chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
|
110 |
+
chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
feat_dim: int = 64,
|
116 |
+
hidden_dim: int = 128,
|
117 |
+
num_blocks: int = 6,
|
118 |
+
rnn_type: str = "LSTM",
|
119 |
+
d_model: int = 256,
|
120 |
+
chunk_size: int = 100,
|
121 |
+
chunk_stride: int = 50,
|
122 |
+
) -> None:
|
123 |
+
super(DPRNN, self).__init__()
|
124 |
+
|
125 |
+
self.num_blocks = num_blocks
|
126 |
+
|
127 |
+
self.row_rnn = nn.ModuleList([])
|
128 |
+
self.col_rnn = nn.ModuleList([])
|
129 |
+
self.row_norm = nn.ModuleList([])
|
130 |
+
self.col_norm = nn.ModuleList([])
|
131 |
+
for _ in range(num_blocks):
|
132 |
+
self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
133 |
+
self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
134 |
+
self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
135 |
+
self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
136 |
+
self.conv = nn.Sequential(
|
137 |
+
nn.Conv2d(feat_dim, d_model, 1),
|
138 |
+
nn.PReLU(),
|
139 |
+
)
|
140 |
+
self.chunk_size = chunk_size
|
141 |
+
self.chunk_stride = chunk_stride
|
142 |
+
|
143 |
+
def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
144 |
+
# input shape: (B, N, T)
|
145 |
+
seq_len = x.shape[-1]
|
146 |
+
|
147 |
+
rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
|
148 |
+
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
|
149 |
+
|
150 |
+
return out, rest
|
151 |
+
|
152 |
+
def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
153 |
+
out, rest = self.pad_chunk(x)
|
154 |
+
batch_size, feat_dim, seq_len = out.shape
|
155 |
+
|
156 |
+
segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
|
157 |
+
segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
|
158 |
+
out = torch.cat([segments1, segments2], dim=3)
|
159 |
+
out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
|
160 |
+
|
161 |
+
return out, rest
|
162 |
+
|
163 |
+
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
|
164 |
+
batch_size, dim, _, _ = x.shape
|
165 |
+
out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
|
166 |
+
out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
|
167 |
+
out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
|
168 |
+
out = out1 + out2
|
169 |
+
if rest > 0:
|
170 |
+
out = out[:, :, :-rest]
|
171 |
+
out = out.contiguous()
|
172 |
+
return out
|
173 |
+
|
174 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
175 |
+
x, rest = self.chunking(x)
|
176 |
+
batch_size, _, dim1, dim2 = x.shape
|
177 |
+
out = x
|
178 |
+
for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
|
179 |
+
row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
|
180 |
+
row_out = row_rnn(row_in)
|
181 |
+
row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
|
182 |
+
row_out = row_norm(row_out)
|
183 |
+
out = out + row_out
|
184 |
+
|
185 |
+
col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
|
186 |
+
col_out = col_rnn(col_in)
|
187 |
+
col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
|
188 |
+
col_out = col_norm(col_out)
|
189 |
+
out = out + col_out
|
190 |
+
out = self.conv(out)
|
191 |
+
out = self.merging(out, rest)
|
192 |
+
out = out.transpose(1, 2).contiguous()
|
193 |
+
return out
|
194 |
+
|
195 |
+
|
196 |
+
class AutoPool(nn.Module):
|
197 |
+
def __init__(self, pool_dim: int = 1) -> None:
|
198 |
+
super(AutoPool, self).__init__()
|
199 |
+
self.pool_dim: int = pool_dim
|
200 |
+
self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
|
201 |
+
self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
|
202 |
+
|
203 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
204 |
+
weight = self.softmax(torch.mul(x, self.alpha))
|
205 |
+
out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
|
206 |
+
return out
|
207 |
+
|
208 |
+
|
209 |
+
class SquimObjective(nn.Module):
|
210 |
+
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
|
211 |
+
for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
|
212 |
+
|
213 |
+
Args:
|
214 |
+
encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
|
215 |
+
dprnn (torch.nn.Module): DPRNN module to model sequential feature.
|
216 |
+
branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(
|
220 |
+
self,
|
221 |
+
encoder: nn.Module,
|
222 |
+
dprnn: nn.Module,
|
223 |
+
branches: nn.ModuleList,
|
224 |
+
):
|
225 |
+
super(SquimObjective, self).__init__()
|
226 |
+
self.encoder = encoder
|
227 |
+
self.dprnn = dprnn
|
228 |
+
self.branches = branches
|
229 |
+
|
230 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
231 |
+
"""
|
232 |
+
Args:
|
233 |
+
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
|
237 |
+
"""
|
238 |
+
if x.ndim != 2:
|
239 |
+
raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
|
240 |
+
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
|
241 |
+
out = self.encoder(x)
|
242 |
+
out = self.dprnn(out)
|
243 |
+
scores = []
|
244 |
+
for branch in self.branches:
|
245 |
+
scores.append(branch(out).squeeze(dim=1))
|
246 |
+
return scores
|
247 |
+
|
248 |
+
|
249 |
+
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
|
250 |
+
"""Create branch module after DPRNN model for predicting metric score.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
d_model (int): The number of expected features in the input.
|
254 |
+
nhead (int): Number of heads in the multi-head attention model.
|
255 |
+
metric (str): The metric name to predict.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
(nn.Module): Returned module to predict corresponding metric score.
|
259 |
+
"""
|
260 |
+
layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
|
261 |
+
layer2 = AutoPool()
|
262 |
+
if metric == "stoi":
|
263 |
+
layer3 = nn.Sequential(
|
264 |
+
nn.Linear(d_model, d_model),
|
265 |
+
nn.PReLU(),
|
266 |
+
nn.Linear(d_model, 1),
|
267 |
+
RangeSigmoid(),
|
268 |
+
)
|
269 |
+
elif metric == "pesq":
|
270 |
+
layer3 = nn.Sequential(
|
271 |
+
nn.Linear(d_model, d_model),
|
272 |
+
nn.PReLU(),
|
273 |
+
nn.Linear(d_model, 1),
|
274 |
+
RangeSigmoid(val_range=PESQRange),
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
|
278 |
+
return nn.Sequential(layer1, layer2, layer3)
|
279 |
+
|
280 |
+
|
281 |
+
def squim_objective_model(
|
282 |
+
feat_dim: int,
|
283 |
+
win_len: int,
|
284 |
+
d_model: int,
|
285 |
+
nhead: int,
|
286 |
+
hidden_dim: int,
|
287 |
+
num_blocks: int,
|
288 |
+
rnn_type: str,
|
289 |
+
chunk_size: int,
|
290 |
+
chunk_stride: Optional[int] = None,
|
291 |
+
) -> SquimObjective:
|
292 |
+
"""Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
feat_dim (int, optional): The feature dimension after Encoder module.
|
296 |
+
win_len (int): Kernel size in the Encoder module.
|
297 |
+
d_model (int): The number of expected features in the input.
|
298 |
+
nhead (int): Number of heads in the multi-head attention model.
|
299 |
+
hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
|
300 |
+
num_blocks (int): Number of DPRNN layers.
|
301 |
+
rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
|
302 |
+
chunk_size (int): Chunk size of input for DPRNN.
|
303 |
+
chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
|
304 |
+
"""
|
305 |
+
if chunk_stride is None:
|
306 |
+
chunk_stride = chunk_size // 2
|
307 |
+
encoder = Encoder(feat_dim, win_len)
|
308 |
+
dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
|
309 |
+
branches = nn.ModuleList(
|
310 |
+
[
|
311 |
+
_create_branch(d_model, nhead, "stoi"),
|
312 |
+
_create_branch(d_model, nhead, "pesq"),
|
313 |
+
_create_branch(d_model, nhead, "sisdr"),
|
314 |
+
]
|
315 |
+
)
|
316 |
+
return SquimObjective(encoder, dprnn, branches)
|
317 |
+
|
318 |
+
|
319 |
+
def squim_objective_base() -> SquimObjective:
|
320 |
+
"""Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
|
321 |
+
return squim_objective_model(
|
322 |
+
feat_dim=256,
|
323 |
+
win_len=64,
|
324 |
+
d_model=256,
|
325 |
+
nhead=4,
|
326 |
+
hidden_dim=256,
|
327 |
+
num_blocks=2,
|
328 |
+
rnn_type="LSTM",
|
329 |
+
chunk_size=71,
|
330 |
+
)
|
331 |
+
|
332 |
+
@dataclass
|
333 |
+
class SquimObjectiveBundle:
|
334 |
+
|
335 |
+
_path: str
|
336 |
+
_sample_rate: float
|
337 |
+
|
338 |
+
def _get_state_dict(self, dl_kwargs):
|
339 |
+
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
|
340 |
+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
341 |
+
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
342 |
+
return state_dict
|
343 |
+
|
344 |
+
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
|
345 |
+
"""Construct the SquimObjective model, and load the pretrained weight.
|
346 |
+
|
347 |
+
The weight file is downloaded from the internet and cached with
|
348 |
+
:func:`torch.hub.load_state_dict_from_url`
|
349 |
+
|
350 |
+
Args:
|
351 |
+
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
|
352 |
+
|
353 |
+
Returns:
|
354 |
+
Variation of :py:class:`~torchaudio.models.SquimObjective`.
|
355 |
+
"""
|
356 |
+
model = squim_objective_base()
|
357 |
+
model.load_state_dict(self._get_state_dict(dl_kwargs))
|
358 |
+
model.eval()
|
359 |
+
return model
|
360 |
+
|
361 |
+
@property
|
362 |
+
def sample_rate(self):
|
363 |
+
"""Sample rate of the audio that the model is trained on.
|
364 |
+
|
365 |
+
:type: float
|
366 |
+
"""
|
367 |
+
return self._sample_rate
|
368 |
+
|
369 |
+
|
370 |
+
SQUIM_OBJECTIVE = SquimObjectiveBundle(
|
371 |
+
"squim_objective_dns2020.pth",
|
372 |
+
_sample_rate=16000,
|
373 |
+
)
|
374 |
+
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
|
375 |
+
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
|
376 |
+
|
377 |
+
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
|
378 |
+
The weights are under `Creative Commons Attribution 4.0 International License
|
379 |
+
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
|
380 |
+
|
381 |
+
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
|
382 |
+
"""
|
383 |
+
|
models/bandit/core/metrics/snr.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchmetrics as tm
|
6 |
+
from torch._C import _LinAlgError
|
7 |
+
from torchmetrics import functional as tmF
|
8 |
+
|
9 |
+
|
10 |
+
class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
|
11 |
+
def __init__(self, **kwargs) -> None:
|
12 |
+
super().__init__(**kwargs)
|
13 |
+
|
14 |
+
def update(self, *args, **kwargs) -> Any:
|
15 |
+
try:
|
16 |
+
super().update(*args, **kwargs)
|
17 |
+
except:
|
18 |
+
pass
|
19 |
+
|
20 |
+
def compute(self) -> Any:
|
21 |
+
if self.total == 0:
|
22 |
+
return torch.tensor(torch.nan)
|
23 |
+
return super().compute()
|
24 |
+
|
25 |
+
|
26 |
+
class BaseChunkMedianSignalRatio(tm.Metric):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
func: Callable,
|
30 |
+
window_size: int,
|
31 |
+
hop_size: int = None,
|
32 |
+
zero_mean: bool = False,
|
33 |
+
) -> None:
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
# self.zero_mean = zero_mean
|
37 |
+
self.func = func
|
38 |
+
self.window_size = window_size
|
39 |
+
if hop_size is None:
|
40 |
+
hop_size = window_size
|
41 |
+
self.hop_size = hop_size
|
42 |
+
|
43 |
+
self.add_state(
|
44 |
+
"sum_snr",
|
45 |
+
default=torch.tensor(0.0),
|
46 |
+
dist_reduce_fx="sum"
|
47 |
+
)
|
48 |
+
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
49 |
+
|
50 |
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
51 |
+
|
52 |
+
n_samples = target.shape[-1]
|
53 |
+
|
54 |
+
n_chunks = int(
|
55 |
+
np.ceil((n_samples - self.window_size) / self.hop_size) + 1
|
56 |
+
)
|
57 |
+
|
58 |
+
snr_chunk = []
|
59 |
+
|
60 |
+
for i in range(n_chunks):
|
61 |
+
start = i * self.hop_size
|
62 |
+
|
63 |
+
if n_samples - start < self.window_size:
|
64 |
+
continue
|
65 |
+
|
66 |
+
end = start + self.window_size
|
67 |
+
|
68 |
+
try:
|
69 |
+
chunk_snr = self.func(
|
70 |
+
preds[..., start:end],
|
71 |
+
target[..., start:end]
|
72 |
+
)
|
73 |
+
|
74 |
+
# print(preds.shape, chunk_snr.shape)
|
75 |
+
|
76 |
+
if torch.all(torch.isfinite(chunk_snr)):
|
77 |
+
snr_chunk.append(chunk_snr)
|
78 |
+
except _LinAlgError:
|
79 |
+
pass
|
80 |
+
|
81 |
+
snr_chunk = torch.stack(snr_chunk, dim=-1)
|
82 |
+
snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
|
83 |
+
|
84 |
+
self.sum_snr += snr_batch.sum()
|
85 |
+
self.total += snr_batch.numel()
|
86 |
+
|
87 |
+
def compute(self) -> Any:
|
88 |
+
return self.sum_snr / self.total
|
89 |
+
|
90 |
+
|
91 |
+
class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
window_size: int,
|
95 |
+
hop_size: int = None,
|
96 |
+
zero_mean: bool = False
|
97 |
+
) -> None:
|
98 |
+
super().__init__(
|
99 |
+
func=tmF.signal_noise_ratio,
|
100 |
+
window_size=window_size,
|
101 |
+
hop_size=hop_size,
|
102 |
+
zero_mean=zero_mean,
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
window_size: int,
|
110 |
+
hop_size: int = None,
|
111 |
+
zero_mean: bool = False
|
112 |
+
) -> None:
|
113 |
+
super().__init__(
|
114 |
+
func=tmF.scale_invariant_signal_noise_ratio,
|
115 |
+
window_size=window_size,
|
116 |
+
hop_size=hop_size,
|
117 |
+
zero_mean=zero_mean,
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
window_size: int,
|
125 |
+
hop_size: int = None,
|
126 |
+
zero_mean: bool = False
|
127 |
+
) -> None:
|
128 |
+
super().__init__(
|
129 |
+
func=tmF.signal_distortion_ratio,
|
130 |
+
window_size=window_size,
|
131 |
+
hop_size=hop_size,
|
132 |
+
zero_mean=zero_mean,
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
class ChunkMedianScaleInvariantSignalDistortionRatio(
|
137 |
+
BaseChunkMedianSignalRatio
|
138 |
+
):
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
window_size: int,
|
142 |
+
hop_size: int = None,
|
143 |
+
zero_mean: bool = False
|
144 |
+
) -> None:
|
145 |
+
super().__init__(
|
146 |
+
func=tmF.scale_invariant_signal_distortion_ratio,
|
147 |
+
window_size=window_size,
|
148 |
+
hop_size=hop_size,
|
149 |
+
zero_mean=zero_mean,
|
150 |
+
)
|
models/bandit/core/model/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .bsrnn.wrapper import (
|
2 |
+
MultiMaskMultiSourceBandSplitRNNSimple,
|
3 |
+
)
|
models/bandit/core/model/_spectral.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchaudio as ta
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class _SpectralComponent(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
n_fft: int = 2048,
|
12 |
+
win_length: Optional[int] = 2048,
|
13 |
+
hop_length: int = 512,
|
14 |
+
window_fn: str = "hann_window",
|
15 |
+
wkwargs: Optional[Dict] = None,
|
16 |
+
power: Optional[int] = None,
|
17 |
+
center: bool = True,
|
18 |
+
normalized: bool = True,
|
19 |
+
pad_mode: str = "constant",
|
20 |
+
onesided: bool = True,
|
21 |
+
**kwargs,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
assert power is None
|
26 |
+
|
27 |
+
window_fn = torch.__dict__[window_fn]
|
28 |
+
|
29 |
+
self.stft = (
|
30 |
+
ta.transforms.Spectrogram(
|
31 |
+
n_fft=n_fft,
|
32 |
+
win_length=win_length,
|
33 |
+
hop_length=hop_length,
|
34 |
+
pad_mode=pad_mode,
|
35 |
+
pad=0,
|
36 |
+
window_fn=window_fn,
|
37 |
+
wkwargs=wkwargs,
|
38 |
+
power=power,
|
39 |
+
normalized=normalized,
|
40 |
+
center=center,
|
41 |
+
onesided=onesided,
|
42 |
+
)
|
43 |
+
)
|
44 |
+
|
45 |
+
self.istft = (
|
46 |
+
ta.transforms.InverseSpectrogram(
|
47 |
+
n_fft=n_fft,
|
48 |
+
win_length=win_length,
|
49 |
+
hop_length=hop_length,
|
50 |
+
pad_mode=pad_mode,
|
51 |
+
pad=0,
|
52 |
+
window_fn=window_fn,
|
53 |
+
wkwargs=wkwargs,
|
54 |
+
normalized=normalized,
|
55 |
+
center=center,
|
56 |
+
onesided=onesided,
|
57 |
+
)
|
58 |
+
)
|
models/bandit/core/model/bsrnn/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from typing import Iterable, Mapping, Union
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
|
7 |
+
from models.bandit.core.model.bsrnn.tfmodel import (
|
8 |
+
SeqBandModellingModule,
|
9 |
+
TransformerTimeFreqModule,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class BandsplitCoreBase(nn.Module, ABC):
|
14 |
+
band_split: nn.Module
|
15 |
+
tf_model: nn.Module
|
16 |
+
mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
|
17 |
+
|
18 |
+
def __init__(self) -> None:
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def mask(x, m):
|
23 |
+
return x * m
|
models/bandit/core/model/bsrnn/bandsplit.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from models.bandit.core.model.bsrnn.utils import (
|
7 |
+
band_widths_from_specs,
|
8 |
+
check_no_gap,
|
9 |
+
check_no_overlap,
|
10 |
+
check_nonzero_bandwidth,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class NormFC(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
emb_dim: int,
|
18 |
+
bandwidth: int,
|
19 |
+
in_channel: int,
|
20 |
+
normalize_channel_independently: bool = False,
|
21 |
+
treat_channel_as_feature: bool = True,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.treat_channel_as_feature = treat_channel_as_feature
|
26 |
+
|
27 |
+
if normalize_channel_independently:
|
28 |
+
raise NotImplementedError
|
29 |
+
|
30 |
+
reim = 2
|
31 |
+
|
32 |
+
self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
|
33 |
+
|
34 |
+
fc_in = bandwidth * reim
|
35 |
+
|
36 |
+
if treat_channel_as_feature:
|
37 |
+
fc_in *= in_channel
|
38 |
+
else:
|
39 |
+
assert emb_dim % in_channel == 0
|
40 |
+
emb_dim = emb_dim // in_channel
|
41 |
+
|
42 |
+
self.fc = nn.Linear(fc_in, emb_dim)
|
43 |
+
|
44 |
+
def forward(self, xb):
|
45 |
+
# xb = (batch, n_time, in_chan, reim * band_width)
|
46 |
+
|
47 |
+
batch, n_time, in_chan, ribw = xb.shape
|
48 |
+
xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
|
49 |
+
# (batch, n_time, in_chan * reim * band_width)
|
50 |
+
|
51 |
+
if not self.treat_channel_as_feature:
|
52 |
+
xb = xb.reshape(batch, n_time, in_chan, ribw)
|
53 |
+
# (batch, n_time, in_chan, reim * band_width)
|
54 |
+
|
55 |
+
zb = self.fc(xb)
|
56 |
+
# (batch, n_time, emb_dim)
|
57 |
+
# OR
|
58 |
+
# (batch, n_time, in_chan, emb_dim_per_chan)
|
59 |
+
|
60 |
+
if not self.treat_channel_as_feature:
|
61 |
+
batch, n_time, in_chan, emb_dim_per_chan = zb.shape
|
62 |
+
# (batch, n_time, in_chan, emb_dim_per_chan)
|
63 |
+
zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
|
64 |
+
|
65 |
+
return zb # (batch, n_time, emb_dim)
|
66 |
+
|
67 |
+
|
68 |
+
class BandSplitModule(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
band_specs: List[Tuple[float, float]],
|
72 |
+
emb_dim: int,
|
73 |
+
in_channel: int,
|
74 |
+
require_no_overlap: bool = False,
|
75 |
+
require_no_gap: bool = True,
|
76 |
+
normalize_channel_independently: bool = False,
|
77 |
+
treat_channel_as_feature: bool = True,
|
78 |
+
) -> None:
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
check_nonzero_bandwidth(band_specs)
|
82 |
+
|
83 |
+
if require_no_gap:
|
84 |
+
check_no_gap(band_specs)
|
85 |
+
|
86 |
+
if require_no_overlap:
|
87 |
+
check_no_overlap(band_specs)
|
88 |
+
|
89 |
+
self.band_specs = band_specs
|
90 |
+
# list of [fstart, fend) in index.
|
91 |
+
# Note that fend is exclusive.
|
92 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
93 |
+
self.n_bands = len(band_specs)
|
94 |
+
self.emb_dim = emb_dim
|
95 |
+
|
96 |
+
self.norm_fc_modules = nn.ModuleList(
|
97 |
+
[ # type: ignore
|
98 |
+
(
|
99 |
+
NormFC(
|
100 |
+
emb_dim=emb_dim,
|
101 |
+
bandwidth=bw,
|
102 |
+
in_channel=in_channel,
|
103 |
+
normalize_channel_independently=normalize_channel_independently,
|
104 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
105 |
+
)
|
106 |
+
)
|
107 |
+
for bw in self.band_widths
|
108 |
+
]
|
109 |
+
)
|
110 |
+
|
111 |
+
def forward(self, x: torch.Tensor):
|
112 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
113 |
+
|
114 |
+
batch, in_chan, _, n_time = x.shape
|
115 |
+
|
116 |
+
z = torch.zeros(
|
117 |
+
size=(batch, self.n_bands, n_time, self.emb_dim),
|
118 |
+
device=x.device
|
119 |
+
)
|
120 |
+
|
121 |
+
xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
|
122 |
+
xr = torch.permute(
|
123 |
+
xr,
|
124 |
+
(0, 3, 1, 4, 2)
|
125 |
+
) # batch, n_time, in_chan, 2, n_freq
|
126 |
+
batch, n_time, in_chan, reim, band_width = xr.shape
|
127 |
+
for i, nfm in enumerate(self.norm_fc_modules):
|
128 |
+
# print(f"bandsplit/band{i:02d}")
|
129 |
+
fstart, fend = self.band_specs[i]
|
130 |
+
xb = xr[..., fstart:fend]
|
131 |
+
# (batch, n_time, in_chan, reim, band_width)
|
132 |
+
xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
|
133 |
+
# (batch, n_time, in_chan, reim * band_width)
|
134 |
+
# z.append(nfm(xb)) # (batch, n_time, emb_dim)
|
135 |
+
z[:, i, :, :] = nfm(xb.contiguous())
|
136 |
+
|
137 |
+
# z = torch.stack(z, dim=1)
|
138 |
+
|
139 |
+
return z
|
models/bandit/core/model/bsrnn/core.py
ADDED
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from models.bandit.core.model.bsrnn import BandsplitCoreBase
|
8 |
+
from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
|
9 |
+
from models.bandit.core.model.bsrnn.maskestim import (
|
10 |
+
MaskEstimationModule,
|
11 |
+
OverlappingMaskEstimationModule
|
12 |
+
)
|
13 |
+
from models.bandit.core.model.bsrnn.tfmodel import (
|
14 |
+
ConvolutionalTimeFreqModule,
|
15 |
+
SeqBandModellingModule,
|
16 |
+
TransformerTimeFreqModule
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
|
21 |
+
def __init__(self) -> None:
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
def forward(self, x, cond=None, compute_residual: bool = True):
|
25 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
26 |
+
# print(x.shape)
|
27 |
+
batch, in_chan, n_freq, n_time = x.shape
|
28 |
+
x = torch.reshape(x, (-1, 1, n_freq, n_time))
|
29 |
+
|
30 |
+
z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
|
31 |
+
|
32 |
+
# if torch.any(torch.isnan(z)):
|
33 |
+
# raise ValueError("z nan")
|
34 |
+
|
35 |
+
# print(z)
|
36 |
+
q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
|
37 |
+
# print(q)
|
38 |
+
|
39 |
+
|
40 |
+
# if torch.any(torch.isnan(q)):
|
41 |
+
# raise ValueError("q nan")
|
42 |
+
|
43 |
+
out = {}
|
44 |
+
|
45 |
+
for stem, mem in self.mask_estim.items():
|
46 |
+
m = mem(q, cond=cond)
|
47 |
+
|
48 |
+
# if torch.any(torch.isnan(m)):
|
49 |
+
# raise ValueError("m nan", stem)
|
50 |
+
|
51 |
+
s = self.mask(x, m)
|
52 |
+
s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
|
53 |
+
out[stem] = s
|
54 |
+
|
55 |
+
return {"spectrogram": out}
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
def instantiate_mask_estim(self,
|
60 |
+
in_channel: int,
|
61 |
+
stems: List[str],
|
62 |
+
band_specs: List[Tuple[float, float]],
|
63 |
+
emb_dim: int,
|
64 |
+
mlp_dim: int,
|
65 |
+
cond_dim: int,
|
66 |
+
hidden_activation: str,
|
67 |
+
|
68 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
69 |
+
complex_mask: bool = True,
|
70 |
+
overlapping_band: bool = False,
|
71 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
72 |
+
n_freq: Optional[int] = None,
|
73 |
+
use_freq_weights: bool = True,
|
74 |
+
mult_add_mask: bool = False
|
75 |
+
):
|
76 |
+
if hidden_activation_kwargs is None:
|
77 |
+
hidden_activation_kwargs = {}
|
78 |
+
|
79 |
+
if "mne:+" in stems:
|
80 |
+
stems = [s for s in stems if s != "mne:+"]
|
81 |
+
|
82 |
+
if overlapping_band:
|
83 |
+
assert freq_weights is not None
|
84 |
+
assert n_freq is not None
|
85 |
+
|
86 |
+
if mult_add_mask:
|
87 |
+
|
88 |
+
self.mask_estim = nn.ModuleDict(
|
89 |
+
{
|
90 |
+
stem: MultAddMaskEstimationModule(
|
91 |
+
band_specs=band_specs,
|
92 |
+
freq_weights=freq_weights,
|
93 |
+
n_freq=n_freq,
|
94 |
+
emb_dim=emb_dim,
|
95 |
+
mlp_dim=mlp_dim,
|
96 |
+
in_channel=in_channel,
|
97 |
+
hidden_activation=hidden_activation,
|
98 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
99 |
+
complex_mask=complex_mask,
|
100 |
+
use_freq_weights=use_freq_weights,
|
101 |
+
)
|
102 |
+
for stem in stems
|
103 |
+
}
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
self.mask_estim = nn.ModuleDict(
|
107 |
+
{
|
108 |
+
stem: OverlappingMaskEstimationModule(
|
109 |
+
band_specs=band_specs,
|
110 |
+
freq_weights=freq_weights,
|
111 |
+
n_freq=n_freq,
|
112 |
+
emb_dim=emb_dim,
|
113 |
+
mlp_dim=mlp_dim,
|
114 |
+
in_channel=in_channel,
|
115 |
+
hidden_activation=hidden_activation,
|
116 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
117 |
+
complex_mask=complex_mask,
|
118 |
+
use_freq_weights=use_freq_weights,
|
119 |
+
)
|
120 |
+
for stem in stems
|
121 |
+
}
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
self.mask_estim = nn.ModuleDict(
|
125 |
+
{
|
126 |
+
stem: MaskEstimationModule(
|
127 |
+
band_specs=band_specs,
|
128 |
+
emb_dim=emb_dim,
|
129 |
+
mlp_dim=mlp_dim,
|
130 |
+
cond_dim=cond_dim,
|
131 |
+
in_channel=in_channel,
|
132 |
+
hidden_activation=hidden_activation,
|
133 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
134 |
+
complex_mask=complex_mask,
|
135 |
+
)
|
136 |
+
for stem in stems
|
137 |
+
}
|
138 |
+
)
|
139 |
+
|
140 |
+
def instantiate_bandsplit(self,
|
141 |
+
in_channel: int,
|
142 |
+
band_specs: List[Tuple[float, float]],
|
143 |
+
require_no_overlap: bool = False,
|
144 |
+
require_no_gap: bool = True,
|
145 |
+
normalize_channel_independently: bool = False,
|
146 |
+
treat_channel_as_feature: bool = True,
|
147 |
+
emb_dim: int = 128
|
148 |
+
):
|
149 |
+
self.band_split = BandSplitModule(
|
150 |
+
in_channel=in_channel,
|
151 |
+
band_specs=band_specs,
|
152 |
+
require_no_overlap=require_no_overlap,
|
153 |
+
require_no_gap=require_no_gap,
|
154 |
+
normalize_channel_independently=normalize_channel_independently,
|
155 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
156 |
+
emb_dim=emb_dim,
|
157 |
+
)
|
158 |
+
|
159 |
+
class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
|
160 |
+
def __init__(self, **kwargs) -> None:
|
161 |
+
super().__init__()
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
165 |
+
z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
|
166 |
+
q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
|
167 |
+
m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time)
|
168 |
+
|
169 |
+
s = self.mask(x, m)
|
170 |
+
|
171 |
+
return s
|
172 |
+
|
173 |
+
|
174 |
+
class SingleMaskBandsplitCoreRNN(
|
175 |
+
SingleMaskBandsplitCoreBase,
|
176 |
+
):
|
177 |
+
def __init__(
|
178 |
+
self,
|
179 |
+
in_channel: int,
|
180 |
+
band_specs: List[Tuple[float, float]],
|
181 |
+
require_no_overlap: bool = False,
|
182 |
+
require_no_gap: bool = True,
|
183 |
+
normalize_channel_independently: bool = False,
|
184 |
+
treat_channel_as_feature: bool = True,
|
185 |
+
n_sqm_modules: int = 12,
|
186 |
+
emb_dim: int = 128,
|
187 |
+
rnn_dim: int = 256,
|
188 |
+
bidirectional: bool = True,
|
189 |
+
rnn_type: str = "LSTM",
|
190 |
+
mlp_dim: int = 512,
|
191 |
+
hidden_activation: str = "Tanh",
|
192 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
193 |
+
complex_mask: bool = True,
|
194 |
+
) -> None:
|
195 |
+
super().__init__()
|
196 |
+
self.band_split = (BandSplitModule(
|
197 |
+
in_channel=in_channel,
|
198 |
+
band_specs=band_specs,
|
199 |
+
require_no_overlap=require_no_overlap,
|
200 |
+
require_no_gap=require_no_gap,
|
201 |
+
normalize_channel_independently=normalize_channel_independently,
|
202 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
203 |
+
emb_dim=emb_dim,
|
204 |
+
))
|
205 |
+
self.tf_model = (SeqBandModellingModule(
|
206 |
+
n_modules=n_sqm_modules,
|
207 |
+
emb_dim=emb_dim,
|
208 |
+
rnn_dim=rnn_dim,
|
209 |
+
bidirectional=bidirectional,
|
210 |
+
rnn_type=rnn_type,
|
211 |
+
))
|
212 |
+
self.mask_estim = (MaskEstimationModule(
|
213 |
+
in_channel=in_channel,
|
214 |
+
band_specs=band_specs,
|
215 |
+
emb_dim=emb_dim,
|
216 |
+
mlp_dim=mlp_dim,
|
217 |
+
hidden_activation=hidden_activation,
|
218 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
219 |
+
complex_mask=complex_mask,
|
220 |
+
))
|
221 |
+
|
222 |
+
|
223 |
+
class SingleMaskBandsplitCoreTransformer(
|
224 |
+
SingleMaskBandsplitCoreBase,
|
225 |
+
):
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
in_channel: int,
|
229 |
+
band_specs: List[Tuple[float, float]],
|
230 |
+
require_no_overlap: bool = False,
|
231 |
+
require_no_gap: bool = True,
|
232 |
+
normalize_channel_independently: bool = False,
|
233 |
+
treat_channel_as_feature: bool = True,
|
234 |
+
n_sqm_modules: int = 12,
|
235 |
+
emb_dim: int = 128,
|
236 |
+
rnn_dim: int = 256,
|
237 |
+
bidirectional: bool = True,
|
238 |
+
tf_dropout: float = 0.0,
|
239 |
+
mlp_dim: int = 512,
|
240 |
+
hidden_activation: str = "Tanh",
|
241 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
242 |
+
complex_mask: bool = True,
|
243 |
+
) -> None:
|
244 |
+
super().__init__()
|
245 |
+
self.band_split = BandSplitModule(
|
246 |
+
in_channel=in_channel,
|
247 |
+
band_specs=band_specs,
|
248 |
+
require_no_overlap=require_no_overlap,
|
249 |
+
require_no_gap=require_no_gap,
|
250 |
+
normalize_channel_independently=normalize_channel_independently,
|
251 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
252 |
+
emb_dim=emb_dim,
|
253 |
+
)
|
254 |
+
self.tf_model = TransformerTimeFreqModule(
|
255 |
+
n_modules=n_sqm_modules,
|
256 |
+
emb_dim=emb_dim,
|
257 |
+
rnn_dim=rnn_dim,
|
258 |
+
bidirectional=bidirectional,
|
259 |
+
dropout=tf_dropout,
|
260 |
+
)
|
261 |
+
self.mask_estim = MaskEstimationModule(
|
262 |
+
in_channel=in_channel,
|
263 |
+
band_specs=band_specs,
|
264 |
+
emb_dim=emb_dim,
|
265 |
+
mlp_dim=mlp_dim,
|
266 |
+
hidden_activation=hidden_activation,
|
267 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
268 |
+
complex_mask=complex_mask,
|
269 |
+
)
|
270 |
+
|
271 |
+
|
272 |
+
class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
|
273 |
+
def __init__(
|
274 |
+
self,
|
275 |
+
in_channel: int,
|
276 |
+
stems: List[str],
|
277 |
+
band_specs: List[Tuple[float, float]],
|
278 |
+
require_no_overlap: bool = False,
|
279 |
+
require_no_gap: bool = True,
|
280 |
+
normalize_channel_independently: bool = False,
|
281 |
+
treat_channel_as_feature: bool = True,
|
282 |
+
n_sqm_modules: int = 12,
|
283 |
+
emb_dim: int = 128,
|
284 |
+
rnn_dim: int = 256,
|
285 |
+
bidirectional: bool = True,
|
286 |
+
rnn_type: str = "LSTM",
|
287 |
+
mlp_dim: int = 512,
|
288 |
+
cond_dim: int = 0,
|
289 |
+
hidden_activation: str = "Tanh",
|
290 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
291 |
+
complex_mask: bool = True,
|
292 |
+
overlapping_band: bool = False,
|
293 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
294 |
+
n_freq: Optional[int] = None,
|
295 |
+
use_freq_weights: bool = True,
|
296 |
+
mult_add_mask: bool = False
|
297 |
+
) -> None:
|
298 |
+
|
299 |
+
super().__init__()
|
300 |
+
self.instantiate_bandsplit(
|
301 |
+
in_channel=in_channel,
|
302 |
+
band_specs=band_specs,
|
303 |
+
require_no_overlap=require_no_overlap,
|
304 |
+
require_no_gap=require_no_gap,
|
305 |
+
normalize_channel_independently=normalize_channel_independently,
|
306 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
307 |
+
emb_dim=emb_dim
|
308 |
+
)
|
309 |
+
|
310 |
+
|
311 |
+
self.tf_model = (
|
312 |
+
SeqBandModellingModule(
|
313 |
+
n_modules=n_sqm_modules,
|
314 |
+
emb_dim=emb_dim,
|
315 |
+
rnn_dim=rnn_dim,
|
316 |
+
bidirectional=bidirectional,
|
317 |
+
rnn_type=rnn_type,
|
318 |
+
)
|
319 |
+
)
|
320 |
+
|
321 |
+
self.mult_add_mask = mult_add_mask
|
322 |
+
|
323 |
+
self.instantiate_mask_estim(
|
324 |
+
in_channel=in_channel,
|
325 |
+
stems=stems,
|
326 |
+
band_specs=band_specs,
|
327 |
+
emb_dim=emb_dim,
|
328 |
+
mlp_dim=mlp_dim,
|
329 |
+
cond_dim=cond_dim,
|
330 |
+
hidden_activation=hidden_activation,
|
331 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
332 |
+
complex_mask=complex_mask,
|
333 |
+
overlapping_band=overlapping_band,
|
334 |
+
freq_weights=freq_weights,
|
335 |
+
n_freq=n_freq,
|
336 |
+
use_freq_weights=use_freq_weights,
|
337 |
+
mult_add_mask=mult_add_mask
|
338 |
+
)
|
339 |
+
|
340 |
+
@staticmethod
|
341 |
+
def _mult_add_mask(x, m):
|
342 |
+
|
343 |
+
assert m.ndim == 5
|
344 |
+
|
345 |
+
mm = m[..., 0]
|
346 |
+
am = m[..., 1]
|
347 |
+
|
348 |
+
# print(mm.shape, am.shape, x.shape, m.shape)
|
349 |
+
|
350 |
+
return x * mm + am
|
351 |
+
|
352 |
+
def mask(self, x, m):
|
353 |
+
if self.mult_add_mask:
|
354 |
+
|
355 |
+
return self._mult_add_mask(x, m)
|
356 |
+
else:
|
357 |
+
return super().mask(x, m)
|
358 |
+
|
359 |
+
|
360 |
+
class MultiSourceMultiMaskBandSplitCoreTransformer(
|
361 |
+
MultiMaskBandSplitCoreBase,
|
362 |
+
):
|
363 |
+
def __init__(
|
364 |
+
self,
|
365 |
+
in_channel: int,
|
366 |
+
stems: List[str],
|
367 |
+
band_specs: List[Tuple[float, float]],
|
368 |
+
require_no_overlap: bool = False,
|
369 |
+
require_no_gap: bool = True,
|
370 |
+
normalize_channel_independently: bool = False,
|
371 |
+
treat_channel_as_feature: bool = True,
|
372 |
+
n_sqm_modules: int = 12,
|
373 |
+
emb_dim: int = 128,
|
374 |
+
rnn_dim: int = 256,
|
375 |
+
bidirectional: bool = True,
|
376 |
+
tf_dropout: float = 0.0,
|
377 |
+
mlp_dim: int = 512,
|
378 |
+
hidden_activation: str = "Tanh",
|
379 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
380 |
+
complex_mask: bool = True,
|
381 |
+
overlapping_band: bool = False,
|
382 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
383 |
+
n_freq: Optional[int] = None,
|
384 |
+
use_freq_weights:bool=True,
|
385 |
+
rnn_type: str = "LSTM",
|
386 |
+
cond_dim: int = 0,
|
387 |
+
mult_add_mask: bool = False
|
388 |
+
) -> None:
|
389 |
+
super().__init__()
|
390 |
+
self.instantiate_bandsplit(
|
391 |
+
in_channel=in_channel,
|
392 |
+
band_specs=band_specs,
|
393 |
+
require_no_overlap=require_no_overlap,
|
394 |
+
require_no_gap=require_no_gap,
|
395 |
+
normalize_channel_independently=normalize_channel_independently,
|
396 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
397 |
+
emb_dim=emb_dim
|
398 |
+
)
|
399 |
+
self.tf_model = TransformerTimeFreqModule(
|
400 |
+
n_modules=n_sqm_modules,
|
401 |
+
emb_dim=emb_dim,
|
402 |
+
rnn_dim=rnn_dim,
|
403 |
+
bidirectional=bidirectional,
|
404 |
+
dropout=tf_dropout,
|
405 |
+
)
|
406 |
+
|
407 |
+
self.instantiate_mask_estim(
|
408 |
+
in_channel=in_channel,
|
409 |
+
stems=stems,
|
410 |
+
band_specs=band_specs,
|
411 |
+
emb_dim=emb_dim,
|
412 |
+
mlp_dim=mlp_dim,
|
413 |
+
cond_dim=cond_dim,
|
414 |
+
hidden_activation=hidden_activation,
|
415 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
416 |
+
complex_mask=complex_mask,
|
417 |
+
overlapping_band=overlapping_band,
|
418 |
+
freq_weights=freq_weights,
|
419 |
+
n_freq=n_freq,
|
420 |
+
use_freq_weights=use_freq_weights,
|
421 |
+
mult_add_mask=mult_add_mask
|
422 |
+
)
|
423 |
+
|
424 |
+
|
425 |
+
|
426 |
+
class MultiSourceMultiMaskBandSplitCoreConv(
|
427 |
+
MultiMaskBandSplitCoreBase,
|
428 |
+
):
|
429 |
+
def __init__(
|
430 |
+
self,
|
431 |
+
in_channel: int,
|
432 |
+
stems: List[str],
|
433 |
+
band_specs: List[Tuple[float, float]],
|
434 |
+
require_no_overlap: bool = False,
|
435 |
+
require_no_gap: bool = True,
|
436 |
+
normalize_channel_independently: bool = False,
|
437 |
+
treat_channel_as_feature: bool = True,
|
438 |
+
n_sqm_modules: int = 12,
|
439 |
+
emb_dim: int = 128,
|
440 |
+
rnn_dim: int = 256,
|
441 |
+
bidirectional: bool = True,
|
442 |
+
tf_dropout: float = 0.0,
|
443 |
+
mlp_dim: int = 512,
|
444 |
+
hidden_activation: str = "Tanh",
|
445 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
446 |
+
complex_mask: bool = True,
|
447 |
+
overlapping_band: bool = False,
|
448 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
449 |
+
n_freq: Optional[int] = None,
|
450 |
+
use_freq_weights:bool=True,
|
451 |
+
rnn_type: str = "LSTM",
|
452 |
+
cond_dim: int = 0,
|
453 |
+
mult_add_mask: bool = False
|
454 |
+
) -> None:
|
455 |
+
super().__init__()
|
456 |
+
self.instantiate_bandsplit(
|
457 |
+
in_channel=in_channel,
|
458 |
+
band_specs=band_specs,
|
459 |
+
require_no_overlap=require_no_overlap,
|
460 |
+
require_no_gap=require_no_gap,
|
461 |
+
normalize_channel_independently=normalize_channel_independently,
|
462 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
463 |
+
emb_dim=emb_dim
|
464 |
+
)
|
465 |
+
self.tf_model = ConvolutionalTimeFreqModule(
|
466 |
+
n_modules=n_sqm_modules,
|
467 |
+
emb_dim=emb_dim,
|
468 |
+
rnn_dim=rnn_dim,
|
469 |
+
bidirectional=bidirectional,
|
470 |
+
dropout=tf_dropout,
|
471 |
+
)
|
472 |
+
|
473 |
+
self.instantiate_mask_estim(
|
474 |
+
in_channel=in_channel,
|
475 |
+
stems=stems,
|
476 |
+
band_specs=band_specs,
|
477 |
+
emb_dim=emb_dim,
|
478 |
+
mlp_dim=mlp_dim,
|
479 |
+
cond_dim=cond_dim,
|
480 |
+
hidden_activation=hidden_activation,
|
481 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
482 |
+
complex_mask=complex_mask,
|
483 |
+
overlapping_band=overlapping_band,
|
484 |
+
freq_weights=freq_weights,
|
485 |
+
n_freq=n_freq,
|
486 |
+
use_freq_weights=use_freq_weights,
|
487 |
+
mult_add_mask=mult_add_mask
|
488 |
+
)
|
489 |
+
|
490 |
+
|
491 |
+
class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
|
492 |
+
def __init__(self) -> None:
|
493 |
+
super().__init__()
|
494 |
+
|
495 |
+
def mask(self, x, m):
|
496 |
+
# x.shape = (batch, n_channel, n_freq, n_time)
|
497 |
+
# m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
|
498 |
+
|
499 |
+
_, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
|
500 |
+
padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
|
501 |
+
|
502 |
+
xf = F.unfold(
|
503 |
+
x,
|
504 |
+
kernel_size=(kernel_freq, kernel_time),
|
505 |
+
padding=padding,
|
506 |
+
stride=(1, 1),
|
507 |
+
)
|
508 |
+
|
509 |
+
xf = xf.view(
|
510 |
+
-1,
|
511 |
+
n_channel,
|
512 |
+
kernel_freq,
|
513 |
+
kernel_time,
|
514 |
+
n_freq,
|
515 |
+
n_time,
|
516 |
+
)
|
517 |
+
|
518 |
+
sf = xf * m
|
519 |
+
|
520 |
+
sf = sf.view(
|
521 |
+
-1,
|
522 |
+
n_channel * kernel_freq * kernel_time,
|
523 |
+
n_freq * n_time,
|
524 |
+
)
|
525 |
+
|
526 |
+
s = F.fold(
|
527 |
+
sf,
|
528 |
+
output_size=(n_freq, n_time),
|
529 |
+
kernel_size=(kernel_freq, kernel_time),
|
530 |
+
padding=padding,
|
531 |
+
stride=(1, 1),
|
532 |
+
).view(
|
533 |
+
-1,
|
534 |
+
n_channel,
|
535 |
+
n_freq,
|
536 |
+
n_time,
|
537 |
+
)
|
538 |
+
|
539 |
+
return s
|
540 |
+
|
541 |
+
def old_mask(self, x, m):
|
542 |
+
# x.shape = (batch, n_channel, n_freq, n_time)
|
543 |
+
# m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
|
544 |
+
|
545 |
+
s = torch.zeros_like(x)
|
546 |
+
|
547 |
+
_, n_channel, n_freq, n_time = x.shape
|
548 |
+
kernel_freq, kernel_time, _, _, _, _ = m.shape
|
549 |
+
|
550 |
+
# print(x.shape, m.shape)
|
551 |
+
|
552 |
+
kernel_freq_half = (kernel_freq - 1) // 2
|
553 |
+
kernel_time_half = (kernel_time - 1) // 2
|
554 |
+
|
555 |
+
for ifreq in range(kernel_freq):
|
556 |
+
for itime in range(kernel_time):
|
557 |
+
df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
|
558 |
+
x = x.roll(shifts=(df, dt), dims=(2, 3))
|
559 |
+
|
560 |
+
# if `df` > 0:
|
561 |
+
# x[:, :, :df, :] = 0
|
562 |
+
# elif `df` < 0:
|
563 |
+
# x[:, :, df:, :] = 0
|
564 |
+
|
565 |
+
# if `dt` > 0:
|
566 |
+
# x[:, :, :, :dt] = 0
|
567 |
+
# elif `dt` < 0:
|
568 |
+
# x[:, :, :, dt:] = 0
|
569 |
+
|
570 |
+
fslice = slice(max(0, df), min(n_freq, n_freq + df))
|
571 |
+
tslice = slice(max(0, dt), min(n_time, n_time + dt))
|
572 |
+
|
573 |
+
s[:, :, fslice, tslice] += x[:, :, fslice, tslice] * m[ifreq,
|
574 |
+
itime, :,
|
575 |
+
:, fslice,
|
576 |
+
tslice]
|
577 |
+
|
578 |
+
return s
|
579 |
+
|
580 |
+
|
581 |
+
class MultiSourceMultiPatchingMaskBandSplitCoreRNN(
|
582 |
+
PatchingMaskBandsplitCoreBase
|
583 |
+
):
|
584 |
+
def __init__(
|
585 |
+
self,
|
586 |
+
in_channel: int,
|
587 |
+
stems: List[str],
|
588 |
+
band_specs: List[Tuple[float, float]],
|
589 |
+
mask_kernel_freq: int,
|
590 |
+
mask_kernel_time: int,
|
591 |
+
conv_kernel_freq: int,
|
592 |
+
conv_kernel_time: int,
|
593 |
+
kernel_norm_mlp_version: int,
|
594 |
+
require_no_overlap: bool = False,
|
595 |
+
require_no_gap: bool = True,
|
596 |
+
normalize_channel_independently: bool = False,
|
597 |
+
treat_channel_as_feature: bool = True,
|
598 |
+
n_sqm_modules: int = 12,
|
599 |
+
emb_dim: int = 128,
|
600 |
+
rnn_dim: int = 256,
|
601 |
+
bidirectional: bool = True,
|
602 |
+
rnn_type: str = "LSTM",
|
603 |
+
mlp_dim: int = 512,
|
604 |
+
hidden_activation: str = "Tanh",
|
605 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
606 |
+
complex_mask: bool = True,
|
607 |
+
overlapping_band: bool = False,
|
608 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
609 |
+
n_freq: Optional[int] = None,
|
610 |
+
) -> None:
|
611 |
+
|
612 |
+
super().__init__()
|
613 |
+
self.band_split = BandSplitModule(
|
614 |
+
in_channel=in_channel,
|
615 |
+
band_specs=band_specs,
|
616 |
+
require_no_overlap=require_no_overlap,
|
617 |
+
require_no_gap=require_no_gap,
|
618 |
+
normalize_channel_independently=normalize_channel_independently,
|
619 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
620 |
+
emb_dim=emb_dim,
|
621 |
+
)
|
622 |
+
|
623 |
+
self.tf_model = (
|
624 |
+
SeqBandModellingModule(
|
625 |
+
n_modules=n_sqm_modules,
|
626 |
+
emb_dim=emb_dim,
|
627 |
+
rnn_dim=rnn_dim,
|
628 |
+
bidirectional=bidirectional,
|
629 |
+
rnn_type=rnn_type,
|
630 |
+
)
|
631 |
+
)
|
632 |
+
|
633 |
+
if hidden_activation_kwargs is None:
|
634 |
+
hidden_activation_kwargs = {}
|
635 |
+
|
636 |
+
if overlapping_band:
|
637 |
+
assert freq_weights is not None
|
638 |
+
assert n_freq is not None
|
639 |
+
self.mask_estim = nn.ModuleDict(
|
640 |
+
{
|
641 |
+
stem: PatchingMaskEstimationModule(
|
642 |
+
band_specs=band_specs,
|
643 |
+
freq_weights=freq_weights,
|
644 |
+
n_freq=n_freq,
|
645 |
+
emb_dim=emb_dim,
|
646 |
+
mlp_dim=mlp_dim,
|
647 |
+
in_channel=in_channel,
|
648 |
+
hidden_activation=hidden_activation,
|
649 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
650 |
+
complex_mask=complex_mask,
|
651 |
+
mask_kernel_freq=mask_kernel_freq,
|
652 |
+
mask_kernel_time=mask_kernel_time,
|
653 |
+
conv_kernel_freq=conv_kernel_freq,
|
654 |
+
conv_kernel_time=conv_kernel_time,
|
655 |
+
kernel_norm_mlp_version=kernel_norm_mlp_version
|
656 |
+
)
|
657 |
+
for stem in stems
|
658 |
+
}
|
659 |
+
)
|
660 |
+
else:
|
661 |
+
raise NotImplementedError
|
models/bandit/core/model/bsrnn/maskestim.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Dict, List, Optional, Tuple, Type
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn.modules import activation
|
7 |
+
|
8 |
+
from models.bandit.core.model.bsrnn.utils import (
|
9 |
+
band_widths_from_specs,
|
10 |
+
check_no_gap,
|
11 |
+
check_no_overlap,
|
12 |
+
check_nonzero_bandwidth,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
class BaseNormMLP(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
emb_dim: int,
|
20 |
+
mlp_dim: int,
|
21 |
+
bandwidth: int,
|
22 |
+
in_channel: Optional[int],
|
23 |
+
hidden_activation: str = "Tanh",
|
24 |
+
hidden_activation_kwargs=None,
|
25 |
+
complex_mask: bool = True, ):
|
26 |
+
|
27 |
+
super().__init__()
|
28 |
+
if hidden_activation_kwargs is None:
|
29 |
+
hidden_activation_kwargs = {}
|
30 |
+
self.hidden_activation_kwargs = hidden_activation_kwargs
|
31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
32 |
+
self.hidden = torch.jit.script(nn.Sequential(
|
33 |
+
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
34 |
+
activation.__dict__[hidden_activation](
|
35 |
+
**self.hidden_activation_kwargs
|
36 |
+
),
|
37 |
+
))
|
38 |
+
|
39 |
+
self.bandwidth = bandwidth
|
40 |
+
self.in_channel = in_channel
|
41 |
+
|
42 |
+
self.complex_mask = complex_mask
|
43 |
+
self.reim = 2 if complex_mask else 1
|
44 |
+
self.glu_mult = 2
|
45 |
+
|
46 |
+
|
47 |
+
class NormMLP(BaseNormMLP):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
emb_dim: int,
|
51 |
+
mlp_dim: int,
|
52 |
+
bandwidth: int,
|
53 |
+
in_channel: Optional[int],
|
54 |
+
hidden_activation: str = "Tanh",
|
55 |
+
hidden_activation_kwargs=None,
|
56 |
+
complex_mask: bool = True,
|
57 |
+
) -> None:
|
58 |
+
super().__init__(
|
59 |
+
emb_dim=emb_dim,
|
60 |
+
mlp_dim=mlp_dim,
|
61 |
+
bandwidth=bandwidth,
|
62 |
+
in_channel=in_channel,
|
63 |
+
hidden_activation=hidden_activation,
|
64 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
65 |
+
complex_mask=complex_mask,
|
66 |
+
)
|
67 |
+
|
68 |
+
self.output = torch.jit.script(
|
69 |
+
nn.Sequential(
|
70 |
+
nn.Linear(
|
71 |
+
in_features=mlp_dim,
|
72 |
+
out_features=bandwidth * in_channel * self.reim * 2,
|
73 |
+
),
|
74 |
+
nn.GLU(dim=-1),
|
75 |
+
)
|
76 |
+
)
|
77 |
+
|
78 |
+
def reshape_output(self, mb):
|
79 |
+
# print(mb.shape)
|
80 |
+
batch, n_time, _ = mb.shape
|
81 |
+
if self.complex_mask:
|
82 |
+
mb = mb.reshape(
|
83 |
+
batch,
|
84 |
+
n_time,
|
85 |
+
self.in_channel,
|
86 |
+
self.bandwidth,
|
87 |
+
self.reim
|
88 |
+
).contiguous()
|
89 |
+
# print(mb.shape)
|
90 |
+
mb = torch.view_as_complex(
|
91 |
+
mb
|
92 |
+
) # (batch, n_time, in_channel, bandwidth)
|
93 |
+
else:
|
94 |
+
mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
|
95 |
+
|
96 |
+
mb = torch.permute(
|
97 |
+
mb,
|
98 |
+
(0, 2, 3, 1)
|
99 |
+
) # (batch, in_channel, bandwidth, n_time)
|
100 |
+
|
101 |
+
return mb
|
102 |
+
|
103 |
+
def forward(self, qb):
|
104 |
+
# qb = (batch, n_time, emb_dim)
|
105 |
+
|
106 |
+
# if torch.any(torch.isnan(qb)):
|
107 |
+
# raise ValueError("qb0")
|
108 |
+
|
109 |
+
|
110 |
+
qb = self.norm(qb) # (batch, n_time, emb_dim)
|
111 |
+
|
112 |
+
# if torch.any(torch.isnan(qb)):
|
113 |
+
# raise ValueError("qb1")
|
114 |
+
|
115 |
+
qb = self.hidden(qb) # (batch, n_time, mlp_dim)
|
116 |
+
# if torch.any(torch.isnan(qb)):
|
117 |
+
# raise ValueError("qb2")
|
118 |
+
mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
|
119 |
+
# if torch.any(torch.isnan(qb)):
|
120 |
+
# raise ValueError("mb")
|
121 |
+
mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time)
|
122 |
+
|
123 |
+
return mb
|
124 |
+
|
125 |
+
|
126 |
+
class MultAddNormMLP(NormMLP):
|
127 |
+
def __init__(self, emb_dim: int, mlp_dim: int, bandwidth: int, in_channel: "int | None", hidden_activation: str = "Tanh", hidden_activation_kwargs=None, complex_mask: bool = True) -> None:
|
128 |
+
super().__init__(emb_dim, mlp_dim, bandwidth, in_channel, hidden_activation, hidden_activation_kwargs, complex_mask)
|
129 |
+
|
130 |
+
self.output2 = torch.jit.script(
|
131 |
+
nn.Sequential(
|
132 |
+
nn.Linear(
|
133 |
+
in_features=mlp_dim,
|
134 |
+
out_features=bandwidth * in_channel * self.reim * 2,
|
135 |
+
),
|
136 |
+
nn.GLU(dim=-1),
|
137 |
+
)
|
138 |
+
)
|
139 |
+
|
140 |
+
def forward(self, qb):
|
141 |
+
|
142 |
+
qb = self.norm(qb) # (batch, n_time, emb_dim)
|
143 |
+
qb = self.hidden(qb) # (batch, n_time, mlp_dim)
|
144 |
+
mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
|
145 |
+
mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time)
|
146 |
+
amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim)
|
147 |
+
amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time)
|
148 |
+
|
149 |
+
return mmb, amb
|
150 |
+
|
151 |
+
|
152 |
+
class MaskEstimationModuleSuperBase(nn.Module):
|
153 |
+
pass
|
154 |
+
|
155 |
+
|
156 |
+
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
band_specs: List[Tuple[float, float]],
|
160 |
+
emb_dim: int,
|
161 |
+
mlp_dim: int,
|
162 |
+
in_channel: Optional[int],
|
163 |
+
hidden_activation: str = "Tanh",
|
164 |
+
hidden_activation_kwargs: Dict = None,
|
165 |
+
complex_mask: bool = True,
|
166 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
167 |
+
norm_mlp_kwargs: Dict = None,
|
168 |
+
) -> None:
|
169 |
+
super().__init__()
|
170 |
+
|
171 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
172 |
+
self.n_bands = len(band_specs)
|
173 |
+
|
174 |
+
if hidden_activation_kwargs is None:
|
175 |
+
hidden_activation_kwargs = {}
|
176 |
+
|
177 |
+
if norm_mlp_kwargs is None:
|
178 |
+
norm_mlp_kwargs = {}
|
179 |
+
|
180 |
+
self.norm_mlp = nn.ModuleList(
|
181 |
+
[
|
182 |
+
(
|
183 |
+
norm_mlp_cls(
|
184 |
+
bandwidth=self.band_widths[b],
|
185 |
+
emb_dim=emb_dim,
|
186 |
+
mlp_dim=mlp_dim,
|
187 |
+
in_channel=in_channel,
|
188 |
+
hidden_activation=hidden_activation,
|
189 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
190 |
+
complex_mask=complex_mask,
|
191 |
+
**norm_mlp_kwargs,
|
192 |
+
)
|
193 |
+
)
|
194 |
+
for b in range(self.n_bands)
|
195 |
+
]
|
196 |
+
)
|
197 |
+
|
198 |
+
def compute_masks(self, q):
|
199 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
200 |
+
|
201 |
+
masks = []
|
202 |
+
|
203 |
+
for b, nmlp in enumerate(self.norm_mlp):
|
204 |
+
# print(f"maskestim/{b:02d}")
|
205 |
+
qb = q[:, b, :, :]
|
206 |
+
mb = nmlp(qb)
|
207 |
+
masks.append(mb)
|
208 |
+
|
209 |
+
return masks
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
in_channel: int,
|
217 |
+
band_specs: List[Tuple[float, float]],
|
218 |
+
freq_weights: List[torch.Tensor],
|
219 |
+
n_freq: int,
|
220 |
+
emb_dim: int,
|
221 |
+
mlp_dim: int,
|
222 |
+
cond_dim: int = 0,
|
223 |
+
hidden_activation: str = "Tanh",
|
224 |
+
hidden_activation_kwargs: Dict = None,
|
225 |
+
complex_mask: bool = True,
|
226 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
227 |
+
norm_mlp_kwargs: Dict = None,
|
228 |
+
use_freq_weights: bool = True,
|
229 |
+
) -> None:
|
230 |
+
check_nonzero_bandwidth(band_specs)
|
231 |
+
check_no_gap(band_specs)
|
232 |
+
|
233 |
+
# if cond_dim > 0:
|
234 |
+
# raise NotImplementedError
|
235 |
+
|
236 |
+
super().__init__(
|
237 |
+
band_specs=band_specs,
|
238 |
+
emb_dim=emb_dim + cond_dim,
|
239 |
+
mlp_dim=mlp_dim,
|
240 |
+
in_channel=in_channel,
|
241 |
+
hidden_activation=hidden_activation,
|
242 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
243 |
+
complex_mask=complex_mask,
|
244 |
+
norm_mlp_cls=norm_mlp_cls,
|
245 |
+
norm_mlp_kwargs=norm_mlp_kwargs,
|
246 |
+
)
|
247 |
+
|
248 |
+
self.n_freq = n_freq
|
249 |
+
self.band_specs = band_specs
|
250 |
+
self.in_channel = in_channel
|
251 |
+
|
252 |
+
if freq_weights is not None:
|
253 |
+
for i, fw in enumerate(freq_weights):
|
254 |
+
self.register_buffer(f"freq_weights/{i}", fw)
|
255 |
+
|
256 |
+
self.use_freq_weights = use_freq_weights
|
257 |
+
else:
|
258 |
+
self.use_freq_weights = False
|
259 |
+
|
260 |
+
self.cond_dim = cond_dim
|
261 |
+
|
262 |
+
def forward(self, q, cond=None):
|
263 |
+
# q = (batch, n_bands, n_time, emb_dim)
|
264 |
+
|
265 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
266 |
+
|
267 |
+
if cond is not None:
|
268 |
+
print(cond)
|
269 |
+
if cond.ndim == 2:
|
270 |
+
cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
|
271 |
+
elif cond.ndim == 3:
|
272 |
+
assert cond.shape[1] == n_time
|
273 |
+
else:
|
274 |
+
raise ValueError(f"Invalid cond shape: {cond.shape}")
|
275 |
+
|
276 |
+
q = torch.cat([q, cond], dim=-1)
|
277 |
+
elif self.cond_dim > 0:
|
278 |
+
cond = torch.ones(
|
279 |
+
(batch, n_bands, n_time, self.cond_dim),
|
280 |
+
device=q.device,
|
281 |
+
dtype=q.dtype,
|
282 |
+
)
|
283 |
+
q = torch.cat([q, cond], dim=-1)
|
284 |
+
else:
|
285 |
+
pass
|
286 |
+
|
287 |
+
mask_list = self.compute_masks(
|
288 |
+
q
|
289 |
+
) # [n_bands * (batch, in_channel, bandwidth, n_time)]
|
290 |
+
|
291 |
+
masks = torch.zeros(
|
292 |
+
(batch, self.in_channel, self.n_freq, n_time),
|
293 |
+
device=q.device,
|
294 |
+
dtype=mask_list[0].dtype,
|
295 |
+
)
|
296 |
+
|
297 |
+
for im, mask in enumerate(mask_list):
|
298 |
+
fstart, fend = self.band_specs[im]
|
299 |
+
if self.use_freq_weights:
|
300 |
+
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
301 |
+
mask = mask * fw
|
302 |
+
masks[:, :, fstart:fend, :] += mask
|
303 |
+
|
304 |
+
return masks
|
305 |
+
|
306 |
+
|
307 |
+
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
308 |
+
def __init__(
|
309 |
+
self,
|
310 |
+
band_specs: List[Tuple[float, float]],
|
311 |
+
emb_dim: int,
|
312 |
+
mlp_dim: int,
|
313 |
+
in_channel: Optional[int],
|
314 |
+
hidden_activation: str = "Tanh",
|
315 |
+
hidden_activation_kwargs: Dict = None,
|
316 |
+
complex_mask: bool = True,
|
317 |
+
**kwargs,
|
318 |
+
) -> None:
|
319 |
+
check_nonzero_bandwidth(band_specs)
|
320 |
+
check_no_gap(band_specs)
|
321 |
+
check_no_overlap(band_specs)
|
322 |
+
super().__init__(
|
323 |
+
in_channel=in_channel,
|
324 |
+
band_specs=band_specs,
|
325 |
+
freq_weights=None,
|
326 |
+
n_freq=None,
|
327 |
+
emb_dim=emb_dim,
|
328 |
+
mlp_dim=mlp_dim,
|
329 |
+
hidden_activation=hidden_activation,
|
330 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
331 |
+
complex_mask=complex_mask,
|
332 |
+
)
|
333 |
+
|
334 |
+
def forward(self, q, cond=None):
|
335 |
+
# q = (batch, n_bands, n_time, emb_dim)
|
336 |
+
|
337 |
+
masks = self.compute_masks(
|
338 |
+
q
|
339 |
+
) # [n_bands * (batch, in_channel, bandwidth, n_time)]
|
340 |
+
|
341 |
+
# TODO: currently this requires band specs to have no gap and no overlap
|
342 |
+
masks = torch.concat(
|
343 |
+
masks,
|
344 |
+
dim=2
|
345 |
+
) # (batch, in_channel, n_freq, n_time)
|
346 |
+
|
347 |
+
return masks
|
models/bandit/core/model/bsrnn/tfmodel.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.nn.modules import rnn
|
7 |
+
|
8 |
+
import torch.backends.cuda
|
9 |
+
|
10 |
+
|
11 |
+
class TimeFrequencyModellingModule(nn.Module):
|
12 |
+
def __init__(self) -> None:
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
|
16 |
+
class ResidualRNN(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
emb_dim: int,
|
20 |
+
rnn_dim: int,
|
21 |
+
bidirectional: bool = True,
|
22 |
+
rnn_type: str = "LSTM",
|
23 |
+
use_batch_trick: bool = True,
|
24 |
+
use_layer_norm: bool = True,
|
25 |
+
) -> None:
|
26 |
+
# n_group is the size of the 2nd dim
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
self.use_layer_norm = use_layer_norm
|
30 |
+
if use_layer_norm:
|
31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
32 |
+
else:
|
33 |
+
self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
|
34 |
+
|
35 |
+
self.rnn = rnn.__dict__[rnn_type](
|
36 |
+
input_size=emb_dim,
|
37 |
+
hidden_size=rnn_dim,
|
38 |
+
num_layers=1,
|
39 |
+
batch_first=True,
|
40 |
+
bidirectional=bidirectional,
|
41 |
+
)
|
42 |
+
|
43 |
+
self.fc = nn.Linear(
|
44 |
+
in_features=rnn_dim * (2 if bidirectional else 1),
|
45 |
+
out_features=emb_dim
|
46 |
+
)
|
47 |
+
|
48 |
+
self.use_batch_trick = use_batch_trick
|
49 |
+
if not self.use_batch_trick:
|
50 |
+
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
51 |
+
|
52 |
+
def forward(self, z):
|
53 |
+
# z = (batch, n_uncrossed, n_across, emb_dim)
|
54 |
+
|
55 |
+
z0 = torch.clone(z)
|
56 |
+
|
57 |
+
# print(z.device)
|
58 |
+
|
59 |
+
if self.use_layer_norm:
|
60 |
+
z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
|
61 |
+
else:
|
62 |
+
z = torch.permute(
|
63 |
+
z, (0, 3, 1, 2)
|
64 |
+
) # (batch, emb_dim, n_uncrossed, n_across)
|
65 |
+
|
66 |
+
z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across)
|
67 |
+
|
68 |
+
z = torch.permute(
|
69 |
+
z, (0, 2, 3, 1)
|
70 |
+
) # (batch, n_uncrossed, n_across, emb_dim)
|
71 |
+
|
72 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
73 |
+
|
74 |
+
if self.use_batch_trick:
|
75 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
76 |
+
|
77 |
+
z = self.rnn(z.contiguous())[0] # (batch * n_uncrossed, n_across, dir_rnn_dim)
|
78 |
+
|
79 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
80 |
+
# (batch, n_uncrossed, n_across, dir_rnn_dim)
|
81 |
+
else:
|
82 |
+
# Note: this is EXTREMELY SLOW
|
83 |
+
zlist = []
|
84 |
+
for i in range(n_uncrossed):
|
85 |
+
zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim)
|
86 |
+
zlist.append(zi)
|
87 |
+
|
88 |
+
z = torch.stack(
|
89 |
+
zlist,
|
90 |
+
dim=1
|
91 |
+
) # (batch, n_uncrossed, n_across, dir_rnn_dim)
|
92 |
+
|
93 |
+
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
|
94 |
+
|
95 |
+
z = z + z0
|
96 |
+
|
97 |
+
return z
|
98 |
+
|
99 |
+
|
100 |
+
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
n_modules: int = 12,
|
104 |
+
emb_dim: int = 128,
|
105 |
+
rnn_dim: int = 256,
|
106 |
+
bidirectional: bool = True,
|
107 |
+
rnn_type: str = "LSTM",
|
108 |
+
parallel_mode=False,
|
109 |
+
) -> None:
|
110 |
+
super().__init__()
|
111 |
+
self.seqband = nn.ModuleList([])
|
112 |
+
|
113 |
+
if parallel_mode:
|
114 |
+
for _ in range(n_modules):
|
115 |
+
self.seqband.append(
|
116 |
+
nn.ModuleList(
|
117 |
+
[ResidualRNN(
|
118 |
+
emb_dim=emb_dim,
|
119 |
+
rnn_dim=rnn_dim,
|
120 |
+
bidirectional=bidirectional,
|
121 |
+
rnn_type=rnn_type,
|
122 |
+
),
|
123 |
+
ResidualRNN(
|
124 |
+
emb_dim=emb_dim,
|
125 |
+
rnn_dim=rnn_dim,
|
126 |
+
bidirectional=bidirectional,
|
127 |
+
rnn_type=rnn_type,
|
128 |
+
)]
|
129 |
+
)
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
|
133 |
+
for _ in range(2 * n_modules):
|
134 |
+
self.seqband.append(
|
135 |
+
ResidualRNN(
|
136 |
+
emb_dim=emb_dim,
|
137 |
+
rnn_dim=rnn_dim,
|
138 |
+
bidirectional=bidirectional,
|
139 |
+
rnn_type=rnn_type,
|
140 |
+
)
|
141 |
+
)
|
142 |
+
|
143 |
+
self.parallel_mode = parallel_mode
|
144 |
+
|
145 |
+
def forward(self, z):
|
146 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
147 |
+
|
148 |
+
if self.parallel_mode:
|
149 |
+
for sbm_pair in self.seqband:
|
150 |
+
# z: (batch, n_bands, n_time, emb_dim)
|
151 |
+
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
152 |
+
zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
|
153 |
+
zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
|
154 |
+
z = zt + zf.transpose(1, 2)
|
155 |
+
else:
|
156 |
+
for sbm in self.seqband:
|
157 |
+
z = sbm(z)
|
158 |
+
z = z.transpose(1, 2)
|
159 |
+
|
160 |
+
# (batch, n_bands, n_time, emb_dim)
|
161 |
+
# --> (batch, n_time, n_bands, emb_dim)
|
162 |
+
# OR
|
163 |
+
# (batch, n_time, n_bands, emb_dim)
|
164 |
+
# --> (batch, n_bands, n_time, emb_dim)
|
165 |
+
|
166 |
+
q = z
|
167 |
+
return q # (batch, n_bands, n_time, emb_dim)
|
168 |
+
|
169 |
+
|
170 |
+
class ResidualTransformer(nn.Module):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
emb_dim: int = 128,
|
174 |
+
rnn_dim: int = 256,
|
175 |
+
bidirectional: bool = True,
|
176 |
+
dropout: float = 0.0,
|
177 |
+
) -> None:
|
178 |
+
# n_group is the size of the 2nd dim
|
179 |
+
super().__init__()
|
180 |
+
|
181 |
+
self.tf = nn.TransformerEncoderLayer(
|
182 |
+
d_model=emb_dim,
|
183 |
+
nhead=4,
|
184 |
+
dim_feedforward=rnn_dim,
|
185 |
+
batch_first=True
|
186 |
+
)
|
187 |
+
|
188 |
+
self.is_causal = not bidirectional
|
189 |
+
self.dropout = dropout
|
190 |
+
|
191 |
+
def forward(self, z):
|
192 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
193 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
194 |
+
z = self.tf(z, is_causal=self.is_causal) # (batch, n_uncrossed, n_across, emb_dim)
|
195 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
|
196 |
+
|
197 |
+
return z
|
198 |
+
|
199 |
+
|
200 |
+
class TransformerTimeFreqModule(TimeFrequencyModellingModule):
|
201 |
+
def __init__(
|
202 |
+
self,
|
203 |
+
n_modules: int = 12,
|
204 |
+
emb_dim: int = 128,
|
205 |
+
rnn_dim: int = 256,
|
206 |
+
bidirectional: bool = True,
|
207 |
+
dropout: float = 0.0,
|
208 |
+
) -> None:
|
209 |
+
super().__init__()
|
210 |
+
self.norm = nn.LayerNorm(emb_dim)
|
211 |
+
self.seqband = nn.ModuleList([])
|
212 |
+
|
213 |
+
for _ in range(2 * n_modules):
|
214 |
+
self.seqband.append(
|
215 |
+
ResidualTransformer(
|
216 |
+
emb_dim=emb_dim,
|
217 |
+
rnn_dim=rnn_dim,
|
218 |
+
bidirectional=bidirectional,
|
219 |
+
dropout=dropout,
|
220 |
+
)
|
221 |
+
)
|
222 |
+
|
223 |
+
def forward(self, z):
|
224 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
225 |
+
z = self.norm(z) # (batch, n_bands, n_time, emb_dim)
|
226 |
+
|
227 |
+
for sbm in self.seqband:
|
228 |
+
z = sbm(z)
|
229 |
+
z = z.transpose(1, 2)
|
230 |
+
|
231 |
+
# (batch, n_bands, n_time, emb_dim)
|
232 |
+
# --> (batch, n_time, n_bands, emb_dim)
|
233 |
+
# OR
|
234 |
+
# (batch, n_time, n_bands, emb_dim)
|
235 |
+
# --> (batch, n_bands, n_time, emb_dim)
|
236 |
+
|
237 |
+
q = z
|
238 |
+
return q # (batch, n_bands, n_time, emb_dim)
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
class ResidualConvolution(nn.Module):
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
emb_dim: int = 128,
|
246 |
+
rnn_dim: int = 256,
|
247 |
+
bidirectional: bool = True,
|
248 |
+
dropout: float = 0.0,
|
249 |
+
) -> None:
|
250 |
+
# n_group is the size of the 2nd dim
|
251 |
+
super().__init__()
|
252 |
+
self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
|
253 |
+
|
254 |
+
self.conv = nn.Sequential(
|
255 |
+
nn.Conv2d(
|
256 |
+
in_channels=emb_dim,
|
257 |
+
out_channels=rnn_dim,
|
258 |
+
kernel_size=(3, 3),
|
259 |
+
padding="same",
|
260 |
+
stride=(1, 1),
|
261 |
+
),
|
262 |
+
nn.Tanhshrink()
|
263 |
+
)
|
264 |
+
|
265 |
+
self.is_causal = not bidirectional
|
266 |
+
self.dropout = dropout
|
267 |
+
|
268 |
+
self.fc = nn.Conv2d(
|
269 |
+
in_channels=rnn_dim,
|
270 |
+
out_channels=emb_dim,
|
271 |
+
kernel_size=(1, 1),
|
272 |
+
padding="same",
|
273 |
+
stride=(1, 1),
|
274 |
+
)
|
275 |
+
|
276 |
+
|
277 |
+
def forward(self, z):
|
278 |
+
# z = (batch, n_uncrossed, n_across, emb_dim)
|
279 |
+
|
280 |
+
z0 = torch.clone(z)
|
281 |
+
|
282 |
+
z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
|
283 |
+
z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim)
|
284 |
+
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
|
285 |
+
z = z + z0
|
286 |
+
|
287 |
+
return z
|
288 |
+
|
289 |
+
|
290 |
+
class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
|
291 |
+
def __init__(
|
292 |
+
self,
|
293 |
+
n_modules: int = 12,
|
294 |
+
emb_dim: int = 128,
|
295 |
+
rnn_dim: int = 256,
|
296 |
+
bidirectional: bool = True,
|
297 |
+
dropout: float = 0.0,
|
298 |
+
) -> None:
|
299 |
+
super().__init__()
|
300 |
+
self.seqband = torch.jit.script(nn.Sequential(
|
301 |
+
*[ResidualConvolution(
|
302 |
+
emb_dim=emb_dim,
|
303 |
+
rnn_dim=rnn_dim,
|
304 |
+
bidirectional=bidirectional,
|
305 |
+
dropout=dropout,
|
306 |
+
) for _ in range(2 * n_modules) ]))
|
307 |
+
|
308 |
+
def forward(self, z):
|
309 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
310 |
+
|
311 |
+
z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
|
312 |
+
|
313 |
+
z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
|
314 |
+
|
315 |
+
z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
|
316 |
+
|
317 |
+
return z
|
models/bandit/core/model/bsrnn/utils.py
ADDED
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import abstractmethod
|
3 |
+
from typing import Any, Callable
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from librosa import hz_to_midi, midi_to_hz
|
8 |
+
from torch import Tensor
|
9 |
+
from torchaudio import functional as taF
|
10 |
+
from spafe.fbanks import bark_fbanks
|
11 |
+
from spafe.utils.converters import erb2hz, hz2bark, hz2erb
|
12 |
+
from torchaudio.functional.functional import _create_triangular_filterbank
|
13 |
+
|
14 |
+
|
15 |
+
def band_widths_from_specs(band_specs):
|
16 |
+
return [e - i for i, e in band_specs]
|
17 |
+
|
18 |
+
|
19 |
+
def check_nonzero_bandwidth(band_specs):
|
20 |
+
# pprint(band_specs)
|
21 |
+
for fstart, fend in band_specs:
|
22 |
+
if fend - fstart <= 0:
|
23 |
+
raise ValueError("Bands cannot be zero-width")
|
24 |
+
|
25 |
+
|
26 |
+
def check_no_overlap(band_specs):
|
27 |
+
fend_prev = -1
|
28 |
+
for fstart_curr, fend_curr in band_specs:
|
29 |
+
if fstart_curr <= fend_prev:
|
30 |
+
raise ValueError("Bands cannot overlap")
|
31 |
+
|
32 |
+
|
33 |
+
def check_no_gap(band_specs):
|
34 |
+
fstart, _ = band_specs[0]
|
35 |
+
assert fstart == 0
|
36 |
+
|
37 |
+
fend_prev = -1
|
38 |
+
for fstart_curr, fend_curr in band_specs:
|
39 |
+
if fstart_curr - fend_prev > 1:
|
40 |
+
raise ValueError("Bands cannot leave gap")
|
41 |
+
fend_prev = fend_curr
|
42 |
+
|
43 |
+
|
44 |
+
class BandsplitSpecification:
|
45 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
46 |
+
self.fs = fs
|
47 |
+
self.nfft = nfft
|
48 |
+
self.nyquist = fs / 2
|
49 |
+
self.max_index = nfft // 2 + 1
|
50 |
+
|
51 |
+
self.split500 = self.hertz_to_index(500)
|
52 |
+
self.split1k = self.hertz_to_index(1000)
|
53 |
+
self.split2k = self.hertz_to_index(2000)
|
54 |
+
self.split4k = self.hertz_to_index(4000)
|
55 |
+
self.split8k = self.hertz_to_index(8000)
|
56 |
+
self.split16k = self.hertz_to_index(16000)
|
57 |
+
self.split20k = self.hertz_to_index(20000)
|
58 |
+
|
59 |
+
self.above20k = [(self.split20k, self.max_index)]
|
60 |
+
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
61 |
+
|
62 |
+
def index_to_hertz(self, index: int):
|
63 |
+
return index * self.fs / self.nfft
|
64 |
+
|
65 |
+
def hertz_to_index(self, hz: float, round: bool = True):
|
66 |
+
index = hz * self.nfft / self.fs
|
67 |
+
|
68 |
+
if round:
|
69 |
+
index = int(np.round(index))
|
70 |
+
|
71 |
+
return index
|
72 |
+
|
73 |
+
def get_band_specs_with_bandwidth(
|
74 |
+
self,
|
75 |
+
start_index,
|
76 |
+
end_index,
|
77 |
+
bandwidth_hz
|
78 |
+
):
|
79 |
+
band_specs = []
|
80 |
+
lower = start_index
|
81 |
+
|
82 |
+
while lower < end_index:
|
83 |
+
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
84 |
+
upper = min(upper, end_index)
|
85 |
+
|
86 |
+
band_specs.append((lower, upper))
|
87 |
+
lower = upper
|
88 |
+
|
89 |
+
return band_specs
|
90 |
+
|
91 |
+
@abstractmethod
|
92 |
+
def get_band_specs(self):
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
|
96 |
+
class VocalBandsplitSpecification(BandsplitSpecification):
|
97 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
98 |
+
super().__init__(nfft=nfft, fs=fs)
|
99 |
+
|
100 |
+
self.version = version
|
101 |
+
|
102 |
+
def get_band_specs(self):
|
103 |
+
return getattr(self, f"version{self.version}")()
|
104 |
+
|
105 |
+
@property
|
106 |
+
def version1(self):
|
107 |
+
return self.get_band_specs_with_bandwidth(
|
108 |
+
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
109 |
+
)
|
110 |
+
|
111 |
+
def version2(self):
|
112 |
+
below16k = self.get_band_specs_with_bandwidth(
|
113 |
+
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
114 |
+
)
|
115 |
+
below20k = self.get_band_specs_with_bandwidth(
|
116 |
+
start_index=self.split16k,
|
117 |
+
end_index=self.split20k,
|
118 |
+
bandwidth_hz=2000
|
119 |
+
)
|
120 |
+
|
121 |
+
return below16k + below20k + self.above20k
|
122 |
+
|
123 |
+
def version3(self):
|
124 |
+
below8k = self.get_band_specs_with_bandwidth(
|
125 |
+
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
126 |
+
)
|
127 |
+
below16k = self.get_band_specs_with_bandwidth(
|
128 |
+
start_index=self.split8k,
|
129 |
+
end_index=self.split16k,
|
130 |
+
bandwidth_hz=2000
|
131 |
+
)
|
132 |
+
|
133 |
+
return below8k + below16k + self.above16k
|
134 |
+
|
135 |
+
def version4(self):
|
136 |
+
below1k = self.get_band_specs_with_bandwidth(
|
137 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
138 |
+
)
|
139 |
+
below8k = self.get_band_specs_with_bandwidth(
|
140 |
+
start_index=self.split1k,
|
141 |
+
end_index=self.split8k,
|
142 |
+
bandwidth_hz=1000
|
143 |
+
)
|
144 |
+
below16k = self.get_band_specs_with_bandwidth(
|
145 |
+
start_index=self.split8k,
|
146 |
+
end_index=self.split16k,
|
147 |
+
bandwidth_hz=2000
|
148 |
+
)
|
149 |
+
|
150 |
+
return below1k + below8k + below16k + self.above16k
|
151 |
+
|
152 |
+
def version5(self):
|
153 |
+
below1k = self.get_band_specs_with_bandwidth(
|
154 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
155 |
+
)
|
156 |
+
below16k = self.get_band_specs_with_bandwidth(
|
157 |
+
start_index=self.split1k,
|
158 |
+
end_index=self.split16k,
|
159 |
+
bandwidth_hz=1000
|
160 |
+
)
|
161 |
+
below20k = self.get_band_specs_with_bandwidth(
|
162 |
+
start_index=self.split16k,
|
163 |
+
end_index=self.split20k,
|
164 |
+
bandwidth_hz=2000
|
165 |
+
)
|
166 |
+
return below1k + below16k + below20k + self.above20k
|
167 |
+
|
168 |
+
def version6(self):
|
169 |
+
below1k = self.get_band_specs_with_bandwidth(
|
170 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
171 |
+
)
|
172 |
+
below4k = self.get_band_specs_with_bandwidth(
|
173 |
+
start_index=self.split1k,
|
174 |
+
end_index=self.split4k,
|
175 |
+
bandwidth_hz=500
|
176 |
+
)
|
177 |
+
below8k = self.get_band_specs_with_bandwidth(
|
178 |
+
start_index=self.split4k,
|
179 |
+
end_index=self.split8k,
|
180 |
+
bandwidth_hz=1000
|
181 |
+
)
|
182 |
+
below16k = self.get_band_specs_with_bandwidth(
|
183 |
+
start_index=self.split8k,
|
184 |
+
end_index=self.split16k,
|
185 |
+
bandwidth_hz=2000
|
186 |
+
)
|
187 |
+
return below1k + below4k + below8k + below16k + self.above16k
|
188 |
+
|
189 |
+
def version7(self):
|
190 |
+
below1k = self.get_band_specs_with_bandwidth(
|
191 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
192 |
+
)
|
193 |
+
below4k = self.get_band_specs_with_bandwidth(
|
194 |
+
start_index=self.split1k,
|
195 |
+
end_index=self.split4k,
|
196 |
+
bandwidth_hz=250
|
197 |
+
)
|
198 |
+
below8k = self.get_band_specs_with_bandwidth(
|
199 |
+
start_index=self.split4k,
|
200 |
+
end_index=self.split8k,
|
201 |
+
bandwidth_hz=500
|
202 |
+
)
|
203 |
+
below16k = self.get_band_specs_with_bandwidth(
|
204 |
+
start_index=self.split8k,
|
205 |
+
end_index=self.split16k,
|
206 |
+
bandwidth_hz=1000
|
207 |
+
)
|
208 |
+
below20k = self.get_band_specs_with_bandwidth(
|
209 |
+
start_index=self.split16k,
|
210 |
+
end_index=self.split20k,
|
211 |
+
bandwidth_hz=2000
|
212 |
+
)
|
213 |
+
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
214 |
+
|
215 |
+
|
216 |
+
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
217 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
218 |
+
super().__init__(nfft=nfft, fs=fs, version="7")
|
219 |
+
|
220 |
+
|
221 |
+
class BassBandsplitSpecification(BandsplitSpecification):
|
222 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
223 |
+
super().__init__(nfft=nfft, fs=fs)
|
224 |
+
|
225 |
+
def get_band_specs(self):
|
226 |
+
below500 = self.get_band_specs_with_bandwidth(
|
227 |
+
start_index=0, end_index=self.split500, bandwidth_hz=50
|
228 |
+
)
|
229 |
+
below1k = self.get_band_specs_with_bandwidth(
|
230 |
+
start_index=self.split500,
|
231 |
+
end_index=self.split1k,
|
232 |
+
bandwidth_hz=100
|
233 |
+
)
|
234 |
+
below4k = self.get_band_specs_with_bandwidth(
|
235 |
+
start_index=self.split1k,
|
236 |
+
end_index=self.split4k,
|
237 |
+
bandwidth_hz=500
|
238 |
+
)
|
239 |
+
below8k = self.get_band_specs_with_bandwidth(
|
240 |
+
start_index=self.split4k,
|
241 |
+
end_index=self.split8k,
|
242 |
+
bandwidth_hz=1000
|
243 |
+
)
|
244 |
+
below16k = self.get_band_specs_with_bandwidth(
|
245 |
+
start_index=self.split8k,
|
246 |
+
end_index=self.split16k,
|
247 |
+
bandwidth_hz=2000
|
248 |
+
)
|
249 |
+
above16k = [(self.split16k, self.max_index)]
|
250 |
+
|
251 |
+
return below500 + below1k + below4k + below8k + below16k + above16k
|
252 |
+
|
253 |
+
|
254 |
+
class DrumBandsplitSpecification(BandsplitSpecification):
|
255 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
256 |
+
super().__init__(nfft=nfft, fs=fs)
|
257 |
+
|
258 |
+
def get_band_specs(self):
|
259 |
+
below1k = self.get_band_specs_with_bandwidth(
|
260 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
261 |
+
)
|
262 |
+
below2k = self.get_band_specs_with_bandwidth(
|
263 |
+
start_index=self.split1k,
|
264 |
+
end_index=self.split2k,
|
265 |
+
bandwidth_hz=100
|
266 |
+
)
|
267 |
+
below4k = self.get_band_specs_with_bandwidth(
|
268 |
+
start_index=self.split2k,
|
269 |
+
end_index=self.split4k,
|
270 |
+
bandwidth_hz=250
|
271 |
+
)
|
272 |
+
below8k = self.get_band_specs_with_bandwidth(
|
273 |
+
start_index=self.split4k,
|
274 |
+
end_index=self.split8k,
|
275 |
+
bandwidth_hz=500
|
276 |
+
)
|
277 |
+
below16k = self.get_band_specs_with_bandwidth(
|
278 |
+
start_index=self.split8k,
|
279 |
+
end_index=self.split16k,
|
280 |
+
bandwidth_hz=1000
|
281 |
+
)
|
282 |
+
above16k = [(self.split16k, self.max_index)]
|
283 |
+
|
284 |
+
return below1k + below2k + below4k + below8k + below16k + above16k
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
|
289 |
+
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
290 |
+
def __init__(
|
291 |
+
self,
|
292 |
+
nfft: int,
|
293 |
+
fs: int,
|
294 |
+
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
295 |
+
n_bands: int,
|
296 |
+
f_min: float = 0.0,
|
297 |
+
f_max: float = None
|
298 |
+
) -> None:
|
299 |
+
super().__init__(nfft=nfft, fs=fs)
|
300 |
+
self.n_bands = n_bands
|
301 |
+
if f_max is None:
|
302 |
+
f_max = fs / 2
|
303 |
+
|
304 |
+
self.filterbank = fbank_fn(
|
305 |
+
n_bands, fs, f_min, f_max, self.max_index
|
306 |
+
)
|
307 |
+
|
308 |
+
weight_per_bin = torch.sum(
|
309 |
+
self.filterbank,
|
310 |
+
dim=0,
|
311 |
+
keepdim=True
|
312 |
+
) # (1, n_freqs)
|
313 |
+
normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
|
314 |
+
|
315 |
+
freq_weights = []
|
316 |
+
band_specs = []
|
317 |
+
for i in range(self.n_bands):
|
318 |
+
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
319 |
+
if isinstance(active_bins, int):
|
320 |
+
active_bins = (active_bins, active_bins)
|
321 |
+
if len(active_bins) == 0:
|
322 |
+
continue
|
323 |
+
start_index = active_bins[0]
|
324 |
+
end_index = active_bins[-1] + 1
|
325 |
+
band_specs.append((start_index, end_index))
|
326 |
+
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
327 |
+
|
328 |
+
self.freq_weights = freq_weights
|
329 |
+
self.band_specs = band_specs
|
330 |
+
|
331 |
+
def get_band_specs(self):
|
332 |
+
return self.band_specs
|
333 |
+
|
334 |
+
def get_freq_weights(self):
|
335 |
+
return self.freq_weights
|
336 |
+
|
337 |
+
def save_to_file(self, dir_path: str) -> None:
|
338 |
+
|
339 |
+
os.makedirs(dir_path, exist_ok=True)
|
340 |
+
|
341 |
+
import pickle
|
342 |
+
|
343 |
+
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
344 |
+
pickle.dump(
|
345 |
+
{
|
346 |
+
"band_specs": self.band_specs,
|
347 |
+
"freq_weights": self.freq_weights,
|
348 |
+
"filterbank": self.filterbank,
|
349 |
+
},
|
350 |
+
f,
|
351 |
+
)
|
352 |
+
|
353 |
+
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
354 |
+
fb = taF.melscale_fbanks(
|
355 |
+
n_mels=n_bands,
|
356 |
+
sample_rate=fs,
|
357 |
+
f_min=f_min,
|
358 |
+
f_max=f_max,
|
359 |
+
n_freqs=n_freqs,
|
360 |
+
).T
|
361 |
+
|
362 |
+
fb[0, 0] = 1.0
|
363 |
+
|
364 |
+
return fb
|
365 |
+
|
366 |
+
|
367 |
+
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
368 |
+
def __init__(
|
369 |
+
self,
|
370 |
+
nfft: int,
|
371 |
+
fs: int,
|
372 |
+
n_bands: int,
|
373 |
+
f_min: float = 0.0,
|
374 |
+
f_max: float = None
|
375 |
+
) -> None:
|
376 |
+
super().__init__(fbank_fn=mel_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
377 |
+
|
378 |
+
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs,
|
379 |
+
scale="constant"):
|
380 |
+
|
381 |
+
nfft = 2 * (n_freqs - 1)
|
382 |
+
df = fs / nfft
|
383 |
+
# init freqs
|
384 |
+
f_max = f_max or fs / 2
|
385 |
+
f_min = f_min or 0
|
386 |
+
f_min = fs / nfft
|
387 |
+
|
388 |
+
n_octaves = np.log2(f_max / f_min)
|
389 |
+
n_octaves_per_band = n_octaves / n_bands
|
390 |
+
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
391 |
+
|
392 |
+
low_midi = max(0, hz_to_midi(f_min))
|
393 |
+
high_midi = hz_to_midi(f_max)
|
394 |
+
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
395 |
+
hz_pts = midi_to_hz(midi_points)
|
396 |
+
|
397 |
+
low_pts = hz_pts / bandwidth_mult
|
398 |
+
high_pts = hz_pts * bandwidth_mult
|
399 |
+
|
400 |
+
low_bins = np.floor(low_pts / df).astype(int)
|
401 |
+
high_bins = np.ceil(high_pts / df).astype(int)
|
402 |
+
|
403 |
+
fb = np.zeros((n_bands, n_freqs))
|
404 |
+
|
405 |
+
for i in range(n_bands):
|
406 |
+
fb[i, low_bins[i]:high_bins[i]+1] = 1.0
|
407 |
+
|
408 |
+
fb[0, :low_bins[0]] = 1.0
|
409 |
+
fb[-1, high_bins[-1]+1:] = 1.0
|
410 |
+
|
411 |
+
return torch.as_tensor(fb)
|
412 |
+
|
413 |
+
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
414 |
+
def __init__(
|
415 |
+
self,
|
416 |
+
nfft: int,
|
417 |
+
fs: int,
|
418 |
+
n_bands: int,
|
419 |
+
f_min: float = 0.0,
|
420 |
+
f_max: float = None
|
421 |
+
) -> None:
|
422 |
+
super().__init__(fbank_fn=musical_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
423 |
+
|
424 |
+
|
425 |
+
def bark_filterbank(
|
426 |
+
n_bands, fs, f_min, f_max, n_freqs
|
427 |
+
):
|
428 |
+
nfft = 2 * (n_freqs -1)
|
429 |
+
fb, _ = bark_fbanks.bark_filter_banks(
|
430 |
+
nfilts=n_bands,
|
431 |
+
nfft=nfft,
|
432 |
+
fs=fs,
|
433 |
+
low_freq=f_min,
|
434 |
+
high_freq=f_max,
|
435 |
+
scale="constant"
|
436 |
+
)
|
437 |
+
|
438 |
+
return torch.as_tensor(fb)
|
439 |
+
|
440 |
+
class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
441 |
+
def __init__(
|
442 |
+
self,
|
443 |
+
nfft: int,
|
444 |
+
fs: int,
|
445 |
+
n_bands: int,
|
446 |
+
f_min: float = 0.0,
|
447 |
+
f_max: float = None
|
448 |
+
) -> None:
|
449 |
+
super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
450 |
+
|
451 |
+
|
452 |
+
def triangular_bark_filterbank(
|
453 |
+
n_bands, fs, f_min, f_max, n_freqs
|
454 |
+
):
|
455 |
+
|
456 |
+
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
457 |
+
|
458 |
+
# calculate mel freq bins
|
459 |
+
m_min = hz2bark(f_min)
|
460 |
+
m_max = hz2bark(f_max)
|
461 |
+
|
462 |
+
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
463 |
+
f_pts = 600 * torch.sinh(m_pts / 6)
|
464 |
+
|
465 |
+
# create filterbank
|
466 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
467 |
+
|
468 |
+
fb = fb.T
|
469 |
+
|
470 |
+
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
471 |
+
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
472 |
+
|
473 |
+
fb[first_active_band, :first_active_bin] = 1.0
|
474 |
+
|
475 |
+
return fb
|
476 |
+
|
477 |
+
class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
478 |
+
def __init__(
|
479 |
+
self,
|
480 |
+
nfft: int,
|
481 |
+
fs: int,
|
482 |
+
n_bands: int,
|
483 |
+
f_min: float = 0.0,
|
484 |
+
f_max: float = None
|
485 |
+
) -> None:
|
486 |
+
super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
487 |
+
|
488 |
+
|
489 |
+
|
490 |
+
def minibark_filterbank(
|
491 |
+
n_bands, fs, f_min, f_max, n_freqs
|
492 |
+
):
|
493 |
+
fb = bark_filterbank(
|
494 |
+
n_bands,
|
495 |
+
fs,
|
496 |
+
f_min,
|
497 |
+
f_max,
|
498 |
+
n_freqs
|
499 |
+
)
|
500 |
+
|
501 |
+
fb[fb < np.sqrt(0.5)] = 0.0
|
502 |
+
|
503 |
+
return fb
|
504 |
+
|
505 |
+
class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
506 |
+
def __init__(
|
507 |
+
self,
|
508 |
+
nfft: int,
|
509 |
+
fs: int,
|
510 |
+
n_bands: int,
|
511 |
+
f_min: float = 0.0,
|
512 |
+
f_max: float = None
|
513 |
+
) -> None:
|
514 |
+
super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
515 |
+
|
516 |
+
|
517 |
+
|
518 |
+
|
519 |
+
|
520 |
+
def erb_filterbank(
|
521 |
+
n_bands: int,
|
522 |
+
fs: int,
|
523 |
+
f_min: float,
|
524 |
+
f_max: float,
|
525 |
+
n_freqs: int,
|
526 |
+
) -> Tensor:
|
527 |
+
# freq bins
|
528 |
+
A = (1000 * np.log(10)) / (24.7 * 4.37)
|
529 |
+
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
530 |
+
|
531 |
+
# calculate mel freq bins
|
532 |
+
m_min = hz2erb(f_min)
|
533 |
+
m_max = hz2erb(f_max)
|
534 |
+
|
535 |
+
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
536 |
+
f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
|
537 |
+
|
538 |
+
# create filterbank
|
539 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
540 |
+
|
541 |
+
fb = fb.T
|
542 |
+
|
543 |
+
|
544 |
+
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
545 |
+
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
546 |
+
|
547 |
+
fb[first_active_band, :first_active_bin] = 1.0
|
548 |
+
|
549 |
+
return fb
|
550 |
+
|
551 |
+
|
552 |
+
|
553 |
+
class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
|
554 |
+
def __init__(
|
555 |
+
self,
|
556 |
+
nfft: int,
|
557 |
+
fs: int,
|
558 |
+
n_bands: int,
|
559 |
+
f_min: float = 0.0,
|
560 |
+
f_max: float = None
|
561 |
+
) -> None:
|
562 |
+
super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
563 |
+
|
564 |
+
if __name__ == "__main__":
|
565 |
+
import pandas as pd
|
566 |
+
|
567 |
+
band_defs = []
|
568 |
+
|
569 |
+
for bands in [VocalBandsplitSpecification]:
|
570 |
+
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
571 |
+
|
572 |
+
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
573 |
+
|
574 |
+
for i, (f_min, f_max) in enumerate(mbs):
|
575 |
+
band_defs.append({
|
576 |
+
"band": band_name,
|
577 |
+
"band_index": i,
|
578 |
+
"f_min": f_min,
|
579 |
+
"f_max": f_max
|
580 |
+
})
|
581 |
+
|
582 |
+
df = pd.DataFrame(band_defs)
|
583 |
+
df.to_csv("vox7bands.csv", index=False)
|
models/bandit/core/model/bsrnn/wrapper.py
ADDED
@@ -0,0 +1,882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pprint import pprint
|
2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from models.bandit.core.model._spectral import _SpectralComponent
|
8 |
+
from models.bandit.core.model.bsrnn.utils import (
|
9 |
+
BarkBandsplitSpecification, BassBandsplitSpecification,
|
10 |
+
DrumBandsplitSpecification,
|
11 |
+
EquivalentRectangularBandsplitSpecification, MelBandsplitSpecification,
|
12 |
+
MusicalBandsplitSpecification, OtherBandsplitSpecification,
|
13 |
+
TriangularBarkBandsplitSpecification, VocalBandsplitSpecification,
|
14 |
+
)
|
15 |
+
from .core import (
|
16 |
+
MultiSourceMultiMaskBandSplitCoreConv,
|
17 |
+
MultiSourceMultiMaskBandSplitCoreRNN,
|
18 |
+
MultiSourceMultiMaskBandSplitCoreTransformer,
|
19 |
+
MultiSourceMultiPatchingMaskBandSplitCoreRNN, SingleMaskBandsplitCoreRNN,
|
20 |
+
SingleMaskBandsplitCoreTransformer,
|
21 |
+
)
|
22 |
+
|
23 |
+
import pytorch_lightning as pl
|
24 |
+
|
25 |
+
def get_band_specs(band_specs, n_fft, fs, n_bands=None):
|
26 |
+
if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
|
27 |
+
bsm = VocalBandsplitSpecification(
|
28 |
+
nfft=n_fft, fs=fs
|
29 |
+
).get_band_specs()
|
30 |
+
freq_weights = None
|
31 |
+
overlapping_band = False
|
32 |
+
elif "tribark" in band_specs:
|
33 |
+
assert n_bands is not None
|
34 |
+
specs = TriangularBarkBandsplitSpecification(
|
35 |
+
nfft=n_fft,
|
36 |
+
fs=fs,
|
37 |
+
n_bands=n_bands
|
38 |
+
)
|
39 |
+
bsm = specs.get_band_specs()
|
40 |
+
freq_weights = specs.get_freq_weights()
|
41 |
+
overlapping_band = True
|
42 |
+
elif "bark" in band_specs:
|
43 |
+
assert n_bands is not None
|
44 |
+
specs = BarkBandsplitSpecification(
|
45 |
+
nfft=n_fft,
|
46 |
+
fs=fs,
|
47 |
+
n_bands=n_bands
|
48 |
+
)
|
49 |
+
bsm = specs.get_band_specs()
|
50 |
+
freq_weights = specs.get_freq_weights()
|
51 |
+
overlapping_band = True
|
52 |
+
elif "erb" in band_specs:
|
53 |
+
assert n_bands is not None
|
54 |
+
specs = EquivalentRectangularBandsplitSpecification(
|
55 |
+
nfft=n_fft,
|
56 |
+
fs=fs,
|
57 |
+
n_bands=n_bands
|
58 |
+
)
|
59 |
+
bsm = specs.get_band_specs()
|
60 |
+
freq_weights = specs.get_freq_weights()
|
61 |
+
overlapping_band = True
|
62 |
+
elif "musical" in band_specs:
|
63 |
+
assert n_bands is not None
|
64 |
+
specs = MusicalBandsplitSpecification(
|
65 |
+
nfft=n_fft,
|
66 |
+
fs=fs,
|
67 |
+
n_bands=n_bands
|
68 |
+
)
|
69 |
+
bsm = specs.get_band_specs()
|
70 |
+
freq_weights = specs.get_freq_weights()
|
71 |
+
overlapping_band = True
|
72 |
+
elif band_specs == "dnr:mel" or "mel" in band_specs:
|
73 |
+
assert n_bands is not None
|
74 |
+
specs = MelBandsplitSpecification(
|
75 |
+
nfft=n_fft,
|
76 |
+
fs=fs,
|
77 |
+
n_bands=n_bands
|
78 |
+
)
|
79 |
+
bsm = specs.get_band_specs()
|
80 |
+
freq_weights = specs.get_freq_weights()
|
81 |
+
overlapping_band = True
|
82 |
+
else:
|
83 |
+
raise NameError
|
84 |
+
|
85 |
+
return bsm, freq_weights, overlapping_band
|
86 |
+
|
87 |
+
|
88 |
+
def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
|
89 |
+
if band_specs_map == "musdb:all":
|
90 |
+
bsm = {
|
91 |
+
"vocals": VocalBandsplitSpecification(
|
92 |
+
nfft=n_fft, fs=fs
|
93 |
+
).get_band_specs(),
|
94 |
+
"drums": DrumBandsplitSpecification(
|
95 |
+
nfft=n_fft, fs=fs
|
96 |
+
).get_band_specs(),
|
97 |
+
"bass": BassBandsplitSpecification(
|
98 |
+
nfft=n_fft, fs=fs
|
99 |
+
).get_band_specs(),
|
100 |
+
"other": OtherBandsplitSpecification(
|
101 |
+
nfft=n_fft, fs=fs
|
102 |
+
).get_band_specs(),
|
103 |
+
}
|
104 |
+
freq_weights = None
|
105 |
+
overlapping_band = False
|
106 |
+
elif band_specs_map == "dnr:vox7":
|
107 |
+
bsm_, freq_weights, overlapping_band = get_band_specs(
|
108 |
+
"dnr:speech", n_fft, fs, n_bands
|
109 |
+
)
|
110 |
+
bsm = {
|
111 |
+
"speech": bsm_,
|
112 |
+
"music": bsm_,
|
113 |
+
"effects": bsm_
|
114 |
+
}
|
115 |
+
elif "dnr:vox7:" in band_specs_map:
|
116 |
+
stem = band_specs_map.split(":")[-1]
|
117 |
+
bsm_, freq_weights, overlapping_band = get_band_specs(
|
118 |
+
"dnr:speech", n_fft, fs, n_bands
|
119 |
+
)
|
120 |
+
bsm = {
|
121 |
+
stem: bsm_
|
122 |
+
}
|
123 |
+
else:
|
124 |
+
raise NameError
|
125 |
+
|
126 |
+
return bsm, freq_weights, overlapping_band
|
127 |
+
|
128 |
+
|
129 |
+
class BandSplitWrapperBase(pl.LightningModule):
|
130 |
+
bsrnn: nn.Module
|
131 |
+
|
132 |
+
def __init__(self, **kwargs):
|
133 |
+
super().__init__()
|
134 |
+
|
135 |
+
|
136 |
+
class SingleMaskMultiSourceBandSplitBase(
|
137 |
+
BandSplitWrapperBase,
|
138 |
+
_SpectralComponent
|
139 |
+
):
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
143 |
+
fs: int = 44100,
|
144 |
+
n_fft: int = 2048,
|
145 |
+
win_length: Optional[int] = 2048,
|
146 |
+
hop_length: int = 512,
|
147 |
+
window_fn: str = "hann_window",
|
148 |
+
wkwargs: Optional[Dict] = None,
|
149 |
+
power: Optional[int] = None,
|
150 |
+
center: bool = True,
|
151 |
+
normalized: bool = True,
|
152 |
+
pad_mode: str = "constant",
|
153 |
+
onesided: bool = True,
|
154 |
+
n_bands: int = None,
|
155 |
+
) -> None:
|
156 |
+
super().__init__(
|
157 |
+
n_fft=n_fft,
|
158 |
+
win_length=win_length,
|
159 |
+
hop_length=hop_length,
|
160 |
+
window_fn=window_fn,
|
161 |
+
wkwargs=wkwargs,
|
162 |
+
power=power,
|
163 |
+
center=center,
|
164 |
+
normalized=normalized,
|
165 |
+
pad_mode=pad_mode,
|
166 |
+
onesided=onesided,
|
167 |
+
)
|
168 |
+
|
169 |
+
if isinstance(band_specs_map, str):
|
170 |
+
self.band_specs_map, self.freq_weights, self.overlapping_band = get_band_specs_map(
|
171 |
+
band_specs_map,
|
172 |
+
n_fft,
|
173 |
+
fs,
|
174 |
+
n_bands=n_bands
|
175 |
+
)
|
176 |
+
|
177 |
+
self.stems = list(self.band_specs_map.keys())
|
178 |
+
|
179 |
+
def forward(self, batch):
|
180 |
+
audio = batch["audio"]
|
181 |
+
|
182 |
+
with torch.no_grad():
|
183 |
+
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
|
184 |
+
audio}
|
185 |
+
|
186 |
+
X = batch["spectrogram"]["mixture"]
|
187 |
+
length = batch["audio"]["mixture"].shape[-1]
|
188 |
+
|
189 |
+
output = {"spectrogram": {}, "audio": {}}
|
190 |
+
|
191 |
+
for stem, bsrnn in self.bsrnn.items():
|
192 |
+
S = bsrnn(X)
|
193 |
+
s = self.istft(S, length)
|
194 |
+
output["spectrogram"][stem] = S
|
195 |
+
output["audio"][stem] = s
|
196 |
+
|
197 |
+
return batch, output
|
198 |
+
|
199 |
+
|
200 |
+
class MultiMaskMultiSourceBandSplitBase(
|
201 |
+
BandSplitWrapperBase,
|
202 |
+
_SpectralComponent
|
203 |
+
):
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
stems: List[str],
|
207 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
208 |
+
fs: int = 44100,
|
209 |
+
n_fft: int = 2048,
|
210 |
+
win_length: Optional[int] = 2048,
|
211 |
+
hop_length: int = 512,
|
212 |
+
window_fn: str = "hann_window",
|
213 |
+
wkwargs: Optional[Dict] = None,
|
214 |
+
power: Optional[int] = None,
|
215 |
+
center: bool = True,
|
216 |
+
normalized: bool = True,
|
217 |
+
pad_mode: str = "constant",
|
218 |
+
onesided: bool = True,
|
219 |
+
n_bands: int = None,
|
220 |
+
) -> None:
|
221 |
+
super().__init__(
|
222 |
+
n_fft=n_fft,
|
223 |
+
win_length=win_length,
|
224 |
+
hop_length=hop_length,
|
225 |
+
window_fn=window_fn,
|
226 |
+
wkwargs=wkwargs,
|
227 |
+
power=power,
|
228 |
+
center=center,
|
229 |
+
normalized=normalized,
|
230 |
+
pad_mode=pad_mode,
|
231 |
+
onesided=onesided,
|
232 |
+
)
|
233 |
+
|
234 |
+
if isinstance(band_specs, str):
|
235 |
+
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
236 |
+
band_specs,
|
237 |
+
n_fft,
|
238 |
+
fs,
|
239 |
+
n_bands
|
240 |
+
)
|
241 |
+
|
242 |
+
self.stems = stems
|
243 |
+
|
244 |
+
def forward(self, batch):
|
245 |
+
# with torch.no_grad():
|
246 |
+
audio = batch["audio"]
|
247 |
+
cond = batch.get("condition", None)
|
248 |
+
with torch.no_grad():
|
249 |
+
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
|
250 |
+
audio}
|
251 |
+
|
252 |
+
X = batch["spectrogram"]["mixture"]
|
253 |
+
length = batch["audio"]["mixture"].shape[-1]
|
254 |
+
|
255 |
+
output = self.bsrnn(X, cond=cond)
|
256 |
+
output["audio"] = {}
|
257 |
+
|
258 |
+
for stem, S in output["spectrogram"].items():
|
259 |
+
s = self.istft(S, length)
|
260 |
+
output["audio"][stem] = s
|
261 |
+
|
262 |
+
return batch, output
|
263 |
+
|
264 |
+
|
265 |
+
class MultiMaskMultiSourceBandSplitBaseSimple(
|
266 |
+
BandSplitWrapperBase,
|
267 |
+
_SpectralComponent
|
268 |
+
):
|
269 |
+
def __init__(
|
270 |
+
self,
|
271 |
+
stems: List[str],
|
272 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
273 |
+
fs: int = 44100,
|
274 |
+
n_fft: int = 2048,
|
275 |
+
win_length: Optional[int] = 2048,
|
276 |
+
hop_length: int = 512,
|
277 |
+
window_fn: str = "hann_window",
|
278 |
+
wkwargs: Optional[Dict] = None,
|
279 |
+
power: Optional[int] = None,
|
280 |
+
center: bool = True,
|
281 |
+
normalized: bool = True,
|
282 |
+
pad_mode: str = "constant",
|
283 |
+
onesided: bool = True,
|
284 |
+
n_bands: int = None,
|
285 |
+
) -> None:
|
286 |
+
super().__init__(
|
287 |
+
n_fft=n_fft,
|
288 |
+
win_length=win_length,
|
289 |
+
hop_length=hop_length,
|
290 |
+
window_fn=window_fn,
|
291 |
+
wkwargs=wkwargs,
|
292 |
+
power=power,
|
293 |
+
center=center,
|
294 |
+
normalized=normalized,
|
295 |
+
pad_mode=pad_mode,
|
296 |
+
onesided=onesided,
|
297 |
+
)
|
298 |
+
|
299 |
+
if isinstance(band_specs, str):
|
300 |
+
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
301 |
+
band_specs,
|
302 |
+
n_fft,
|
303 |
+
fs,
|
304 |
+
n_bands
|
305 |
+
)
|
306 |
+
|
307 |
+
self.stems = stems
|
308 |
+
|
309 |
+
def forward(self, batch):
|
310 |
+
with torch.no_grad():
|
311 |
+
X = self.stft(batch)
|
312 |
+
length = batch.shape[-1]
|
313 |
+
output = self.bsrnn(X, cond=None)
|
314 |
+
res = []
|
315 |
+
for stem, S in output["spectrogram"].items():
|
316 |
+
s = self.istft(S, length)
|
317 |
+
res.append(s)
|
318 |
+
res = torch.stack(res, dim=1)
|
319 |
+
return res
|
320 |
+
|
321 |
+
|
322 |
+
class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
|
323 |
+
def __init__(
|
324 |
+
self,
|
325 |
+
in_channel: int,
|
326 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
327 |
+
fs: int = 44100,
|
328 |
+
require_no_overlap: bool = False,
|
329 |
+
require_no_gap: bool = True,
|
330 |
+
normalize_channel_independently: bool = False,
|
331 |
+
treat_channel_as_feature: bool = True,
|
332 |
+
n_sqm_modules: int = 12,
|
333 |
+
emb_dim: int = 128,
|
334 |
+
rnn_dim: int = 256,
|
335 |
+
bidirectional: bool = True,
|
336 |
+
rnn_type: str = "LSTM",
|
337 |
+
mlp_dim: int = 512,
|
338 |
+
hidden_activation: str = "Tanh",
|
339 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
340 |
+
complex_mask: bool = True,
|
341 |
+
n_fft: int = 2048,
|
342 |
+
win_length: Optional[int] = 2048,
|
343 |
+
hop_length: int = 512,
|
344 |
+
window_fn: str = "hann_window",
|
345 |
+
wkwargs: Optional[Dict] = None,
|
346 |
+
power: Optional[int] = None,
|
347 |
+
center: bool = True,
|
348 |
+
normalized: bool = True,
|
349 |
+
pad_mode: str = "constant",
|
350 |
+
onesided: bool = True,
|
351 |
+
) -> None:
|
352 |
+
super().__init__(
|
353 |
+
band_specs_map=band_specs_map,
|
354 |
+
fs=fs,
|
355 |
+
n_fft=n_fft,
|
356 |
+
win_length=win_length,
|
357 |
+
hop_length=hop_length,
|
358 |
+
window_fn=window_fn,
|
359 |
+
wkwargs=wkwargs,
|
360 |
+
power=power,
|
361 |
+
center=center,
|
362 |
+
normalized=normalized,
|
363 |
+
pad_mode=pad_mode,
|
364 |
+
onesided=onesided,
|
365 |
+
)
|
366 |
+
|
367 |
+
self.bsrnn = nn.ModuleDict(
|
368 |
+
{
|
369 |
+
src: SingleMaskBandsplitCoreRNN(
|
370 |
+
band_specs=specs,
|
371 |
+
in_channel=in_channel,
|
372 |
+
require_no_overlap=require_no_overlap,
|
373 |
+
require_no_gap=require_no_gap,
|
374 |
+
normalize_channel_independently=normalize_channel_independently,
|
375 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
376 |
+
n_sqm_modules=n_sqm_modules,
|
377 |
+
emb_dim=emb_dim,
|
378 |
+
rnn_dim=rnn_dim,
|
379 |
+
bidirectional=bidirectional,
|
380 |
+
rnn_type=rnn_type,
|
381 |
+
mlp_dim=mlp_dim,
|
382 |
+
hidden_activation=hidden_activation,
|
383 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
384 |
+
complex_mask=complex_mask,
|
385 |
+
)
|
386 |
+
for src, specs in self.band_specs_map.items()
|
387 |
+
}
|
388 |
+
)
|
389 |
+
|
390 |
+
|
391 |
+
class SingleMaskMultiSourceBandSplitTransformer(
|
392 |
+
SingleMaskMultiSourceBandSplitBase
|
393 |
+
):
|
394 |
+
def __init__(
|
395 |
+
self,
|
396 |
+
in_channel: int,
|
397 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
398 |
+
fs: int = 44100,
|
399 |
+
require_no_overlap: bool = False,
|
400 |
+
require_no_gap: bool = True,
|
401 |
+
normalize_channel_independently: bool = False,
|
402 |
+
treat_channel_as_feature: bool = True,
|
403 |
+
n_sqm_modules: int = 12,
|
404 |
+
emb_dim: int = 128,
|
405 |
+
rnn_dim: int = 256,
|
406 |
+
bidirectional: bool = True,
|
407 |
+
tf_dropout: float = 0.0,
|
408 |
+
mlp_dim: int = 512,
|
409 |
+
hidden_activation: str = "Tanh",
|
410 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
411 |
+
complex_mask: bool = True,
|
412 |
+
n_fft: int = 2048,
|
413 |
+
win_length: Optional[int] = 2048,
|
414 |
+
hop_length: int = 512,
|
415 |
+
window_fn: str = "hann_window",
|
416 |
+
wkwargs: Optional[Dict] = None,
|
417 |
+
power: Optional[int] = None,
|
418 |
+
center: bool = True,
|
419 |
+
normalized: bool = True,
|
420 |
+
pad_mode: str = "constant",
|
421 |
+
onesided: bool = True,
|
422 |
+
) -> None:
|
423 |
+
super().__init__(
|
424 |
+
band_specs_map=band_specs_map,
|
425 |
+
fs=fs,
|
426 |
+
n_fft=n_fft,
|
427 |
+
win_length=win_length,
|
428 |
+
hop_length=hop_length,
|
429 |
+
window_fn=window_fn,
|
430 |
+
wkwargs=wkwargs,
|
431 |
+
power=power,
|
432 |
+
center=center,
|
433 |
+
normalized=normalized,
|
434 |
+
pad_mode=pad_mode,
|
435 |
+
onesided=onesided,
|
436 |
+
)
|
437 |
+
|
438 |
+
self.bsrnn = nn.ModuleDict(
|
439 |
+
{
|
440 |
+
src: SingleMaskBandsplitCoreTransformer(
|
441 |
+
band_specs=specs,
|
442 |
+
in_channel=in_channel,
|
443 |
+
require_no_overlap=require_no_overlap,
|
444 |
+
require_no_gap=require_no_gap,
|
445 |
+
normalize_channel_independently=normalize_channel_independently,
|
446 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
447 |
+
n_sqm_modules=n_sqm_modules,
|
448 |
+
emb_dim=emb_dim,
|
449 |
+
rnn_dim=rnn_dim,
|
450 |
+
bidirectional=bidirectional,
|
451 |
+
tf_dropout=tf_dropout,
|
452 |
+
mlp_dim=mlp_dim,
|
453 |
+
hidden_activation=hidden_activation,
|
454 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
455 |
+
complex_mask=complex_mask,
|
456 |
+
)
|
457 |
+
for src, specs in self.band_specs_map.items()
|
458 |
+
}
|
459 |
+
)
|
460 |
+
|
461 |
+
|
462 |
+
class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
463 |
+
def __init__(
|
464 |
+
self,
|
465 |
+
in_channel: int,
|
466 |
+
stems: List[str],
|
467 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
468 |
+
fs: int = 44100,
|
469 |
+
require_no_overlap: bool = False,
|
470 |
+
require_no_gap: bool = True,
|
471 |
+
normalize_channel_independently: bool = False,
|
472 |
+
treat_channel_as_feature: bool = True,
|
473 |
+
n_sqm_modules: int = 12,
|
474 |
+
emb_dim: int = 128,
|
475 |
+
rnn_dim: int = 256,
|
476 |
+
cond_dim: int = 0,
|
477 |
+
bidirectional: bool = True,
|
478 |
+
rnn_type: str = "LSTM",
|
479 |
+
mlp_dim: int = 512,
|
480 |
+
hidden_activation: str = "Tanh",
|
481 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
482 |
+
complex_mask: bool = True,
|
483 |
+
n_fft: int = 2048,
|
484 |
+
win_length: Optional[int] = 2048,
|
485 |
+
hop_length: int = 512,
|
486 |
+
window_fn: str = "hann_window",
|
487 |
+
wkwargs: Optional[Dict] = None,
|
488 |
+
power: Optional[int] = None,
|
489 |
+
center: bool = True,
|
490 |
+
normalized: bool = True,
|
491 |
+
pad_mode: str = "constant",
|
492 |
+
onesided: bool = True,
|
493 |
+
n_bands: int = None,
|
494 |
+
use_freq_weights: bool = True,
|
495 |
+
normalize_input: bool = False,
|
496 |
+
mult_add_mask: bool = False,
|
497 |
+
freeze_encoder: bool = False,
|
498 |
+
) -> None:
|
499 |
+
super().__init__(
|
500 |
+
stems=stems,
|
501 |
+
band_specs=band_specs,
|
502 |
+
fs=fs,
|
503 |
+
n_fft=n_fft,
|
504 |
+
win_length=win_length,
|
505 |
+
hop_length=hop_length,
|
506 |
+
window_fn=window_fn,
|
507 |
+
wkwargs=wkwargs,
|
508 |
+
power=power,
|
509 |
+
center=center,
|
510 |
+
normalized=normalized,
|
511 |
+
pad_mode=pad_mode,
|
512 |
+
onesided=onesided,
|
513 |
+
n_bands=n_bands,
|
514 |
+
)
|
515 |
+
|
516 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
517 |
+
stems=stems,
|
518 |
+
band_specs=self.band_specs,
|
519 |
+
in_channel=in_channel,
|
520 |
+
require_no_overlap=require_no_overlap,
|
521 |
+
require_no_gap=require_no_gap,
|
522 |
+
normalize_channel_independently=normalize_channel_independently,
|
523 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
524 |
+
n_sqm_modules=n_sqm_modules,
|
525 |
+
emb_dim=emb_dim,
|
526 |
+
rnn_dim=rnn_dim,
|
527 |
+
bidirectional=bidirectional,
|
528 |
+
rnn_type=rnn_type,
|
529 |
+
mlp_dim=mlp_dim,
|
530 |
+
cond_dim=cond_dim,
|
531 |
+
hidden_activation=hidden_activation,
|
532 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
533 |
+
complex_mask=complex_mask,
|
534 |
+
overlapping_band=self.overlapping_band,
|
535 |
+
freq_weights=self.freq_weights,
|
536 |
+
n_freq=n_fft // 2 + 1,
|
537 |
+
use_freq_weights=use_freq_weights,
|
538 |
+
mult_add_mask=mult_add_mask
|
539 |
+
)
|
540 |
+
|
541 |
+
self.normalize_input = normalize_input
|
542 |
+
self.cond_dim = cond_dim
|
543 |
+
|
544 |
+
if freeze_encoder:
|
545 |
+
for param in self.bsrnn.band_split.parameters():
|
546 |
+
param.requires_grad = False
|
547 |
+
|
548 |
+
for param in self.bsrnn.tf_model.parameters():
|
549 |
+
param.requires_grad = False
|
550 |
+
|
551 |
+
|
552 |
+
class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
|
553 |
+
def __init__(
|
554 |
+
self,
|
555 |
+
in_channel: int,
|
556 |
+
stems: List[str],
|
557 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
558 |
+
fs: int = 44100,
|
559 |
+
require_no_overlap: bool = False,
|
560 |
+
require_no_gap: bool = True,
|
561 |
+
normalize_channel_independently: bool = False,
|
562 |
+
treat_channel_as_feature: bool = True,
|
563 |
+
n_sqm_modules: int = 12,
|
564 |
+
emb_dim: int = 128,
|
565 |
+
rnn_dim: int = 256,
|
566 |
+
cond_dim: int = 0,
|
567 |
+
bidirectional: bool = True,
|
568 |
+
rnn_type: str = "LSTM",
|
569 |
+
mlp_dim: int = 512,
|
570 |
+
hidden_activation: str = "Tanh",
|
571 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
572 |
+
complex_mask: bool = True,
|
573 |
+
n_fft: int = 2048,
|
574 |
+
win_length: Optional[int] = 2048,
|
575 |
+
hop_length: int = 512,
|
576 |
+
window_fn: str = "hann_window",
|
577 |
+
wkwargs: Optional[Dict] = None,
|
578 |
+
power: Optional[int] = None,
|
579 |
+
center: bool = True,
|
580 |
+
normalized: bool = True,
|
581 |
+
pad_mode: str = "constant",
|
582 |
+
onesided: bool = True,
|
583 |
+
n_bands: int = None,
|
584 |
+
use_freq_weights: bool = True,
|
585 |
+
normalize_input: bool = False,
|
586 |
+
mult_add_mask: bool = False,
|
587 |
+
freeze_encoder: bool = False,
|
588 |
+
) -> None:
|
589 |
+
super().__init__(
|
590 |
+
stems=stems,
|
591 |
+
band_specs=band_specs,
|
592 |
+
fs=fs,
|
593 |
+
n_fft=n_fft,
|
594 |
+
win_length=win_length,
|
595 |
+
hop_length=hop_length,
|
596 |
+
window_fn=window_fn,
|
597 |
+
wkwargs=wkwargs,
|
598 |
+
power=power,
|
599 |
+
center=center,
|
600 |
+
normalized=normalized,
|
601 |
+
pad_mode=pad_mode,
|
602 |
+
onesided=onesided,
|
603 |
+
n_bands=n_bands,
|
604 |
+
)
|
605 |
+
|
606 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
607 |
+
stems=stems,
|
608 |
+
band_specs=self.band_specs,
|
609 |
+
in_channel=in_channel,
|
610 |
+
require_no_overlap=require_no_overlap,
|
611 |
+
require_no_gap=require_no_gap,
|
612 |
+
normalize_channel_independently=normalize_channel_independently,
|
613 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
614 |
+
n_sqm_modules=n_sqm_modules,
|
615 |
+
emb_dim=emb_dim,
|
616 |
+
rnn_dim=rnn_dim,
|
617 |
+
bidirectional=bidirectional,
|
618 |
+
rnn_type=rnn_type,
|
619 |
+
mlp_dim=mlp_dim,
|
620 |
+
cond_dim=cond_dim,
|
621 |
+
hidden_activation=hidden_activation,
|
622 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
623 |
+
complex_mask=complex_mask,
|
624 |
+
overlapping_band=self.overlapping_band,
|
625 |
+
freq_weights=self.freq_weights,
|
626 |
+
n_freq=n_fft // 2 + 1,
|
627 |
+
use_freq_weights=use_freq_weights,
|
628 |
+
mult_add_mask=mult_add_mask
|
629 |
+
)
|
630 |
+
|
631 |
+
self.normalize_input = normalize_input
|
632 |
+
self.cond_dim = cond_dim
|
633 |
+
|
634 |
+
if freeze_encoder:
|
635 |
+
for param in self.bsrnn.band_split.parameters():
|
636 |
+
param.requires_grad = False
|
637 |
+
|
638 |
+
for param in self.bsrnn.tf_model.parameters():
|
639 |
+
param.requires_grad = False
|
640 |
+
|
641 |
+
|
642 |
+
class MultiMaskMultiSourceBandSplitTransformer(
|
643 |
+
MultiMaskMultiSourceBandSplitBase
|
644 |
+
):
|
645 |
+
def __init__(
|
646 |
+
self,
|
647 |
+
in_channel: int,
|
648 |
+
stems: List[str],
|
649 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
650 |
+
fs: int = 44100,
|
651 |
+
require_no_overlap: bool = False,
|
652 |
+
require_no_gap: bool = True,
|
653 |
+
normalize_channel_independently: bool = False,
|
654 |
+
treat_channel_as_feature: bool = True,
|
655 |
+
n_sqm_modules: int = 12,
|
656 |
+
emb_dim: int = 128,
|
657 |
+
rnn_dim: int = 256,
|
658 |
+
cond_dim: int = 0,
|
659 |
+
bidirectional: bool = True,
|
660 |
+
rnn_type: str = "LSTM",
|
661 |
+
mlp_dim: int = 512,
|
662 |
+
hidden_activation: str = "Tanh",
|
663 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
664 |
+
complex_mask: bool = True,
|
665 |
+
n_fft: int = 2048,
|
666 |
+
win_length: Optional[int] = 2048,
|
667 |
+
hop_length: int = 512,
|
668 |
+
window_fn: str = "hann_window",
|
669 |
+
wkwargs: Optional[Dict] = None,
|
670 |
+
power: Optional[int] = None,
|
671 |
+
center: bool = True,
|
672 |
+
normalized: bool = True,
|
673 |
+
pad_mode: str = "constant",
|
674 |
+
onesided: bool = True,
|
675 |
+
n_bands: int = None,
|
676 |
+
use_freq_weights: bool = True,
|
677 |
+
normalize_input: bool = False,
|
678 |
+
mult_add_mask: bool = False
|
679 |
+
) -> None:
|
680 |
+
super().__init__(
|
681 |
+
stems=stems,
|
682 |
+
band_specs=band_specs,
|
683 |
+
fs=fs,
|
684 |
+
n_fft=n_fft,
|
685 |
+
win_length=win_length,
|
686 |
+
hop_length=hop_length,
|
687 |
+
window_fn=window_fn,
|
688 |
+
wkwargs=wkwargs,
|
689 |
+
power=power,
|
690 |
+
center=center,
|
691 |
+
normalized=normalized,
|
692 |
+
pad_mode=pad_mode,
|
693 |
+
onesided=onesided,
|
694 |
+
n_bands=n_bands,
|
695 |
+
)
|
696 |
+
|
697 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
|
698 |
+
stems=stems,
|
699 |
+
band_specs=self.band_specs,
|
700 |
+
in_channel=in_channel,
|
701 |
+
require_no_overlap=require_no_overlap,
|
702 |
+
require_no_gap=require_no_gap,
|
703 |
+
normalize_channel_independently=normalize_channel_independently,
|
704 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
705 |
+
n_sqm_modules=n_sqm_modules,
|
706 |
+
emb_dim=emb_dim,
|
707 |
+
rnn_dim=rnn_dim,
|
708 |
+
bidirectional=bidirectional,
|
709 |
+
rnn_type=rnn_type,
|
710 |
+
mlp_dim=mlp_dim,
|
711 |
+
cond_dim=cond_dim,
|
712 |
+
hidden_activation=hidden_activation,
|
713 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
714 |
+
complex_mask=complex_mask,
|
715 |
+
overlapping_band=self.overlapping_band,
|
716 |
+
freq_weights=self.freq_weights,
|
717 |
+
n_freq=n_fft // 2 + 1,
|
718 |
+
use_freq_weights=use_freq_weights,
|
719 |
+
mult_add_mask=mult_add_mask
|
720 |
+
)
|
721 |
+
|
722 |
+
|
723 |
+
|
724 |
+
class MultiMaskMultiSourceBandSplitConv(
|
725 |
+
MultiMaskMultiSourceBandSplitBase
|
726 |
+
):
|
727 |
+
def __init__(
|
728 |
+
self,
|
729 |
+
in_channel: int,
|
730 |
+
stems: List[str],
|
731 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
732 |
+
fs: int = 44100,
|
733 |
+
require_no_overlap: bool = False,
|
734 |
+
require_no_gap: bool = True,
|
735 |
+
normalize_channel_independently: bool = False,
|
736 |
+
treat_channel_as_feature: bool = True,
|
737 |
+
n_sqm_modules: int = 12,
|
738 |
+
emb_dim: int = 128,
|
739 |
+
rnn_dim: int = 256,
|
740 |
+
cond_dim: int = 0,
|
741 |
+
bidirectional: bool = True,
|
742 |
+
rnn_type: str = "LSTM",
|
743 |
+
mlp_dim: int = 512,
|
744 |
+
hidden_activation: str = "Tanh",
|
745 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
746 |
+
complex_mask: bool = True,
|
747 |
+
n_fft: int = 2048,
|
748 |
+
win_length: Optional[int] = 2048,
|
749 |
+
hop_length: int = 512,
|
750 |
+
window_fn: str = "hann_window",
|
751 |
+
wkwargs: Optional[Dict] = None,
|
752 |
+
power: Optional[int] = None,
|
753 |
+
center: bool = True,
|
754 |
+
normalized: bool = True,
|
755 |
+
pad_mode: str = "constant",
|
756 |
+
onesided: bool = True,
|
757 |
+
n_bands: int = None,
|
758 |
+
use_freq_weights: bool = True,
|
759 |
+
normalize_input: bool = False,
|
760 |
+
mult_add_mask: bool = False
|
761 |
+
) -> None:
|
762 |
+
super().__init__(
|
763 |
+
stems=stems,
|
764 |
+
band_specs=band_specs,
|
765 |
+
fs=fs,
|
766 |
+
n_fft=n_fft,
|
767 |
+
win_length=win_length,
|
768 |
+
hop_length=hop_length,
|
769 |
+
window_fn=window_fn,
|
770 |
+
wkwargs=wkwargs,
|
771 |
+
power=power,
|
772 |
+
center=center,
|
773 |
+
normalized=normalized,
|
774 |
+
pad_mode=pad_mode,
|
775 |
+
onesided=onesided,
|
776 |
+
n_bands=n_bands,
|
777 |
+
)
|
778 |
+
|
779 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
|
780 |
+
stems=stems,
|
781 |
+
band_specs=self.band_specs,
|
782 |
+
in_channel=in_channel,
|
783 |
+
require_no_overlap=require_no_overlap,
|
784 |
+
require_no_gap=require_no_gap,
|
785 |
+
normalize_channel_independently=normalize_channel_independently,
|
786 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
787 |
+
n_sqm_modules=n_sqm_modules,
|
788 |
+
emb_dim=emb_dim,
|
789 |
+
rnn_dim=rnn_dim,
|
790 |
+
bidirectional=bidirectional,
|
791 |
+
rnn_type=rnn_type,
|
792 |
+
mlp_dim=mlp_dim,
|
793 |
+
cond_dim=cond_dim,
|
794 |
+
hidden_activation=hidden_activation,
|
795 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
796 |
+
complex_mask=complex_mask,
|
797 |
+
overlapping_band=self.overlapping_band,
|
798 |
+
freq_weights=self.freq_weights,
|
799 |
+
n_freq=n_fft // 2 + 1,
|
800 |
+
use_freq_weights=use_freq_weights,
|
801 |
+
mult_add_mask=mult_add_mask
|
802 |
+
)
|
803 |
+
class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
804 |
+
def __init__(
|
805 |
+
self,
|
806 |
+
in_channel: int,
|
807 |
+
stems: List[str],
|
808 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
809 |
+
kernel_norm_mlp_version: int = 1,
|
810 |
+
mask_kernel_freq: int = 3,
|
811 |
+
mask_kernel_time: int = 3,
|
812 |
+
conv_kernel_freq: int = 1,
|
813 |
+
conv_kernel_time: int = 1,
|
814 |
+
fs: int = 44100,
|
815 |
+
require_no_overlap: bool = False,
|
816 |
+
require_no_gap: bool = True,
|
817 |
+
normalize_channel_independently: bool = False,
|
818 |
+
treat_channel_as_feature: bool = True,
|
819 |
+
n_sqm_modules: int = 12,
|
820 |
+
emb_dim: int = 128,
|
821 |
+
rnn_dim: int = 256,
|
822 |
+
bidirectional: bool = True,
|
823 |
+
rnn_type: str = "LSTM",
|
824 |
+
mlp_dim: int = 512,
|
825 |
+
hidden_activation: str = "Tanh",
|
826 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
827 |
+
complex_mask: bool = True,
|
828 |
+
n_fft: int = 2048,
|
829 |
+
win_length: Optional[int] = 2048,
|
830 |
+
hop_length: int = 512,
|
831 |
+
window_fn: str = "hann_window",
|
832 |
+
wkwargs: Optional[Dict] = None,
|
833 |
+
power: Optional[int] = None,
|
834 |
+
center: bool = True,
|
835 |
+
normalized: bool = True,
|
836 |
+
pad_mode: str = "constant",
|
837 |
+
onesided: bool = True,
|
838 |
+
n_bands: int = None,
|
839 |
+
) -> None:
|
840 |
+
super().__init__(
|
841 |
+
stems=stems,
|
842 |
+
band_specs=band_specs,
|
843 |
+
fs=fs,
|
844 |
+
n_fft=n_fft,
|
845 |
+
win_length=win_length,
|
846 |
+
hop_length=hop_length,
|
847 |
+
window_fn=window_fn,
|
848 |
+
wkwargs=wkwargs,
|
849 |
+
power=power,
|
850 |
+
center=center,
|
851 |
+
normalized=normalized,
|
852 |
+
pad_mode=pad_mode,
|
853 |
+
onesided=onesided,
|
854 |
+
n_bands=n_bands,
|
855 |
+
)
|
856 |
+
|
857 |
+
self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
|
858 |
+
stems=stems,
|
859 |
+
band_specs=self.band_specs,
|
860 |
+
in_channel=in_channel,
|
861 |
+
require_no_overlap=require_no_overlap,
|
862 |
+
require_no_gap=require_no_gap,
|
863 |
+
normalize_channel_independently=normalize_channel_independently,
|
864 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
865 |
+
n_sqm_modules=n_sqm_modules,
|
866 |
+
emb_dim=emb_dim,
|
867 |
+
rnn_dim=rnn_dim,
|
868 |
+
bidirectional=bidirectional,
|
869 |
+
rnn_type=rnn_type,
|
870 |
+
mlp_dim=mlp_dim,
|
871 |
+
hidden_activation=hidden_activation,
|
872 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
873 |
+
complex_mask=complex_mask,
|
874 |
+
overlapping_band=self.overlapping_band,
|
875 |
+
freq_weights=self.freq_weights,
|
876 |
+
n_freq=n_fft // 2 + 1,
|
877 |
+
mask_kernel_freq=mask_kernel_freq,
|
878 |
+
mask_kernel_time=mask_kernel_time,
|
879 |
+
conv_kernel_freq=conv_kernel_freq,
|
880 |
+
conv_kernel_time=conv_kernel_time,
|
881 |
+
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
882 |
+
)
|
models/bandit/core/utils/__init__.py
ADDED
File without changes
|
models/bandit/core/utils/audio.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
|
3 |
+
from tqdm import tqdm
|
4 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
|
12 |
+
@torch.jit.script
|
13 |
+
def merge(
|
14 |
+
combined: torch.Tensor,
|
15 |
+
original_batch_size: int,
|
16 |
+
n_channel: int,
|
17 |
+
n_chunks: int,
|
18 |
+
chunk_size: int, ):
|
19 |
+
combined = torch.reshape(
|
20 |
+
combined,
|
21 |
+
(original_batch_size, n_chunks, n_channel, chunk_size)
|
22 |
+
)
|
23 |
+
combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
|
24 |
+
original_batch_size * n_channel,
|
25 |
+
chunk_size,
|
26 |
+
n_chunks
|
27 |
+
)
|
28 |
+
|
29 |
+
return combined
|
30 |
+
|
31 |
+
|
32 |
+
@torch.jit.script
|
33 |
+
def unfold(
|
34 |
+
padded_audio: torch.Tensor,
|
35 |
+
original_batch_size: int,
|
36 |
+
n_channel: int,
|
37 |
+
chunk_size: int,
|
38 |
+
hop_size: int
|
39 |
+
) -> torch.Tensor:
|
40 |
+
|
41 |
+
unfolded_input = F.unfold(
|
42 |
+
padded_audio[:, :, None, :],
|
43 |
+
kernel_size=(1, chunk_size),
|
44 |
+
stride=(1, hop_size)
|
45 |
+
)
|
46 |
+
|
47 |
+
_, _, n_chunks = unfolded_input.shape
|
48 |
+
unfolded_input = unfolded_input.view(
|
49 |
+
original_batch_size,
|
50 |
+
n_channel,
|
51 |
+
chunk_size,
|
52 |
+
n_chunks
|
53 |
+
)
|
54 |
+
unfolded_input = torch.permute(
|
55 |
+
unfolded_input,
|
56 |
+
(0, 3, 1, 2)
|
57 |
+
).reshape(
|
58 |
+
original_batch_size * n_chunks,
|
59 |
+
n_channel,
|
60 |
+
chunk_size
|
61 |
+
)
|
62 |
+
|
63 |
+
return unfolded_input
|
64 |
+
|
65 |
+
|
66 |
+
@torch.jit.script
|
67 |
+
# @torch.compile
|
68 |
+
def merge_chunks_all(
|
69 |
+
combined: torch.Tensor,
|
70 |
+
original_batch_size: int,
|
71 |
+
n_channel: int,
|
72 |
+
n_samples: int,
|
73 |
+
n_padded_samples: int,
|
74 |
+
n_chunks: int,
|
75 |
+
chunk_size: int,
|
76 |
+
hop_size: int,
|
77 |
+
edge_frame_pad_sizes: Tuple[int, int],
|
78 |
+
standard_window: torch.Tensor,
|
79 |
+
first_window: torch.Tensor,
|
80 |
+
last_window: torch.Tensor
|
81 |
+
):
|
82 |
+
combined = merge(
|
83 |
+
combined,
|
84 |
+
original_batch_size,
|
85 |
+
n_channel,
|
86 |
+
n_chunks,
|
87 |
+
chunk_size
|
88 |
+
)
|
89 |
+
|
90 |
+
combined = combined * standard_window[:, None].to(combined.device)
|
91 |
+
|
92 |
+
combined = F.fold(
|
93 |
+
combined.to(torch.float32), output_size=(1, n_padded_samples),
|
94 |
+
kernel_size=(1, chunk_size),
|
95 |
+
stride=(1, hop_size)
|
96 |
+
)
|
97 |
+
|
98 |
+
combined = combined.view(
|
99 |
+
original_batch_size,
|
100 |
+
n_channel,
|
101 |
+
n_padded_samples
|
102 |
+
)
|
103 |
+
|
104 |
+
pad_front, pad_back = edge_frame_pad_sizes
|
105 |
+
combined = combined[..., pad_front:-pad_back]
|
106 |
+
|
107 |
+
combined = combined[..., :n_samples]
|
108 |
+
|
109 |
+
return combined
|
110 |
+
|
111 |
+
# @torch.jit.script
|
112 |
+
|
113 |
+
|
114 |
+
def merge_chunks_edge(
|
115 |
+
combined: torch.Tensor,
|
116 |
+
original_batch_size: int,
|
117 |
+
n_channel: int,
|
118 |
+
n_samples: int,
|
119 |
+
n_padded_samples: int,
|
120 |
+
n_chunks: int,
|
121 |
+
chunk_size: int,
|
122 |
+
hop_size: int,
|
123 |
+
edge_frame_pad_sizes: Tuple[int, int],
|
124 |
+
standard_window: torch.Tensor,
|
125 |
+
first_window: torch.Tensor,
|
126 |
+
last_window: torch.Tensor
|
127 |
+
):
|
128 |
+
combined = merge(
|
129 |
+
combined,
|
130 |
+
original_batch_size,
|
131 |
+
n_channel,
|
132 |
+
n_chunks,
|
133 |
+
chunk_size
|
134 |
+
)
|
135 |
+
|
136 |
+
combined[..., 0] = combined[..., 0] * first_window
|
137 |
+
combined[..., -1] = combined[..., -1] * last_window
|
138 |
+
combined[..., 1:-1] = combined[...,
|
139 |
+
1:-1] * standard_window[:, None]
|
140 |
+
|
141 |
+
combined = F.fold(
|
142 |
+
combined, output_size=(1, n_padded_samples),
|
143 |
+
kernel_size=(1, chunk_size),
|
144 |
+
stride=(1, hop_size)
|
145 |
+
)
|
146 |
+
|
147 |
+
combined = combined.view(
|
148 |
+
original_batch_size,
|
149 |
+
n_channel,
|
150 |
+
n_padded_samples
|
151 |
+
)
|
152 |
+
|
153 |
+
combined = combined[..., :n_samples]
|
154 |
+
|
155 |
+
return combined
|
156 |
+
|
157 |
+
|
158 |
+
class BaseFader(nn.Module):
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
chunk_size_second: float,
|
162 |
+
hop_size_second: float,
|
163 |
+
fs: int,
|
164 |
+
fade_edge_frames: bool,
|
165 |
+
batch_size: int,
|
166 |
+
) -> None:
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
self.chunk_size = int(chunk_size_second * fs)
|
170 |
+
self.hop_size = int(hop_size_second * fs)
|
171 |
+
self.overlap_size = self.chunk_size - self.hop_size
|
172 |
+
self.fade_edge_frames = fade_edge_frames
|
173 |
+
self.batch_size = batch_size
|
174 |
+
|
175 |
+
# @torch.jit.script
|
176 |
+
def prepare(self, audio):
|
177 |
+
|
178 |
+
if self.fade_edge_frames:
|
179 |
+
audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
|
180 |
+
|
181 |
+
n_samples = audio.shape[-1]
|
182 |
+
n_chunks = int(
|
183 |
+
np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1
|
184 |
+
)
|
185 |
+
|
186 |
+
padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
|
187 |
+
pad_size = padded_size - n_samples
|
188 |
+
|
189 |
+
padded_audio = F.pad(audio, (0, pad_size))
|
190 |
+
|
191 |
+
return padded_audio, n_chunks
|
192 |
+
|
193 |
+
def forward(
|
194 |
+
self,
|
195 |
+
audio: torch.Tensor,
|
196 |
+
model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
|
197 |
+
):
|
198 |
+
|
199 |
+
original_dtype = audio.dtype
|
200 |
+
original_device = audio.device
|
201 |
+
|
202 |
+
audio = audio.to("cpu")
|
203 |
+
|
204 |
+
original_batch_size, n_channel, n_samples = audio.shape
|
205 |
+
padded_audio, n_chunks = self.prepare(audio)
|
206 |
+
del audio
|
207 |
+
n_padded_samples = padded_audio.shape[-1]
|
208 |
+
|
209 |
+
if n_channel > 1:
|
210 |
+
padded_audio = padded_audio.view(
|
211 |
+
original_batch_size * n_channel, 1, n_padded_samples
|
212 |
+
)
|
213 |
+
|
214 |
+
unfolded_input = unfold(
|
215 |
+
padded_audio,
|
216 |
+
original_batch_size,
|
217 |
+
n_channel,
|
218 |
+
self.chunk_size, self.hop_size
|
219 |
+
)
|
220 |
+
|
221 |
+
n_total_chunks, n_channel, chunk_size = unfolded_input.shape
|
222 |
+
|
223 |
+
n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
|
224 |
+
|
225 |
+
chunks_in = [
|
226 |
+
unfolded_input[
|
227 |
+
b * self.batch_size:(b + 1) * self.batch_size, ...].clone()
|
228 |
+
for b in range(n_batch)
|
229 |
+
]
|
230 |
+
|
231 |
+
all_chunks_out = defaultdict(
|
232 |
+
lambda: torch.zeros_like(
|
233 |
+
unfolded_input, device="cpu"
|
234 |
+
)
|
235 |
+
)
|
236 |
+
|
237 |
+
# for b, cin in enumerate(tqdm(chunks_in)):
|
238 |
+
for b, cin in enumerate(chunks_in):
|
239 |
+
if torch.allclose(cin, torch.tensor(0.0)):
|
240 |
+
del cin
|
241 |
+
continue
|
242 |
+
|
243 |
+
chunks_out = model_fn(cin.to(original_device))
|
244 |
+
del cin
|
245 |
+
for s, c in chunks_out.items():
|
246 |
+
all_chunks_out[s][b * self.batch_size:(b + 1) * self.batch_size,
|
247 |
+
...] = c.cpu()
|
248 |
+
del chunks_out
|
249 |
+
|
250 |
+
del unfolded_input
|
251 |
+
del padded_audio
|
252 |
+
|
253 |
+
if self.fade_edge_frames:
|
254 |
+
fn = merge_chunks_all
|
255 |
+
else:
|
256 |
+
fn = merge_chunks_edge
|
257 |
+
outputs = {}
|
258 |
+
|
259 |
+
torch.cuda.empty_cache()
|
260 |
+
|
261 |
+
for s, c in all_chunks_out.items():
|
262 |
+
combined: torch.Tensor = fn(
|
263 |
+
c,
|
264 |
+
original_batch_size,
|
265 |
+
n_channel,
|
266 |
+
n_samples,
|
267 |
+
n_padded_samples,
|
268 |
+
n_chunks,
|
269 |
+
self.chunk_size,
|
270 |
+
self.hop_size,
|
271 |
+
self.edge_frame_pad_sizes,
|
272 |
+
self.standard_window,
|
273 |
+
self.__dict__.get("first_window", self.standard_window),
|
274 |
+
self.__dict__.get("last_window", self.standard_window)
|
275 |
+
)
|
276 |
+
|
277 |
+
outputs[s] = combined.to(
|
278 |
+
dtype=original_dtype,
|
279 |
+
device=original_device
|
280 |
+
)
|
281 |
+
|
282 |
+
return {
|
283 |
+
"audio": outputs
|
284 |
+
}
|
285 |
+
#
|
286 |
+
# def old_forward(
|
287 |
+
# self,
|
288 |
+
# audio: torch.Tensor,
|
289 |
+
# model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
|
290 |
+
# ):
|
291 |
+
#
|
292 |
+
# n_samples = audio.shape[-1]
|
293 |
+
# original_batch_size = audio.shape[0]
|
294 |
+
#
|
295 |
+
# padded_audio, n_chunks = self.prepare(audio)
|
296 |
+
#
|
297 |
+
# ndim = padded_audio.ndim
|
298 |
+
# broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size]
|
299 |
+
#
|
300 |
+
# outputs = defaultdict(
|
301 |
+
# lambda: torch.zeros_like(
|
302 |
+
# padded_audio, device=audio.device, dtype=torch.float64
|
303 |
+
# )
|
304 |
+
# )
|
305 |
+
#
|
306 |
+
# all_chunks_out = []
|
307 |
+
# len_chunks_in = []
|
308 |
+
#
|
309 |
+
# batch_size_ = int(self.batch_size // original_batch_size)
|
310 |
+
# for b in range(int(np.ceil(n_chunks / batch_size_))):
|
311 |
+
# chunks_in = []
|
312 |
+
# for j in range(batch_size_):
|
313 |
+
# i = b * batch_size_ + j
|
314 |
+
# if i == n_chunks:
|
315 |
+
# break
|
316 |
+
#
|
317 |
+
# start = i * hop_size
|
318 |
+
# end = start + self.chunk_size
|
319 |
+
# chunk_in = padded_audio[..., start:end]
|
320 |
+
# chunks_in.append(chunk_in)
|
321 |
+
#
|
322 |
+
# chunks_in = torch.concat(chunks_in, dim=0)
|
323 |
+
# chunks_out = model_fn(chunks_in)
|
324 |
+
# all_chunks_out.append(chunks_out)
|
325 |
+
# len_chunks_in.append(len(chunks_in))
|
326 |
+
#
|
327 |
+
# for b, (chunks_out, lci) in enumerate(
|
328 |
+
# zip(all_chunks_out, len_chunks_in)
|
329 |
+
# ):
|
330 |
+
# for stem in chunks_out:
|
331 |
+
# for j in range(lci // original_batch_size):
|
332 |
+
# i = b * batch_size_ + j
|
333 |
+
#
|
334 |
+
# if self.fade_edge_frames:
|
335 |
+
# window = self.standard_window
|
336 |
+
# else:
|
337 |
+
# if i == 0:
|
338 |
+
# window = self.first_window
|
339 |
+
# elif i == n_chunks - 1:
|
340 |
+
# window = self.last_window
|
341 |
+
# else:
|
342 |
+
# window = self.standard_window
|
343 |
+
#
|
344 |
+
# start = i * hop_size
|
345 |
+
# end = start + self.chunk_size
|
346 |
+
#
|
347 |
+
# chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size,
|
348 |
+
# ...]
|
349 |
+
# contrib = window.view(*broadcaster) * chunk_out
|
350 |
+
# outputs[stem][..., start:end] = (
|
351 |
+
# outputs[stem][..., start:end] + contrib
|
352 |
+
# )
|
353 |
+
#
|
354 |
+
# if self.fade_edge_frames:
|
355 |
+
# pad_front, pad_back = self.edge_frame_pad_sizes
|
356 |
+
# outputs = {k: v[..., pad_front:-pad_back] for k, v in
|
357 |
+
# outputs.items()}
|
358 |
+
#
|
359 |
+
# outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in
|
360 |
+
# outputs.items()}
|
361 |
+
#
|
362 |
+
# return {
|
363 |
+
# "audio": outputs
|
364 |
+
# }
|
365 |
+
|
366 |
+
|
367 |
+
class LinearFader(BaseFader):
|
368 |
+
def __init__(
|
369 |
+
self,
|
370 |
+
chunk_size_second: float,
|
371 |
+
hop_size_second: float,
|
372 |
+
fs: int,
|
373 |
+
fade_edge_frames: bool = False,
|
374 |
+
batch_size: int = 1,
|
375 |
+
) -> None:
|
376 |
+
|
377 |
+
assert hop_size_second >= chunk_size_second / 2
|
378 |
+
|
379 |
+
super().__init__(
|
380 |
+
chunk_size_second=chunk_size_second,
|
381 |
+
hop_size_second=hop_size_second,
|
382 |
+
fs=fs,
|
383 |
+
fade_edge_frames=fade_edge_frames,
|
384 |
+
batch_size=batch_size,
|
385 |
+
)
|
386 |
+
|
387 |
+
in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
|
388 |
+
out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
|
389 |
+
center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
|
390 |
+
inout_ones = torch.ones(self.overlap_size)
|
391 |
+
|
392 |
+
# using nn.Parameters allows lightning to take care of devices for us
|
393 |
+
self.register_buffer(
|
394 |
+
"standard_window",
|
395 |
+
torch.concat([in_fade, center_ones, out_fade])
|
396 |
+
)
|
397 |
+
|
398 |
+
self.fade_edge_frames = fade_edge_frames
|
399 |
+
self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
|
400 |
+
|
401 |
+
if not self.fade_edge_frames:
|
402 |
+
self.first_window = nn.Parameter(
|
403 |
+
torch.concat([inout_ones, center_ones, out_fade]),
|
404 |
+
requires_grad=False
|
405 |
+
)
|
406 |
+
self.last_window = nn.Parameter(
|
407 |
+
torch.concat([in_fade, center_ones, inout_ones]),
|
408 |
+
requires_grad=False
|
409 |
+
)
|
410 |
+
|
411 |
+
|
412 |
+
class OverlapAddFader(BaseFader):
|
413 |
+
def __init__(
|
414 |
+
self,
|
415 |
+
window_type: str,
|
416 |
+
chunk_size_second: float,
|
417 |
+
hop_size_second: float,
|
418 |
+
fs: int,
|
419 |
+
batch_size: int = 1,
|
420 |
+
) -> None:
|
421 |
+
assert (chunk_size_second / hop_size_second) % 2 == 0
|
422 |
+
assert int(chunk_size_second * fs) % 2 == 0
|
423 |
+
|
424 |
+
super().__init__(
|
425 |
+
chunk_size_second=chunk_size_second,
|
426 |
+
hop_size_second=hop_size_second,
|
427 |
+
fs=fs,
|
428 |
+
fade_edge_frames=True,
|
429 |
+
batch_size=batch_size,
|
430 |
+
)
|
431 |
+
|
432 |
+
self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
|
433 |
+
# print(f"hop multiplier: {self.hop_multiplier}")
|
434 |
+
|
435 |
+
self.edge_frame_pad_sizes = (
|
436 |
+
2 * self.overlap_size,
|
437 |
+
2 * self.overlap_size
|
438 |
+
)
|
439 |
+
|
440 |
+
self.register_buffer(
|
441 |
+
"standard_window", torch.windows.__dict__[window_type](
|
442 |
+
self.chunk_size, sym=False, # dtype=torch.float64
|
443 |
+
) / self.hop_multiplier
|
444 |
+
)
|
445 |
+
|
446 |
+
|
447 |
+
if __name__ == "__main__":
|
448 |
+
import torchaudio as ta
|
449 |
+
fs = 44100
|
450 |
+
ola = OverlapAddFader(
|
451 |
+
"hann",
|
452 |
+
6.0,
|
453 |
+
1.0,
|
454 |
+
fs,
|
455 |
+
batch_size=16
|
456 |
+
)
|
457 |
+
audio_, _ = ta.load(
|
458 |
+
"$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too "
|
459 |
+
"Much/vocals.wav"
|
460 |
+
)
|
461 |
+
audio_ = audio_[None, ...]
|
462 |
+
out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
|
463 |
+
print(torch.allclose(out, audio_))
|
models/bandit/model_from_config.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os.path
|
3 |
+
import torch
|
4 |
+
|
5 |
+
code_path = os.path.dirname(os.path.abspath(__file__)) + '/'
|
6 |
+
sys.path.append(code_path)
|
7 |
+
|
8 |
+
import yaml
|
9 |
+
from ml_collections import ConfigDict
|
10 |
+
|
11 |
+
torch.set_float32_matmul_precision("medium")
|
12 |
+
|
13 |
+
|
14 |
+
def get_model(
|
15 |
+
config_path,
|
16 |
+
weights_path,
|
17 |
+
device,
|
18 |
+
):
|
19 |
+
from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
|
20 |
+
|
21 |
+
f = open(config_path)
|
22 |
+
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
23 |
+
f.close()
|
24 |
+
|
25 |
+
model = MultiMaskMultiSourceBandSplitRNNSimple(
|
26 |
+
**config.model
|
27 |
+
)
|
28 |
+
d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt')
|
29 |
+
model.load_state_dict(d)
|
30 |
+
model.to(device)
|
31 |
+
return model, config
|
models/bs_roformer/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from models.bs_roformer.bs_roformer import BSRoformer
|
2 |
+
from models.bs_roformer.mel_band_roformer import MelBandRoformer
|
models/bs_roformer/attend.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import wraps
|
2 |
+
from packaging import version
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn, einsum
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from einops import rearrange, reduce
|
10 |
+
|
11 |
+
# constants
|
12 |
+
|
13 |
+
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
14 |
+
|
15 |
+
# helpers
|
16 |
+
|
17 |
+
def exists(val):
|
18 |
+
return val is not None
|
19 |
+
|
20 |
+
def default(v, d):
|
21 |
+
return v if exists(v) else d
|
22 |
+
|
23 |
+
def once(fn):
|
24 |
+
called = False
|
25 |
+
@wraps(fn)
|
26 |
+
def inner(x):
|
27 |
+
nonlocal called
|
28 |
+
if called:
|
29 |
+
return
|
30 |
+
called = True
|
31 |
+
return fn(x)
|
32 |
+
return inner
|
33 |
+
|
34 |
+
print_once = once(print)
|
35 |
+
|
36 |
+
# main class
|
37 |
+
|
38 |
+
class Attend(nn.Module):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
dropout = 0.,
|
42 |
+
flash = False,
|
43 |
+
scale = None
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.scale = scale
|
47 |
+
self.dropout = dropout
|
48 |
+
self.attn_dropout = nn.Dropout(dropout)
|
49 |
+
|
50 |
+
self.flash = flash
|
51 |
+
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
52 |
+
|
53 |
+
# determine efficient attention configs for cuda and cpu
|
54 |
+
|
55 |
+
self.cpu_config = FlashAttentionConfig(True, True, True)
|
56 |
+
self.cuda_config = None
|
57 |
+
|
58 |
+
if not torch.cuda.is_available() or not flash:
|
59 |
+
return
|
60 |
+
|
61 |
+
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
62 |
+
|
63 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
64 |
+
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
|
65 |
+
self.cuda_config = FlashAttentionConfig(True, False, False)
|
66 |
+
else:
|
67 |
+
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
|
68 |
+
self.cuda_config = FlashAttentionConfig(False, True, True)
|
69 |
+
|
70 |
+
def flash_attn(self, q, k, v):
|
71 |
+
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
72 |
+
|
73 |
+
if exists(self.scale):
|
74 |
+
default_scale = q.shape[-1] ** -0.5
|
75 |
+
q = q * (self.scale / default_scale)
|
76 |
+
|
77 |
+
# Check if there is a compatible device for flash attention
|
78 |
+
|
79 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
80 |
+
|
81 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
|
82 |
+
|
83 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
84 |
+
out = F.scaled_dot_product_attention(
|
85 |
+
q, k, v,
|
86 |
+
dropout_p = self.dropout if self.training else 0.
|
87 |
+
)
|
88 |
+
|
89 |
+
return out
|
90 |
+
|
91 |
+
def forward(self, q, k, v):
|
92 |
+
"""
|
93 |
+
einstein notation
|
94 |
+
b - batch
|
95 |
+
h - heads
|
96 |
+
n, i, j - sequence length (base sequence length, source, target)
|
97 |
+
d - feature dimension
|
98 |
+
"""
|
99 |
+
|
100 |
+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
101 |
+
|
102 |
+
scale = default(self.scale, q.shape[-1] ** -0.5)
|
103 |
+
|
104 |
+
if self.flash:
|
105 |
+
return self.flash_attn(q, k, v)
|
106 |
+
|
107 |
+
# similarity
|
108 |
+
|
109 |
+
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
|
110 |
+
|
111 |
+
# attention
|
112 |
+
|
113 |
+
attn = sim.softmax(dim=-1)
|
114 |
+
attn = self.attn_dropout(attn)
|
115 |
+
|
116 |
+
# aggregate values
|
117 |
+
|
118 |
+
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
119 |
+
|
120 |
+
return out
|
models/bs_roformer/bs_roformer.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn, einsum, Tensor
|
5 |
+
from torch.nn import Module, ModuleList
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from models.bs_roformer.attend import Attend
|
9 |
+
|
10 |
+
from beartype.typing import Tuple, Optional, List, Callable
|
11 |
+
from beartype import beartype
|
12 |
+
|
13 |
+
from rotary_embedding_torch import RotaryEmbedding
|
14 |
+
|
15 |
+
from einops import rearrange, pack, unpack
|
16 |
+
from einops.layers.torch import Rearrange
|
17 |
+
|
18 |
+
# helper functions
|
19 |
+
|
20 |
+
def exists(val):
|
21 |
+
return val is not None
|
22 |
+
|
23 |
+
|
24 |
+
def default(v, d):
|
25 |
+
return v if exists(v) else d
|
26 |
+
|
27 |
+
|
28 |
+
def pack_one(t, pattern):
|
29 |
+
return pack([t], pattern)
|
30 |
+
|
31 |
+
|
32 |
+
def unpack_one(t, ps, pattern):
|
33 |
+
return unpack(t, ps, pattern)[0]
|
34 |
+
|
35 |
+
|
36 |
+
# norm
|
37 |
+
|
38 |
+
def l2norm(t):
|
39 |
+
return F.normalize(t, dim = -1, p = 2)
|
40 |
+
|
41 |
+
|
42 |
+
class RMSNorm(Module):
|
43 |
+
def __init__(self, dim):
|
44 |
+
super().__init__()
|
45 |
+
self.scale = dim ** 0.5
|
46 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
50 |
+
|
51 |
+
|
52 |
+
# attention
|
53 |
+
|
54 |
+
class FeedForward(Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
dim,
|
58 |
+
mult=4,
|
59 |
+
dropout=0.
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
dim_inner = int(dim * mult)
|
63 |
+
self.net = nn.Sequential(
|
64 |
+
RMSNorm(dim),
|
65 |
+
nn.Linear(dim, dim_inner),
|
66 |
+
nn.GELU(),
|
67 |
+
nn.Dropout(dropout),
|
68 |
+
nn.Linear(dim_inner, dim),
|
69 |
+
nn.Dropout(dropout)
|
70 |
+
)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
return self.net(x)
|
74 |
+
|
75 |
+
|
76 |
+
class Attention(Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
dim,
|
80 |
+
heads=8,
|
81 |
+
dim_head=64,
|
82 |
+
dropout=0.,
|
83 |
+
rotary_embed=None,
|
84 |
+
flash=True
|
85 |
+
):
|
86 |
+
super().__init__()
|
87 |
+
self.heads = heads
|
88 |
+
self.scale = dim_head ** -0.5
|
89 |
+
dim_inner = heads * dim_head
|
90 |
+
|
91 |
+
self.rotary_embed = rotary_embed
|
92 |
+
|
93 |
+
self.attend = Attend(flash=flash, dropout=dropout)
|
94 |
+
|
95 |
+
self.norm = RMSNorm(dim)
|
96 |
+
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
97 |
+
|
98 |
+
self.to_gates = nn.Linear(dim, heads)
|
99 |
+
|
100 |
+
self.to_out = nn.Sequential(
|
101 |
+
nn.Linear(dim_inner, dim, bias=False),
|
102 |
+
nn.Dropout(dropout)
|
103 |
+
)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
x = self.norm(x)
|
107 |
+
|
108 |
+
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
|
109 |
+
|
110 |
+
if exists(self.rotary_embed):
|
111 |
+
q = self.rotary_embed.rotate_queries_or_keys(q)
|
112 |
+
k = self.rotary_embed.rotate_queries_or_keys(k)
|
113 |
+
|
114 |
+
out = self.attend(q, k, v)
|
115 |
+
|
116 |
+
gates = self.to_gates(x)
|
117 |
+
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
|
118 |
+
|
119 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
120 |
+
return self.to_out(out)
|
121 |
+
|
122 |
+
|
123 |
+
class LinearAttention(Module):
|
124 |
+
"""
|
125 |
+
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
126 |
+
"""
|
127 |
+
|
128 |
+
@beartype
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
*,
|
132 |
+
dim,
|
133 |
+
dim_head=32,
|
134 |
+
heads=8,
|
135 |
+
scale=8,
|
136 |
+
flash=False,
|
137 |
+
dropout=0.
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
dim_inner = dim_head * heads
|
141 |
+
self.norm = RMSNorm(dim)
|
142 |
+
|
143 |
+
self.to_qkv = nn.Sequential(
|
144 |
+
nn.Linear(dim, dim_inner * 3, bias=False),
|
145 |
+
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
|
146 |
+
)
|
147 |
+
|
148 |
+
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
149 |
+
|
150 |
+
self.attend = Attend(
|
151 |
+
scale=scale,
|
152 |
+
dropout=dropout,
|
153 |
+
flash=flash
|
154 |
+
)
|
155 |
+
|
156 |
+
self.to_out = nn.Sequential(
|
157 |
+
Rearrange('b h d n -> b n (h d)'),
|
158 |
+
nn.Linear(dim_inner, dim, bias=False)
|
159 |
+
)
|
160 |
+
|
161 |
+
def forward(
|
162 |
+
self,
|
163 |
+
x
|
164 |
+
):
|
165 |
+
x = self.norm(x)
|
166 |
+
|
167 |
+
q, k, v = self.to_qkv(x)
|
168 |
+
|
169 |
+
q, k = map(l2norm, (q, k))
|
170 |
+
q = q * self.temperature.exp()
|
171 |
+
|
172 |
+
out = self.attend(q, k, v)
|
173 |
+
|
174 |
+
return self.to_out(out)
|
175 |
+
|
176 |
+
|
177 |
+
class Transformer(Module):
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
*,
|
181 |
+
dim,
|
182 |
+
depth,
|
183 |
+
dim_head=64,
|
184 |
+
heads=8,
|
185 |
+
attn_dropout=0.,
|
186 |
+
ff_dropout=0.,
|
187 |
+
ff_mult=4,
|
188 |
+
norm_output=True,
|
189 |
+
rotary_embed=None,
|
190 |
+
flash_attn=True,
|
191 |
+
linear_attn=False
|
192 |
+
):
|
193 |
+
super().__init__()
|
194 |
+
self.layers = ModuleList([])
|
195 |
+
|
196 |
+
for _ in range(depth):
|
197 |
+
if linear_attn:
|
198 |
+
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
|
199 |
+
else:
|
200 |
+
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
|
201 |
+
rotary_embed=rotary_embed, flash=flash_attn)
|
202 |
+
|
203 |
+
self.layers.append(ModuleList([
|
204 |
+
attn,
|
205 |
+
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
|
206 |
+
]))
|
207 |
+
|
208 |
+
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
209 |
+
|
210 |
+
def forward(self, x):
|
211 |
+
|
212 |
+
for attn, ff in self.layers:
|
213 |
+
x = attn(x) + x
|
214 |
+
x = ff(x) + x
|
215 |
+
|
216 |
+
return self.norm(x)
|
217 |
+
|
218 |
+
|
219 |
+
# bandsplit module
|
220 |
+
|
221 |
+
class BandSplit(Module):
|
222 |
+
@beartype
|
223 |
+
def __init__(
|
224 |
+
self,
|
225 |
+
dim,
|
226 |
+
dim_inputs: Tuple[int, ...]
|
227 |
+
):
|
228 |
+
super().__init__()
|
229 |
+
self.dim_inputs = dim_inputs
|
230 |
+
self.to_features = ModuleList([])
|
231 |
+
|
232 |
+
for dim_in in dim_inputs:
|
233 |
+
net = nn.Sequential(
|
234 |
+
RMSNorm(dim_in),
|
235 |
+
nn.Linear(dim_in, dim)
|
236 |
+
)
|
237 |
+
|
238 |
+
self.to_features.append(net)
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
x = x.split(self.dim_inputs, dim=-1)
|
242 |
+
|
243 |
+
outs = []
|
244 |
+
for split_input, to_feature in zip(x, self.to_features):
|
245 |
+
split_output = to_feature(split_input)
|
246 |
+
outs.append(split_output)
|
247 |
+
|
248 |
+
return torch.stack(outs, dim=-2)
|
249 |
+
|
250 |
+
|
251 |
+
def MLP(
|
252 |
+
dim_in,
|
253 |
+
dim_out,
|
254 |
+
dim_hidden=None,
|
255 |
+
depth=1,
|
256 |
+
activation=nn.Tanh
|
257 |
+
):
|
258 |
+
dim_hidden = default(dim_hidden, dim_in)
|
259 |
+
|
260 |
+
net = []
|
261 |
+
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
|
262 |
+
|
263 |
+
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
264 |
+
is_last = ind == (len(dims) - 2)
|
265 |
+
|
266 |
+
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
267 |
+
|
268 |
+
if is_last:
|
269 |
+
continue
|
270 |
+
|
271 |
+
net.append(activation())
|
272 |
+
|
273 |
+
return nn.Sequential(*net)
|
274 |
+
|
275 |
+
|
276 |
+
class MaskEstimator(Module):
|
277 |
+
@beartype
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
dim,
|
281 |
+
dim_inputs: Tuple[int, ...],
|
282 |
+
depth,
|
283 |
+
mlp_expansion_factor=4
|
284 |
+
):
|
285 |
+
super().__init__()
|
286 |
+
self.dim_inputs = dim_inputs
|
287 |
+
self.to_freqs = ModuleList([])
|
288 |
+
dim_hidden = dim * mlp_expansion_factor
|
289 |
+
|
290 |
+
for dim_in in dim_inputs:
|
291 |
+
net = []
|
292 |
+
|
293 |
+
mlp = nn.Sequential(
|
294 |
+
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
|
295 |
+
nn.GLU(dim=-1)
|
296 |
+
)
|
297 |
+
|
298 |
+
self.to_freqs.append(mlp)
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
x = x.unbind(dim=-2)
|
302 |
+
|
303 |
+
outs = []
|
304 |
+
|
305 |
+
for band_features, mlp in zip(x, self.to_freqs):
|
306 |
+
freq_out = mlp(band_features)
|
307 |
+
outs.append(freq_out)
|
308 |
+
|
309 |
+
return torch.cat(outs, dim=-1)
|
310 |
+
|
311 |
+
|
312 |
+
# main class
|
313 |
+
|
314 |
+
DEFAULT_FREQS_PER_BANDS = (
|
315 |
+
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
316 |
+
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
317 |
+
2, 2, 2, 2,
|
318 |
+
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
|
319 |
+
12, 12, 12, 12, 12, 12, 12, 12,
|
320 |
+
24, 24, 24, 24, 24, 24, 24, 24,
|
321 |
+
48, 48, 48, 48, 48, 48, 48, 48,
|
322 |
+
128, 129,
|
323 |
+
)
|
324 |
+
|
325 |
+
|
326 |
+
class BSRoformer(Module):
|
327 |
+
|
328 |
+
@beartype
|
329 |
+
def __init__(
|
330 |
+
self,
|
331 |
+
dim,
|
332 |
+
*,
|
333 |
+
depth,
|
334 |
+
stereo=False,
|
335 |
+
num_stems=1,
|
336 |
+
time_transformer_depth=2,
|
337 |
+
freq_transformer_depth=2,
|
338 |
+
linear_transformer_depth=0,
|
339 |
+
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
340 |
+
# in the paper, they divide into ~60 bands, test with 1 for starters
|
341 |
+
dim_head=64,
|
342 |
+
heads=8,
|
343 |
+
attn_dropout=0.,
|
344 |
+
ff_dropout=0.,
|
345 |
+
flash_attn=True,
|
346 |
+
dim_freqs_in=1025,
|
347 |
+
stft_n_fft=2048,
|
348 |
+
stft_hop_length=512,
|
349 |
+
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
350 |
+
stft_win_length=2048,
|
351 |
+
stft_normalized=False,
|
352 |
+
stft_window_fn: Optional[Callable] = None,
|
353 |
+
mask_estimator_depth=2,
|
354 |
+
multi_stft_resolution_loss_weight=1.,
|
355 |
+
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
356 |
+
multi_stft_hop_size=147,
|
357 |
+
multi_stft_normalized=False,
|
358 |
+
multi_stft_window_fn: Callable = torch.hann_window
|
359 |
+
):
|
360 |
+
super().__init__()
|
361 |
+
|
362 |
+
self.stereo = stereo
|
363 |
+
self.audio_channels = 2 if stereo else 1
|
364 |
+
self.num_stems = num_stems
|
365 |
+
|
366 |
+
self.layers = ModuleList([])
|
367 |
+
|
368 |
+
transformer_kwargs = dict(
|
369 |
+
dim=dim,
|
370 |
+
heads=heads,
|
371 |
+
dim_head=dim_head,
|
372 |
+
attn_dropout=attn_dropout,
|
373 |
+
ff_dropout=ff_dropout,
|
374 |
+
flash_attn=flash_attn,
|
375 |
+
norm_output=False
|
376 |
+
)
|
377 |
+
|
378 |
+
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
379 |
+
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
380 |
+
|
381 |
+
for _ in range(depth):
|
382 |
+
tran_modules = []
|
383 |
+
if linear_transformer_depth > 0:
|
384 |
+
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
|
385 |
+
tran_modules.append(
|
386 |
+
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
|
387 |
+
)
|
388 |
+
tran_modules.append(
|
389 |
+
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
|
390 |
+
)
|
391 |
+
self.layers.append(nn.ModuleList(tran_modules))
|
392 |
+
|
393 |
+
self.final_norm = RMSNorm(dim)
|
394 |
+
|
395 |
+
self.stft_kwargs = dict(
|
396 |
+
n_fft=stft_n_fft,
|
397 |
+
hop_length=stft_hop_length,
|
398 |
+
win_length=stft_win_length,
|
399 |
+
normalized=stft_normalized
|
400 |
+
)
|
401 |
+
|
402 |
+
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
403 |
+
|
404 |
+
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
|
405 |
+
|
406 |
+
assert len(freqs_per_bands) > 1
|
407 |
+
assert sum(
|
408 |
+
freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
|
409 |
+
|
410 |
+
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
|
411 |
+
|
412 |
+
self.band_split = BandSplit(
|
413 |
+
dim=dim,
|
414 |
+
dim_inputs=freqs_per_bands_with_complex
|
415 |
+
)
|
416 |
+
|
417 |
+
self.mask_estimators = nn.ModuleList([])
|
418 |
+
|
419 |
+
for _ in range(num_stems):
|
420 |
+
mask_estimator = MaskEstimator(
|
421 |
+
dim=dim,
|
422 |
+
dim_inputs=freqs_per_bands_with_complex,
|
423 |
+
depth=mask_estimator_depth
|
424 |
+
)
|
425 |
+
|
426 |
+
self.mask_estimators.append(mask_estimator)
|
427 |
+
|
428 |
+
# for the multi-resolution stft loss
|
429 |
+
|
430 |
+
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
431 |
+
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
432 |
+
self.multi_stft_n_fft = stft_n_fft
|
433 |
+
self.multi_stft_window_fn = multi_stft_window_fn
|
434 |
+
|
435 |
+
self.multi_stft_kwargs = dict(
|
436 |
+
hop_length=multi_stft_hop_size,
|
437 |
+
normalized=multi_stft_normalized
|
438 |
+
)
|
439 |
+
|
440 |
+
def forward(
|
441 |
+
self,
|
442 |
+
raw_audio,
|
443 |
+
target=None,
|
444 |
+
return_loss_breakdown=False
|
445 |
+
):
|
446 |
+
"""
|
447 |
+
einops
|
448 |
+
|
449 |
+
b - batch
|
450 |
+
f - freq
|
451 |
+
t - time
|
452 |
+
s - audio channel (1 for mono, 2 for stereo)
|
453 |
+
n - number of 'stems'
|
454 |
+
c - complex (2)
|
455 |
+
d - feature dimension
|
456 |
+
"""
|
457 |
+
|
458 |
+
device = raw_audio.device
|
459 |
+
|
460 |
+
if raw_audio.ndim == 2:
|
461 |
+
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
|
462 |
+
|
463 |
+
channels = raw_audio.shape[1]
|
464 |
+
assert (not self.stereo and channels == 1) or (
|
465 |
+
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
466 |
+
|
467 |
+
# to stft
|
468 |
+
|
469 |
+
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
|
470 |
+
|
471 |
+
stft_window = self.stft_window_fn(device=device)
|
472 |
+
|
473 |
+
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
474 |
+
stft_repr = torch.view_as_real(stft_repr)
|
475 |
+
|
476 |
+
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
|
477 |
+
stft_repr = rearrange(stft_repr,
|
478 |
+
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
479 |
+
|
480 |
+
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
|
481 |
+
|
482 |
+
x = self.band_split(x)
|
483 |
+
|
484 |
+
# axial / hierarchical attention
|
485 |
+
|
486 |
+
for transformer_block in self.layers:
|
487 |
+
|
488 |
+
if len(transformer_block) == 3:
|
489 |
+
linear_transformer, time_transformer, freq_transformer = transformer_block
|
490 |
+
|
491 |
+
x, ft_ps = pack([x], 'b * d')
|
492 |
+
x = linear_transformer(x)
|
493 |
+
x, = unpack(x, ft_ps, 'b * d')
|
494 |
+
else:
|
495 |
+
time_transformer, freq_transformer = transformer_block
|
496 |
+
|
497 |
+
x = rearrange(x, 'b t f d -> b f t d')
|
498 |
+
x, ps = pack([x], '* t d')
|
499 |
+
|
500 |
+
x = time_transformer(x)
|
501 |
+
|
502 |
+
x, = unpack(x, ps, '* t d')
|
503 |
+
x = rearrange(x, 'b f t d -> b t f d')
|
504 |
+
x, ps = pack([x], '* f d')
|
505 |
+
|
506 |
+
x = freq_transformer(x)
|
507 |
+
|
508 |
+
x, = unpack(x, ps, '* f d')
|
509 |
+
|
510 |
+
x = self.final_norm(x)
|
511 |
+
|
512 |
+
num_stems = len(self.mask_estimators)
|
513 |
+
|
514 |
+
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
515 |
+
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
|
516 |
+
|
517 |
+
# modulate frequency representation
|
518 |
+
|
519 |
+
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
|
520 |
+
|
521 |
+
# complex number multiplication
|
522 |
+
|
523 |
+
stft_repr = torch.view_as_complex(stft_repr)
|
524 |
+
mask = torch.view_as_complex(mask)
|
525 |
+
|
526 |
+
stft_repr = stft_repr * mask
|
527 |
+
|
528 |
+
# istft
|
529 |
+
|
530 |
+
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
|
531 |
+
|
532 |
+
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
|
533 |
+
|
534 |
+
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
|
535 |
+
|
536 |
+
if num_stems == 1:
|
537 |
+
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
|
538 |
+
|
539 |
+
# if a target is passed in, calculate loss for learning
|
540 |
+
|
541 |
+
if not exists(target):
|
542 |
+
return recon_audio
|
543 |
+
|
544 |
+
if self.num_stems > 1:
|
545 |
+
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
546 |
+
|
547 |
+
if target.ndim == 2:
|
548 |
+
target = rearrange(target, '... t -> ... 1 t')
|
549 |
+
|
550 |
+
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
|
551 |
+
|
552 |
+
loss = F.l1_loss(recon_audio, target)
|
553 |
+
|
554 |
+
multi_stft_resolution_loss = 0.
|
555 |
+
|
556 |
+
for window_size in self.multi_stft_resolutions_window_sizes:
|
557 |
+
res_stft_kwargs = dict(
|
558 |
+
n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
|
559 |
+
win_length=window_size,
|
560 |
+
return_complex=True,
|
561 |
+
window=self.multi_stft_window_fn(window_size, device=device),
|
562 |
+
**self.multi_stft_kwargs,
|
563 |
+
)
|
564 |
+
|
565 |
+
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
|
566 |
+
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
|
567 |
+
|
568 |
+
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
569 |
+
|
570 |
+
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
571 |
+
|
572 |
+
total_loss = loss + weighted_multi_resolution_loss
|
573 |
+
|
574 |
+
if not return_loss_breakdown:
|
575 |
+
return total_loss
|
576 |
+
|
577 |
+
return total_loss, (loss, multi_stft_resolution_loss)
|
models/bs_roformer/mel_band_roformer.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn, einsum, Tensor
|
5 |
+
from torch.nn import Module, ModuleList
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from models.bs_roformer.attend import Attend
|
9 |
+
|
10 |
+
from beartype.typing import Tuple, Optional, List, Callable
|
11 |
+
from beartype import beartype
|
12 |
+
|
13 |
+
from rotary_embedding_torch import RotaryEmbedding
|
14 |
+
|
15 |
+
from einops import rearrange, pack, unpack, reduce, repeat
|
16 |
+
from einops.layers.torch import Rearrange
|
17 |
+
|
18 |
+
from librosa import filters
|
19 |
+
|
20 |
+
|
21 |
+
# helper functions
|
22 |
+
|
23 |
+
def exists(val):
|
24 |
+
return val is not None
|
25 |
+
|
26 |
+
|
27 |
+
def default(v, d):
|
28 |
+
return v if exists(v) else d
|
29 |
+
|
30 |
+
|
31 |
+
def pack_one(t, pattern):
|
32 |
+
return pack([t], pattern)
|
33 |
+
|
34 |
+
|
35 |
+
def unpack_one(t, ps, pattern):
|
36 |
+
return unpack(t, ps, pattern)[0]
|
37 |
+
|
38 |
+
|
39 |
+
def pad_at_dim(t, pad, dim=-1, value=0.):
|
40 |
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
41 |
+
zeros = ((0, 0) * dims_from_right)
|
42 |
+
return F.pad(t, (*zeros, *pad), value=value)
|
43 |
+
|
44 |
+
|
45 |
+
def l2norm(t):
|
46 |
+
return F.normalize(t, dim=-1, p=2)
|
47 |
+
|
48 |
+
|
49 |
+
# norm
|
50 |
+
|
51 |
+
class RMSNorm(Module):
|
52 |
+
def __init__(self, dim):
|
53 |
+
super().__init__()
|
54 |
+
self.scale = dim ** 0.5
|
55 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
59 |
+
|
60 |
+
|
61 |
+
# attention
|
62 |
+
|
63 |
+
class FeedForward(Module):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
dim,
|
67 |
+
mult=4,
|
68 |
+
dropout=0.
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
dim_inner = int(dim * mult)
|
72 |
+
self.net = nn.Sequential(
|
73 |
+
RMSNorm(dim),
|
74 |
+
nn.Linear(dim, dim_inner),
|
75 |
+
nn.GELU(),
|
76 |
+
nn.Dropout(dropout),
|
77 |
+
nn.Linear(dim_inner, dim),
|
78 |
+
nn.Dropout(dropout)
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
return self.net(x)
|
83 |
+
|
84 |
+
|
85 |
+
class Attention(Module):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
dim,
|
89 |
+
heads=8,
|
90 |
+
dim_head=64,
|
91 |
+
dropout=0.,
|
92 |
+
rotary_embed=None,
|
93 |
+
flash=True
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
self.heads = heads
|
97 |
+
self.scale = dim_head ** -0.5
|
98 |
+
dim_inner = heads * dim_head
|
99 |
+
|
100 |
+
self.rotary_embed = rotary_embed
|
101 |
+
|
102 |
+
self.attend = Attend(flash=flash, dropout=dropout)
|
103 |
+
|
104 |
+
self.norm = RMSNorm(dim)
|
105 |
+
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
106 |
+
|
107 |
+
self.to_gates = nn.Linear(dim, heads)
|
108 |
+
|
109 |
+
self.to_out = nn.Sequential(
|
110 |
+
nn.Linear(dim_inner, dim, bias=False),
|
111 |
+
nn.Dropout(dropout)
|
112 |
+
)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
x = self.norm(x)
|
116 |
+
|
117 |
+
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
|
118 |
+
|
119 |
+
if exists(self.rotary_embed):
|
120 |
+
q = self.rotary_embed.rotate_queries_or_keys(q)
|
121 |
+
k = self.rotary_embed.rotate_queries_or_keys(k)
|
122 |
+
|
123 |
+
out = self.attend(q, k, v)
|
124 |
+
|
125 |
+
gates = self.to_gates(x)
|
126 |
+
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
|
127 |
+
|
128 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
129 |
+
return self.to_out(out)
|
130 |
+
|
131 |
+
|
132 |
+
class LinearAttention(Module):
|
133 |
+
"""
|
134 |
+
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
135 |
+
"""
|
136 |
+
|
137 |
+
@beartype
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
*,
|
141 |
+
dim,
|
142 |
+
dim_head=32,
|
143 |
+
heads=8,
|
144 |
+
scale=8,
|
145 |
+
flash=False,
|
146 |
+
dropout=0.
|
147 |
+
):
|
148 |
+
super().__init__()
|
149 |
+
dim_inner = dim_head * heads
|
150 |
+
self.norm = RMSNorm(dim)
|
151 |
+
|
152 |
+
self.to_qkv = nn.Sequential(
|
153 |
+
nn.Linear(dim, dim_inner * 3, bias=False),
|
154 |
+
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
|
155 |
+
)
|
156 |
+
|
157 |
+
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
158 |
+
|
159 |
+
self.attend = Attend(
|
160 |
+
scale=scale,
|
161 |
+
dropout=dropout,
|
162 |
+
flash=flash
|
163 |
+
)
|
164 |
+
|
165 |
+
self.to_out = nn.Sequential(
|
166 |
+
Rearrange('b h d n -> b n (h d)'),
|
167 |
+
nn.Linear(dim_inner, dim, bias=False)
|
168 |
+
)
|
169 |
+
|
170 |
+
def forward(
|
171 |
+
self,
|
172 |
+
x
|
173 |
+
):
|
174 |
+
x = self.norm(x)
|
175 |
+
|
176 |
+
q, k, v = self.to_qkv(x)
|
177 |
+
|
178 |
+
q, k = map(l2norm, (q, k))
|
179 |
+
q = q * self.temperature.exp()
|
180 |
+
|
181 |
+
out = self.attend(q, k, v)
|
182 |
+
|
183 |
+
return self.to_out(out)
|
184 |
+
|
185 |
+
|
186 |
+
class Transformer(Module):
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
*,
|
190 |
+
dim,
|
191 |
+
depth,
|
192 |
+
dim_head=64,
|
193 |
+
heads=8,
|
194 |
+
attn_dropout=0.,
|
195 |
+
ff_dropout=0.,
|
196 |
+
ff_mult=4,
|
197 |
+
norm_output=True,
|
198 |
+
rotary_embed=None,
|
199 |
+
flash_attn=True,
|
200 |
+
linear_attn=False
|
201 |
+
):
|
202 |
+
super().__init__()
|
203 |
+
self.layers = ModuleList([])
|
204 |
+
|
205 |
+
for _ in range(depth):
|
206 |
+
if linear_attn:
|
207 |
+
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
|
208 |
+
else:
|
209 |
+
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
|
210 |
+
rotary_embed=rotary_embed, flash=flash_attn)
|
211 |
+
|
212 |
+
self.layers.append(ModuleList([
|
213 |
+
attn,
|
214 |
+
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
|
215 |
+
]))
|
216 |
+
|
217 |
+
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
218 |
+
|
219 |
+
def forward(self, x):
|
220 |
+
|
221 |
+
for attn, ff in self.layers:
|
222 |
+
x = attn(x) + x
|
223 |
+
x = ff(x) + x
|
224 |
+
|
225 |
+
return self.norm(x)
|
226 |
+
|
227 |
+
|
228 |
+
# bandsplit module
|
229 |
+
|
230 |
+
class BandSplit(Module):
|
231 |
+
@beartype
|
232 |
+
def __init__(
|
233 |
+
self,
|
234 |
+
dim,
|
235 |
+
dim_inputs: Tuple[int, ...]
|
236 |
+
):
|
237 |
+
super().__init__()
|
238 |
+
self.dim_inputs = dim_inputs
|
239 |
+
self.to_features = ModuleList([])
|
240 |
+
|
241 |
+
for dim_in in dim_inputs:
|
242 |
+
net = nn.Sequential(
|
243 |
+
RMSNorm(dim_in),
|
244 |
+
nn.Linear(dim_in, dim)
|
245 |
+
)
|
246 |
+
|
247 |
+
self.to_features.append(net)
|
248 |
+
|
249 |
+
def forward(self, x):
|
250 |
+
x = x.split(self.dim_inputs, dim=-1)
|
251 |
+
|
252 |
+
outs = []
|
253 |
+
for split_input, to_feature in zip(x, self.to_features):
|
254 |
+
split_output = to_feature(split_input)
|
255 |
+
outs.append(split_output)
|
256 |
+
|
257 |
+
return torch.stack(outs, dim=-2)
|
258 |
+
|
259 |
+
|
260 |
+
def MLP(
|
261 |
+
dim_in,
|
262 |
+
dim_out,
|
263 |
+
dim_hidden=None,
|
264 |
+
depth=1,
|
265 |
+
activation=nn.Tanh
|
266 |
+
):
|
267 |
+
dim_hidden = default(dim_hidden, dim_in)
|
268 |
+
|
269 |
+
net = []
|
270 |
+
dims = (dim_in, *((dim_hidden,) * depth), dim_out)
|
271 |
+
|
272 |
+
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
273 |
+
is_last = ind == (len(dims) - 2)
|
274 |
+
|
275 |
+
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
276 |
+
|
277 |
+
if is_last:
|
278 |
+
continue
|
279 |
+
|
280 |
+
net.append(activation())
|
281 |
+
|
282 |
+
return nn.Sequential(*net)
|
283 |
+
|
284 |
+
|
285 |
+
class MaskEstimator(Module):
|
286 |
+
@beartype
|
287 |
+
def __init__(
|
288 |
+
self,
|
289 |
+
dim,
|
290 |
+
dim_inputs: Tuple[int, ...],
|
291 |
+
depth,
|
292 |
+
mlp_expansion_factor=4
|
293 |
+
):
|
294 |
+
super().__init__()
|
295 |
+
self.dim_inputs = dim_inputs
|
296 |
+
self.to_freqs = ModuleList([])
|
297 |
+
dim_hidden = dim * mlp_expansion_factor
|
298 |
+
|
299 |
+
for dim_in in dim_inputs:
|
300 |
+
net = []
|
301 |
+
|
302 |
+
mlp = nn.Sequential(
|
303 |
+
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
|
304 |
+
nn.GLU(dim=-1)
|
305 |
+
)
|
306 |
+
|
307 |
+
self.to_freqs.append(mlp)
|
308 |
+
|
309 |
+
def forward(self, x):
|
310 |
+
x = x.unbind(dim=-2)
|
311 |
+
|
312 |
+
outs = []
|
313 |
+
|
314 |
+
for band_features, mlp in zip(x, self.to_freqs):
|
315 |
+
freq_out = mlp(band_features)
|
316 |
+
outs.append(freq_out)
|
317 |
+
|
318 |
+
return torch.cat(outs, dim=-1)
|
319 |
+
|
320 |
+
|
321 |
+
# main class
|
322 |
+
|
323 |
+
class MelBandRoformer(Module):
|
324 |
+
|
325 |
+
@beartype
|
326 |
+
def __init__(
|
327 |
+
self,
|
328 |
+
dim,
|
329 |
+
*,
|
330 |
+
depth,
|
331 |
+
stereo=False,
|
332 |
+
num_stems=1,
|
333 |
+
time_transformer_depth=2,
|
334 |
+
freq_transformer_depth=2,
|
335 |
+
linear_transformer_depth=0,
|
336 |
+
num_bands=60,
|
337 |
+
dim_head=64,
|
338 |
+
heads=8,
|
339 |
+
attn_dropout=0.1,
|
340 |
+
ff_dropout=0.1,
|
341 |
+
flash_attn=True,
|
342 |
+
dim_freqs_in=1025,
|
343 |
+
sample_rate=44100, # needed for mel filter bank from librosa
|
344 |
+
stft_n_fft=2048,
|
345 |
+
stft_hop_length=512,
|
346 |
+
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
347 |
+
stft_win_length=2048,
|
348 |
+
stft_normalized=False,
|
349 |
+
stft_window_fn: Optional[Callable] = None,
|
350 |
+
mask_estimator_depth=1,
|
351 |
+
multi_stft_resolution_loss_weight=1.,
|
352 |
+
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
353 |
+
multi_stft_hop_size=147,
|
354 |
+
multi_stft_normalized=False,
|
355 |
+
multi_stft_window_fn: Callable = torch.hann_window,
|
356 |
+
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
|
357 |
+
):
|
358 |
+
super().__init__()
|
359 |
+
|
360 |
+
self.stereo = stereo
|
361 |
+
self.audio_channels = 2 if stereo else 1
|
362 |
+
self.num_stems = num_stems
|
363 |
+
|
364 |
+
self.layers = ModuleList([])
|
365 |
+
|
366 |
+
transformer_kwargs = dict(
|
367 |
+
dim=dim,
|
368 |
+
heads=heads,
|
369 |
+
dim_head=dim_head,
|
370 |
+
attn_dropout=attn_dropout,
|
371 |
+
ff_dropout=ff_dropout,
|
372 |
+
flash_attn=flash_attn
|
373 |
+
)
|
374 |
+
|
375 |
+
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
376 |
+
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
377 |
+
|
378 |
+
for _ in range(depth):
|
379 |
+
tran_modules = []
|
380 |
+
if linear_transformer_depth > 0:
|
381 |
+
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
|
382 |
+
tran_modules.append(
|
383 |
+
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
|
384 |
+
)
|
385 |
+
tran_modules.append(
|
386 |
+
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
|
387 |
+
)
|
388 |
+
self.layers.append(nn.ModuleList(tran_modules))
|
389 |
+
|
390 |
+
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
391 |
+
|
392 |
+
self.stft_kwargs = dict(
|
393 |
+
n_fft=stft_n_fft,
|
394 |
+
hop_length=stft_hop_length,
|
395 |
+
win_length=stft_win_length,
|
396 |
+
normalized=stft_normalized
|
397 |
+
)
|
398 |
+
|
399 |
+
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
|
400 |
+
|
401 |
+
# create mel filter bank
|
402 |
+
# with librosa.filters.mel as in section 2 of paper
|
403 |
+
|
404 |
+
mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
|
405 |
+
|
406 |
+
mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
|
407 |
+
|
408 |
+
# for some reason, it doesn't include the first freq? just force a value for now
|
409 |
+
|
410 |
+
mel_filter_bank[0][0] = 1.
|
411 |
+
|
412 |
+
# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
|
413 |
+
# so let's force a positive value
|
414 |
+
|
415 |
+
mel_filter_bank[-1, -1] = 1.
|
416 |
+
|
417 |
+
# binary as in paper (then estimated masks are averaged for overlapping regions)
|
418 |
+
|
419 |
+
freqs_per_band = mel_filter_bank > 0
|
420 |
+
assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
|
421 |
+
|
422 |
+
repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
|
423 |
+
freq_indices = repeated_freq_indices[freqs_per_band]
|
424 |
+
|
425 |
+
if stereo:
|
426 |
+
freq_indices = repeat(freq_indices, 'f -> f s', s=2)
|
427 |
+
freq_indices = freq_indices * 2 + torch.arange(2)
|
428 |
+
freq_indices = rearrange(freq_indices, 'f s -> (f s)')
|
429 |
+
|
430 |
+
self.register_buffer('freq_indices', freq_indices, persistent=False)
|
431 |
+
self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
|
432 |
+
|
433 |
+
num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
|
434 |
+
num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
|
435 |
+
|
436 |
+
self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
|
437 |
+
self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
|
438 |
+
|
439 |
+
# band split and mask estimator
|
440 |
+
|
441 |
+
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
|
442 |
+
|
443 |
+
self.band_split = BandSplit(
|
444 |
+
dim=dim,
|
445 |
+
dim_inputs=freqs_per_bands_with_complex
|
446 |
+
)
|
447 |
+
|
448 |
+
self.mask_estimators = nn.ModuleList([])
|
449 |
+
|
450 |
+
for _ in range(num_stems):
|
451 |
+
mask_estimator = MaskEstimator(
|
452 |
+
dim=dim,
|
453 |
+
dim_inputs=freqs_per_bands_with_complex,
|
454 |
+
depth=mask_estimator_depth
|
455 |
+
)
|
456 |
+
|
457 |
+
self.mask_estimators.append(mask_estimator)
|
458 |
+
|
459 |
+
# for the multi-resolution stft loss
|
460 |
+
|
461 |
+
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
462 |
+
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
463 |
+
self.multi_stft_n_fft = stft_n_fft
|
464 |
+
self.multi_stft_window_fn = multi_stft_window_fn
|
465 |
+
|
466 |
+
self.multi_stft_kwargs = dict(
|
467 |
+
hop_length=multi_stft_hop_size,
|
468 |
+
normalized=multi_stft_normalized
|
469 |
+
)
|
470 |
+
|
471 |
+
self.match_input_audio_length = match_input_audio_length
|
472 |
+
|
473 |
+
def forward(
|
474 |
+
self,
|
475 |
+
raw_audio,
|
476 |
+
target=None,
|
477 |
+
return_loss_breakdown=False
|
478 |
+
):
|
479 |
+
"""
|
480 |
+
einops
|
481 |
+
|
482 |
+
b - batch
|
483 |
+
f - freq
|
484 |
+
t - time
|
485 |
+
s - audio channel (1 for mono, 2 for stereo)
|
486 |
+
n - number of 'stems'
|
487 |
+
c - complex (2)
|
488 |
+
d - feature dimension
|
489 |
+
"""
|
490 |
+
|
491 |
+
device = raw_audio.device
|
492 |
+
|
493 |
+
if raw_audio.ndim == 2:
|
494 |
+
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
|
495 |
+
|
496 |
+
batch, channels, raw_audio_length = raw_audio.shape
|
497 |
+
|
498 |
+
istft_length = raw_audio_length if self.match_input_audio_length else None
|
499 |
+
|
500 |
+
assert (not self.stereo and channels == 1) or (
|
501 |
+
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
502 |
+
|
503 |
+
# to stft
|
504 |
+
|
505 |
+
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
|
506 |
+
|
507 |
+
stft_window = self.stft_window_fn(device=device)
|
508 |
+
|
509 |
+
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
510 |
+
stft_repr = torch.view_as_real(stft_repr)
|
511 |
+
|
512 |
+
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
|
513 |
+
stft_repr = rearrange(stft_repr,
|
514 |
+
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
515 |
+
|
516 |
+
# index out all frequencies for all frequency ranges across bands ascending in one go
|
517 |
+
|
518 |
+
batch_arange = torch.arange(batch, device=device)[..., None]
|
519 |
+
|
520 |
+
# account for stereo
|
521 |
+
|
522 |
+
x = stft_repr[batch_arange, self.freq_indices]
|
523 |
+
|
524 |
+
# fold the complex (real and imag) into the frequencies dimension
|
525 |
+
|
526 |
+
x = rearrange(x, 'b f t c -> b t (f c)')
|
527 |
+
|
528 |
+
x = self.band_split(x)
|
529 |
+
|
530 |
+
# axial / hierarchical attention
|
531 |
+
|
532 |
+
for transformer_block in self.layers:
|
533 |
+
|
534 |
+
if len(transformer_block) == 3:
|
535 |
+
linear_transformer, time_transformer, freq_transformer = transformer_block
|
536 |
+
|
537 |
+
x, ft_ps = pack([x], 'b * d')
|
538 |
+
x = linear_transformer(x)
|
539 |
+
x, = unpack(x, ft_ps, 'b * d')
|
540 |
+
else:
|
541 |
+
time_transformer, freq_transformer = transformer_block
|
542 |
+
|
543 |
+
x = rearrange(x, 'b t f d -> b f t d')
|
544 |
+
x, ps = pack([x], '* t d')
|
545 |
+
|
546 |
+
x = time_transformer(x)
|
547 |
+
|
548 |
+
x, = unpack(x, ps, '* t d')
|
549 |
+
x = rearrange(x, 'b f t d -> b t f d')
|
550 |
+
x, ps = pack([x], '* f d')
|
551 |
+
|
552 |
+
x = freq_transformer(x)
|
553 |
+
|
554 |
+
x, = unpack(x, ps, '* f d')
|
555 |
+
|
556 |
+
num_stems = len(self.mask_estimators)
|
557 |
+
|
558 |
+
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
559 |
+
masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
|
560 |
+
|
561 |
+
# modulate frequency representation
|
562 |
+
|
563 |
+
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
|
564 |
+
|
565 |
+
# complex number multiplication
|
566 |
+
|
567 |
+
stft_repr = torch.view_as_complex(stft_repr)
|
568 |
+
masks = torch.view_as_complex(masks)
|
569 |
+
|
570 |
+
masks = masks.type(stft_repr.dtype)
|
571 |
+
|
572 |
+
# need to average the estimated mask for the overlapped frequencies
|
573 |
+
|
574 |
+
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
|
575 |
+
|
576 |
+
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
|
577 |
+
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
|
578 |
+
|
579 |
+
denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
|
580 |
+
|
581 |
+
masks_averaged = masks_summed / denom.clamp(min=1e-8)
|
582 |
+
|
583 |
+
# modulate stft repr with estimated mask
|
584 |
+
|
585 |
+
stft_repr = stft_repr * masks_averaged
|
586 |
+
|
587 |
+
# istft
|
588 |
+
|
589 |
+
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
|
590 |
+
|
591 |
+
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
|
592 |
+
length=istft_length)
|
593 |
+
|
594 |
+
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
|
595 |
+
|
596 |
+
if num_stems == 1:
|
597 |
+
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
|
598 |
+
|
599 |
+
# if a target is passed in, calculate loss for learning
|
600 |
+
|
601 |
+
if not exists(target):
|
602 |
+
return recon_audio
|
603 |
+
|
604 |
+
if self.num_stems > 1:
|
605 |
+
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
606 |
+
|
607 |
+
if target.ndim == 2:
|
608 |
+
target = rearrange(target, '... t -> ... 1 t')
|
609 |
+
|
610 |
+
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
|
611 |
+
|
612 |
+
loss = F.l1_loss(recon_audio, target)
|
613 |
+
|
614 |
+
multi_stft_resolution_loss = 0.
|
615 |
+
|
616 |
+
for window_size in self.multi_stft_resolutions_window_sizes:
|
617 |
+
res_stft_kwargs = dict(
|
618 |
+
n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
|
619 |
+
win_length=window_size,
|
620 |
+
return_complex=True,
|
621 |
+
window=self.multi_stft_window_fn(window_size, device=device),
|
622 |
+
**self.multi_stft_kwargs,
|
623 |
+
)
|
624 |
+
|
625 |
+
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
|
626 |
+
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
|
627 |
+
|
628 |
+
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
629 |
+
|
630 |
+
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
631 |
+
|
632 |
+
total_loss = loss + weighted_multi_resolution_loss
|
633 |
+
|
634 |
+
if not return_loss_breakdown:
|
635 |
+
return total_loss
|
636 |
+
|
637 |
+
return total_loss, (loss, multi_stft_resolution_loss)
|
models/demucs4ht.py
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import json
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from demucs.demucs import Demucs
|
11 |
+
from demucs.hdemucs import HDemucs
|
12 |
+
|
13 |
+
import math
|
14 |
+
from openunmix.filtering import wiener
|
15 |
+
from torch import nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from fractions import Fraction
|
18 |
+
from einops import rearrange
|
19 |
+
|
20 |
+
from demucs.transformer import CrossTransformerEncoder
|
21 |
+
|
22 |
+
from demucs.demucs import rescale_module
|
23 |
+
from demucs.states import capture_init
|
24 |
+
from demucs.spec import spectro, ispectro
|
25 |
+
from demucs.hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
|
26 |
+
|
27 |
+
|
28 |
+
class HTDemucs(nn.Module):
|
29 |
+
"""
|
30 |
+
Spectrogram and hybrid Demucs model.
|
31 |
+
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
32 |
+
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
33 |
+
Frequency layers can still access information across time steps thanks to the DConv residual.
|
34 |
+
|
35 |
+
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
36 |
+
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
37 |
+
|
38 |
+
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
39 |
+
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
40 |
+
Open Unmix implementation [Stoter et al. 2019].
|
41 |
+
|
42 |
+
The loss is always on the temporal domain, by backpropagating through the above
|
43 |
+
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
44 |
+
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
45 |
+
contribution, without changing the one from the waveform, which will lead to worse performance.
|
46 |
+
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
47 |
+
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
48 |
+
hybrid models.
|
49 |
+
|
50 |
+
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
51 |
+
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
52 |
+
|
53 |
+
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
54 |
+
"""
|
55 |
+
|
56 |
+
@capture_init
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
sources,
|
60 |
+
# Channels
|
61 |
+
audio_channels=2,
|
62 |
+
channels=48,
|
63 |
+
channels_time=None,
|
64 |
+
growth=2,
|
65 |
+
# STFT
|
66 |
+
nfft=4096,
|
67 |
+
num_subbands=1,
|
68 |
+
wiener_iters=0,
|
69 |
+
end_iters=0,
|
70 |
+
wiener_residual=False,
|
71 |
+
cac=True,
|
72 |
+
# Main structure
|
73 |
+
depth=4,
|
74 |
+
rewrite=True,
|
75 |
+
# Frequency branch
|
76 |
+
multi_freqs=None,
|
77 |
+
multi_freqs_depth=3,
|
78 |
+
freq_emb=0.2,
|
79 |
+
emb_scale=10,
|
80 |
+
emb_smooth=True,
|
81 |
+
# Convolutions
|
82 |
+
kernel_size=8,
|
83 |
+
time_stride=2,
|
84 |
+
stride=4,
|
85 |
+
context=1,
|
86 |
+
context_enc=0,
|
87 |
+
# Normalization
|
88 |
+
norm_starts=4,
|
89 |
+
norm_groups=4,
|
90 |
+
# DConv residual branch
|
91 |
+
dconv_mode=1,
|
92 |
+
dconv_depth=2,
|
93 |
+
dconv_comp=8,
|
94 |
+
dconv_init=1e-3,
|
95 |
+
# Before the Transformer
|
96 |
+
bottom_channels=0,
|
97 |
+
# Transformer
|
98 |
+
t_layers=5,
|
99 |
+
t_emb="sin",
|
100 |
+
t_hidden_scale=4.0,
|
101 |
+
t_heads=8,
|
102 |
+
t_dropout=0.0,
|
103 |
+
t_max_positions=10000,
|
104 |
+
t_norm_in=True,
|
105 |
+
t_norm_in_group=False,
|
106 |
+
t_group_norm=False,
|
107 |
+
t_norm_first=True,
|
108 |
+
t_norm_out=True,
|
109 |
+
t_max_period=10000.0,
|
110 |
+
t_weight_decay=0.0,
|
111 |
+
t_lr=None,
|
112 |
+
t_layer_scale=True,
|
113 |
+
t_gelu=True,
|
114 |
+
t_weight_pos_embed=1.0,
|
115 |
+
t_sin_random_shift=0,
|
116 |
+
t_cape_mean_normalize=True,
|
117 |
+
t_cape_augment=True,
|
118 |
+
t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
|
119 |
+
t_sparse_self_attn=False,
|
120 |
+
t_sparse_cross_attn=False,
|
121 |
+
t_mask_type="diag",
|
122 |
+
t_mask_random_seed=42,
|
123 |
+
t_sparse_attn_window=500,
|
124 |
+
t_global_window=100,
|
125 |
+
t_sparsity=0.95,
|
126 |
+
t_auto_sparsity=False,
|
127 |
+
# ------ Particuliar parameters
|
128 |
+
t_cross_first=False,
|
129 |
+
# Weight init
|
130 |
+
rescale=0.1,
|
131 |
+
# Metadata
|
132 |
+
samplerate=44100,
|
133 |
+
segment=10,
|
134 |
+
use_train_segment=False,
|
135 |
+
):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
sources (list[str]): list of source names.
|
139 |
+
audio_channels (int): input/output audio channels.
|
140 |
+
channels (int): initial number of hidden channels.
|
141 |
+
channels_time: if not None, use a different `channels` value for the time branch.
|
142 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
143 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
144 |
+
various shape parameters and will not work out of the box for hybrid models.
|
145 |
+
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
146 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
147 |
+
wiener_residual: add residual source before wiener filtering.
|
148 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
149 |
+
in input and output. no further processing is done before ISTFT.
|
150 |
+
depth (int): number of layers in the encoder and in the decoder.
|
151 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
152 |
+
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
153 |
+
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
154 |
+
layers will be wrapped.
|
155 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
156 |
+
the actual value controls the weight of the embedding.
|
157 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
158 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
159 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
160 |
+
stride: stride for encoder and decoder layers.
|
161 |
+
time_stride: stride for the final time layer, after the merge.
|
162 |
+
context: context for 1x1 conv in the decoder.
|
163 |
+
context_enc: context for 1x1 conv in the encoder.
|
164 |
+
norm_starts: layer at which group norm starts being used.
|
165 |
+
decoder layers are numbered in reverse order.
|
166 |
+
norm_groups: number of groups for group norm.
|
167 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
168 |
+
dconv_depth: depth of residual DConv branch.
|
169 |
+
dconv_comp: compression of DConv branch.
|
170 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
171 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
172 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
173 |
+
bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
|
174 |
+
transformer in order to change the number of channels
|
175 |
+
t_layers: number of layers in each branch (waveform and spec) of the transformer
|
176 |
+
t_emb: "sin", "cape" or "scaled"
|
177 |
+
t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
|
178 |
+
for instance if C = 384 (the number of channels in the transformer) and
|
179 |
+
t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
|
180 |
+
384 * 4 = 1536
|
181 |
+
t_heads: number of heads for the transformer
|
182 |
+
t_dropout: dropout in the transformer
|
183 |
+
t_max_positions: max_positions for the "scaled" positional embedding, only
|
184 |
+
useful if t_emb="scaled"
|
185 |
+
t_norm_in: (bool) norm before addinf positional embedding and getting into the
|
186 |
+
transformer layers
|
187 |
+
t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
|
188 |
+
timesteps (GroupNorm with group=1)
|
189 |
+
t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
|
190 |
+
timesteps (GroupNorm with group=1)
|
191 |
+
t_norm_first: (bool) if True the norm is before the attention and before the FFN
|
192 |
+
t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
|
193 |
+
t_max_period: (float) denominator in the sinusoidal embedding expression
|
194 |
+
t_weight_decay: (float) weight decay for the transformer
|
195 |
+
t_lr: (float) specific learning rate for the transformer
|
196 |
+
t_layer_scale: (bool) Layer Scale for the transformer
|
197 |
+
t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
|
198 |
+
t_weight_pos_embed: (float) weighting of the positional embedding
|
199 |
+
t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
|
200 |
+
see: https://arxiv.org/abs/2106.03143
|
201 |
+
t_cape_augment: (bool) if t_emb="cape", must be True during training and False
|
202 |
+
during the inference, see: https://arxiv.org/abs/2106.03143
|
203 |
+
t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
|
204 |
+
see: https://arxiv.org/abs/2106.03143
|
205 |
+
t_sparse_self_attn: (bool) if True, the self attentions are sparse
|
206 |
+
t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
|
207 |
+
unless you designed really specific masks)
|
208 |
+
t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
|
209 |
+
with '_' between: i.e. "diag_jmask_random" (note that this is permutation
|
210 |
+
invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
|
211 |
+
t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
|
212 |
+
that generated the random part of the mask
|
213 |
+
t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
|
214 |
+
a key (j), the mask is True id |i-j|<=t_sparse_attn_window
|
215 |
+
t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
|
216 |
+
and mask[:, :t_global_window] will be True
|
217 |
+
t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
|
218 |
+
level of the random part of the mask.
|
219 |
+
t_cross_first: (bool) if True cross attention is the first layer of the
|
220 |
+
transformer (False seems to be better)
|
221 |
+
rescale: weight rescaling trick
|
222 |
+
use_train_segment: (bool) if True, the actual size that is used during the
|
223 |
+
training is used during inference.
|
224 |
+
"""
|
225 |
+
super().__init__()
|
226 |
+
self.num_subbands = num_subbands
|
227 |
+
self.cac = cac
|
228 |
+
self.wiener_residual = wiener_residual
|
229 |
+
self.audio_channels = audio_channels
|
230 |
+
self.sources = sources
|
231 |
+
self.kernel_size = kernel_size
|
232 |
+
self.context = context
|
233 |
+
self.stride = stride
|
234 |
+
self.depth = depth
|
235 |
+
self.bottom_channels = bottom_channels
|
236 |
+
self.channels = channels
|
237 |
+
self.samplerate = samplerate
|
238 |
+
self.segment = segment
|
239 |
+
self.use_train_segment = use_train_segment
|
240 |
+
self.nfft = nfft
|
241 |
+
self.hop_length = nfft // 4
|
242 |
+
self.wiener_iters = wiener_iters
|
243 |
+
self.end_iters = end_iters
|
244 |
+
self.freq_emb = None
|
245 |
+
assert wiener_iters == end_iters
|
246 |
+
|
247 |
+
self.encoder = nn.ModuleList()
|
248 |
+
self.decoder = nn.ModuleList()
|
249 |
+
|
250 |
+
self.tencoder = nn.ModuleList()
|
251 |
+
self.tdecoder = nn.ModuleList()
|
252 |
+
|
253 |
+
chin = audio_channels
|
254 |
+
chin_z = chin # number of channels for the freq branch
|
255 |
+
if self.cac:
|
256 |
+
chin_z *= 2
|
257 |
+
if self.num_subbands > 1:
|
258 |
+
chin_z *= self.num_subbands
|
259 |
+
chout = channels_time or channels
|
260 |
+
chout_z = channels
|
261 |
+
freqs = nfft // 2
|
262 |
+
|
263 |
+
for index in range(depth):
|
264 |
+
norm = index >= norm_starts
|
265 |
+
freq = freqs > 1
|
266 |
+
stri = stride
|
267 |
+
ker = kernel_size
|
268 |
+
if not freq:
|
269 |
+
assert freqs == 1
|
270 |
+
ker = time_stride * 2
|
271 |
+
stri = time_stride
|
272 |
+
|
273 |
+
pad = True
|
274 |
+
last_freq = False
|
275 |
+
if freq and freqs <= kernel_size:
|
276 |
+
ker = freqs
|
277 |
+
pad = False
|
278 |
+
last_freq = True
|
279 |
+
|
280 |
+
kw = {
|
281 |
+
"kernel_size": ker,
|
282 |
+
"stride": stri,
|
283 |
+
"freq": freq,
|
284 |
+
"pad": pad,
|
285 |
+
"norm": norm,
|
286 |
+
"rewrite": rewrite,
|
287 |
+
"norm_groups": norm_groups,
|
288 |
+
"dconv_kw": {
|
289 |
+
"depth": dconv_depth,
|
290 |
+
"compress": dconv_comp,
|
291 |
+
"init": dconv_init,
|
292 |
+
"gelu": True,
|
293 |
+
},
|
294 |
+
}
|
295 |
+
kwt = dict(kw)
|
296 |
+
kwt["freq"] = 0
|
297 |
+
kwt["kernel_size"] = kernel_size
|
298 |
+
kwt["stride"] = stride
|
299 |
+
kwt["pad"] = True
|
300 |
+
kw_dec = dict(kw)
|
301 |
+
multi = False
|
302 |
+
if multi_freqs and index < multi_freqs_depth:
|
303 |
+
multi = True
|
304 |
+
kw_dec["context_freq"] = False
|
305 |
+
|
306 |
+
if last_freq:
|
307 |
+
chout_z = max(chout, chout_z)
|
308 |
+
chout = chout_z
|
309 |
+
|
310 |
+
enc = HEncLayer(
|
311 |
+
chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
|
312 |
+
)
|
313 |
+
if freq:
|
314 |
+
tenc = HEncLayer(
|
315 |
+
chin,
|
316 |
+
chout,
|
317 |
+
dconv=dconv_mode & 1,
|
318 |
+
context=context_enc,
|
319 |
+
empty=last_freq,
|
320 |
+
**kwt
|
321 |
+
)
|
322 |
+
self.tencoder.append(tenc)
|
323 |
+
|
324 |
+
if multi:
|
325 |
+
enc = MultiWrap(enc, multi_freqs)
|
326 |
+
self.encoder.append(enc)
|
327 |
+
if index == 0:
|
328 |
+
chin = self.audio_channels * len(self.sources)
|
329 |
+
chin_z = chin
|
330 |
+
if self.cac:
|
331 |
+
chin_z *= 2
|
332 |
+
if self.num_subbands > 1:
|
333 |
+
chin_z *= self.num_subbands
|
334 |
+
dec = HDecLayer(
|
335 |
+
chout_z,
|
336 |
+
chin_z,
|
337 |
+
dconv=dconv_mode & 2,
|
338 |
+
last=index == 0,
|
339 |
+
context=context,
|
340 |
+
**kw_dec
|
341 |
+
)
|
342 |
+
if multi:
|
343 |
+
dec = MultiWrap(dec, multi_freqs)
|
344 |
+
if freq:
|
345 |
+
tdec = HDecLayer(
|
346 |
+
chout,
|
347 |
+
chin,
|
348 |
+
dconv=dconv_mode & 2,
|
349 |
+
empty=last_freq,
|
350 |
+
last=index == 0,
|
351 |
+
context=context,
|
352 |
+
**kwt
|
353 |
+
)
|
354 |
+
self.tdecoder.insert(0, tdec)
|
355 |
+
self.decoder.insert(0, dec)
|
356 |
+
|
357 |
+
chin = chout
|
358 |
+
chin_z = chout_z
|
359 |
+
chout = int(growth * chout)
|
360 |
+
chout_z = int(growth * chout_z)
|
361 |
+
if freq:
|
362 |
+
if freqs <= kernel_size:
|
363 |
+
freqs = 1
|
364 |
+
else:
|
365 |
+
freqs //= stride
|
366 |
+
if index == 0 and freq_emb:
|
367 |
+
self.freq_emb = ScaledEmbedding(
|
368 |
+
freqs, chin_z, smooth=emb_smooth, scale=emb_scale
|
369 |
+
)
|
370 |
+
self.freq_emb_scale = freq_emb
|
371 |
+
|
372 |
+
if rescale:
|
373 |
+
rescale_module(self, reference=rescale)
|
374 |
+
|
375 |
+
transformer_channels = channels * growth ** (depth - 1)
|
376 |
+
if bottom_channels:
|
377 |
+
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
378 |
+
self.channel_downsampler = nn.Conv1d(
|
379 |
+
bottom_channels, transformer_channels, 1
|
380 |
+
)
|
381 |
+
self.channel_upsampler_t = nn.Conv1d(
|
382 |
+
transformer_channels, bottom_channels, 1
|
383 |
+
)
|
384 |
+
self.channel_downsampler_t = nn.Conv1d(
|
385 |
+
bottom_channels, transformer_channels, 1
|
386 |
+
)
|
387 |
+
|
388 |
+
transformer_channels = bottom_channels
|
389 |
+
|
390 |
+
if t_layers > 0:
|
391 |
+
self.crosstransformer = CrossTransformerEncoder(
|
392 |
+
dim=transformer_channels,
|
393 |
+
emb=t_emb,
|
394 |
+
hidden_scale=t_hidden_scale,
|
395 |
+
num_heads=t_heads,
|
396 |
+
num_layers=t_layers,
|
397 |
+
cross_first=t_cross_first,
|
398 |
+
dropout=t_dropout,
|
399 |
+
max_positions=t_max_positions,
|
400 |
+
norm_in=t_norm_in,
|
401 |
+
norm_in_group=t_norm_in_group,
|
402 |
+
group_norm=t_group_norm,
|
403 |
+
norm_first=t_norm_first,
|
404 |
+
norm_out=t_norm_out,
|
405 |
+
max_period=t_max_period,
|
406 |
+
weight_decay=t_weight_decay,
|
407 |
+
lr=t_lr,
|
408 |
+
layer_scale=t_layer_scale,
|
409 |
+
gelu=t_gelu,
|
410 |
+
sin_random_shift=t_sin_random_shift,
|
411 |
+
weight_pos_embed=t_weight_pos_embed,
|
412 |
+
cape_mean_normalize=t_cape_mean_normalize,
|
413 |
+
cape_augment=t_cape_augment,
|
414 |
+
cape_glob_loc_scale=t_cape_glob_loc_scale,
|
415 |
+
sparse_self_attn=t_sparse_self_attn,
|
416 |
+
sparse_cross_attn=t_sparse_cross_attn,
|
417 |
+
mask_type=t_mask_type,
|
418 |
+
mask_random_seed=t_mask_random_seed,
|
419 |
+
sparse_attn_window=t_sparse_attn_window,
|
420 |
+
global_window=t_global_window,
|
421 |
+
sparsity=t_sparsity,
|
422 |
+
auto_sparsity=t_auto_sparsity,
|
423 |
+
)
|
424 |
+
else:
|
425 |
+
self.crosstransformer = None
|
426 |
+
|
427 |
+
def _spec(self, x):
|
428 |
+
hl = self.hop_length
|
429 |
+
nfft = self.nfft
|
430 |
+
x0 = x # noqa
|
431 |
+
|
432 |
+
# We re-pad the signal in order to keep the property
|
433 |
+
# that the size of the output is exactly the size of the input
|
434 |
+
# divided by the stride (here hop_length), when divisible.
|
435 |
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
436 |
+
# which is not supported by torch.stft.
|
437 |
+
# Having all convolution operations follow this convention allow to easily
|
438 |
+
# align the time and frequency branches later on.
|
439 |
+
assert hl == nfft // 4
|
440 |
+
le = int(math.ceil(x.shape[-1] / hl))
|
441 |
+
pad = hl // 2 * 3
|
442 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
|
443 |
+
|
444 |
+
z = spectro(x, nfft, hl)[..., :-1, :]
|
445 |
+
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
446 |
+
z = z[..., 2: 2 + le]
|
447 |
+
return z
|
448 |
+
|
449 |
+
def _ispec(self, z, length=None, scale=0):
|
450 |
+
hl = self.hop_length // (4**scale)
|
451 |
+
z = F.pad(z, (0, 0, 0, 1))
|
452 |
+
z = F.pad(z, (2, 2))
|
453 |
+
pad = hl // 2 * 3
|
454 |
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
455 |
+
x = ispectro(z, hl, length=le)
|
456 |
+
x = x[..., pad: pad + length]
|
457 |
+
return x
|
458 |
+
|
459 |
+
def _magnitude(self, z):
|
460 |
+
# return the magnitude of the spectrogram, except when cac is True,
|
461 |
+
# in which case we just move the complex dimension to the channel one.
|
462 |
+
if self.cac:
|
463 |
+
B, C, Fr, T = z.shape
|
464 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
465 |
+
m = m.reshape(B, C * 2, Fr, T)
|
466 |
+
else:
|
467 |
+
m = z.abs()
|
468 |
+
return m
|
469 |
+
|
470 |
+
def _mask(self, z, m):
|
471 |
+
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
472 |
+
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
473 |
+
niters = self.wiener_iters
|
474 |
+
if self.cac:
|
475 |
+
B, S, C, Fr, T = m.shape
|
476 |
+
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
477 |
+
out = torch.view_as_complex(out.contiguous())
|
478 |
+
return out
|
479 |
+
if self.training:
|
480 |
+
niters = self.end_iters
|
481 |
+
if niters < 0:
|
482 |
+
z = z[:, None]
|
483 |
+
return z / (1e-8 + z.abs()) * m
|
484 |
+
else:
|
485 |
+
return self._wiener(m, z, niters)
|
486 |
+
|
487 |
+
def _wiener(self, mag_out, mix_stft, niters):
|
488 |
+
# apply wiener filtering from OpenUnmix.
|
489 |
+
init = mix_stft.dtype
|
490 |
+
wiener_win_len = 300
|
491 |
+
residual = self.wiener_residual
|
492 |
+
|
493 |
+
B, S, C, Fq, T = mag_out.shape
|
494 |
+
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
495 |
+
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
496 |
+
|
497 |
+
outs = []
|
498 |
+
for sample in range(B):
|
499 |
+
pos = 0
|
500 |
+
out = []
|
501 |
+
for pos in range(0, T, wiener_win_len):
|
502 |
+
frame = slice(pos, pos + wiener_win_len)
|
503 |
+
z_out = wiener(
|
504 |
+
mag_out[sample, frame],
|
505 |
+
mix_stft[sample, frame],
|
506 |
+
niters,
|
507 |
+
residual=residual,
|
508 |
+
)
|
509 |
+
out.append(z_out.transpose(-1, -2))
|
510 |
+
outs.append(torch.cat(out, dim=0))
|
511 |
+
out = torch.view_as_complex(torch.stack(outs, 0))
|
512 |
+
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
513 |
+
if residual:
|
514 |
+
out = out[:, :-1]
|
515 |
+
assert list(out.shape) == [B, S, C, Fq, T]
|
516 |
+
return out.to(init)
|
517 |
+
|
518 |
+
def valid_length(self, length: int):
|
519 |
+
"""
|
520 |
+
Return a length that is appropriate for evaluation.
|
521 |
+
In our case, always return the training length, unless
|
522 |
+
it is smaller than the given length, in which case this
|
523 |
+
raises an error.
|
524 |
+
"""
|
525 |
+
if not self.use_train_segment:
|
526 |
+
return length
|
527 |
+
training_length = int(self.segment * self.samplerate)
|
528 |
+
if training_length < length:
|
529 |
+
raise ValueError(
|
530 |
+
f"Given length {length} is longer than "
|
531 |
+
f"training length {training_length}")
|
532 |
+
return training_length
|
533 |
+
|
534 |
+
def cac2cws(self, x):
|
535 |
+
k = self.num_subbands
|
536 |
+
b, c, f, t = x.shape
|
537 |
+
x = x.reshape(b, c, k, f // k, t)
|
538 |
+
x = x.reshape(b, c * k, f // k, t)
|
539 |
+
return x
|
540 |
+
|
541 |
+
def cws2cac(self, x):
|
542 |
+
k = self.num_subbands
|
543 |
+
b, c, f, t = x.shape
|
544 |
+
x = x.reshape(b, c // k, k, f, t)
|
545 |
+
x = x.reshape(b, c // k, f * k, t)
|
546 |
+
return x
|
547 |
+
|
548 |
+
def forward(self, mix):
|
549 |
+
length = mix.shape[-1]
|
550 |
+
length_pre_pad = None
|
551 |
+
if self.use_train_segment:
|
552 |
+
if self.training:
|
553 |
+
self.segment = Fraction(mix.shape[-1], self.samplerate)
|
554 |
+
else:
|
555 |
+
training_length = int(self.segment * self.samplerate)
|
556 |
+
# print('Training length: {} Segment: {} Sample rate: {}'.format(training_length, self.segment, self.samplerate))
|
557 |
+
if mix.shape[-1] < training_length:
|
558 |
+
length_pre_pad = mix.shape[-1]
|
559 |
+
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
560 |
+
# print("Mix: {}".format(mix.shape))
|
561 |
+
# print("Length: {}".format(length))
|
562 |
+
z = self._spec(mix)
|
563 |
+
# print("Z: {} Type: {}".format(z.shape, z.dtype))
|
564 |
+
mag = self._magnitude(z)
|
565 |
+
x = mag
|
566 |
+
# print("MAG: {} Type: {}".format(x.shape, x.dtype))
|
567 |
+
|
568 |
+
if self.num_subbands > 1:
|
569 |
+
x = self.cac2cws(x)
|
570 |
+
# print("After SUBBANDS: {} Type: {}".format(x.shape, x.dtype))
|
571 |
+
|
572 |
+
B, C, Fq, T = x.shape
|
573 |
+
|
574 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
575 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
576 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
577 |
+
x = (x - mean) / (1e-5 + std)
|
578 |
+
# x will be the freq. branch input.
|
579 |
+
|
580 |
+
# Prepare the time branch input.
|
581 |
+
xt = mix
|
582 |
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
583 |
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
584 |
+
xt = (xt - meant) / (1e-5 + stdt)
|
585 |
+
|
586 |
+
# print("XT: {}".format(xt.shape))
|
587 |
+
|
588 |
+
# okay, this is a giant mess I know...
|
589 |
+
saved = [] # skip connections, freq.
|
590 |
+
saved_t = [] # skip connections, time.
|
591 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
592 |
+
lengths_t = [] # saved lengths for time branch.
|
593 |
+
for idx, encode in enumerate(self.encoder):
|
594 |
+
lengths.append(x.shape[-1])
|
595 |
+
inject = None
|
596 |
+
if idx < len(self.tencoder):
|
597 |
+
# we have not yet merged branches.
|
598 |
+
lengths_t.append(xt.shape[-1])
|
599 |
+
tenc = self.tencoder[idx]
|
600 |
+
xt = tenc(xt)
|
601 |
+
# print("Encode XT {}: {}".format(idx, xt.shape))
|
602 |
+
if not tenc.empty:
|
603 |
+
# save for skip connection
|
604 |
+
saved_t.append(xt)
|
605 |
+
else:
|
606 |
+
# tenc contains just the first conv., so that now time and freq.
|
607 |
+
# branches have the same shape and can be merged.
|
608 |
+
inject = xt
|
609 |
+
x = encode(x, inject)
|
610 |
+
# print("Encode X {}: {}".format(idx, x.shape))
|
611 |
+
if idx == 0 and self.freq_emb is not None:
|
612 |
+
# add frequency embedding to allow for non equivariant convolutions
|
613 |
+
# over the frequency axis.
|
614 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
615 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
616 |
+
x = x + self.freq_emb_scale * emb
|
617 |
+
|
618 |
+
saved.append(x)
|
619 |
+
if self.crosstransformer:
|
620 |
+
if self.bottom_channels:
|
621 |
+
b, c, f, t = x.shape
|
622 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
623 |
+
x = self.channel_upsampler(x)
|
624 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
625 |
+
xt = self.channel_upsampler_t(xt)
|
626 |
+
|
627 |
+
x, xt = self.crosstransformer(x, xt)
|
628 |
+
# print("Cross Tran X {}, XT: {}".format(x.shape, xt.shape))
|
629 |
+
|
630 |
+
if self.bottom_channels:
|
631 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
632 |
+
x = self.channel_downsampler(x)
|
633 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
634 |
+
xt = self.channel_downsampler_t(xt)
|
635 |
+
|
636 |
+
for idx, decode in enumerate(self.decoder):
|
637 |
+
skip = saved.pop(-1)
|
638 |
+
x, pre = decode(x, skip, lengths.pop(-1))
|
639 |
+
# print('Decode {} X: {}'.format(idx, x.shape))
|
640 |
+
# `pre` contains the output just before final transposed convolution,
|
641 |
+
# which is used when the freq. and time branch separate.
|
642 |
+
|
643 |
+
offset = self.depth - len(self.tdecoder)
|
644 |
+
if idx >= offset:
|
645 |
+
tdec = self.tdecoder[idx - offset]
|
646 |
+
length_t = lengths_t.pop(-1)
|
647 |
+
if tdec.empty:
|
648 |
+
assert pre.shape[2] == 1, pre.shape
|
649 |
+
pre = pre[:, :, 0]
|
650 |
+
xt, _ = tdec(pre, None, length_t)
|
651 |
+
else:
|
652 |
+
skip = saved_t.pop(-1)
|
653 |
+
xt, _ = tdec(xt, skip, length_t)
|
654 |
+
# print('Decode {} XT: {}'.format(idx, xt.shape))
|
655 |
+
|
656 |
+
# Let's make sure we used all stored skip connections.
|
657 |
+
assert len(saved) == 0
|
658 |
+
assert len(lengths_t) == 0
|
659 |
+
assert len(saved_t) == 0
|
660 |
+
|
661 |
+
S = len(self.sources)
|
662 |
+
|
663 |
+
if self.num_subbands > 1:
|
664 |
+
x = x.view(B, -1, Fq, T)
|
665 |
+
# print("X view 1: {}".format(x.shape))
|
666 |
+
x = self.cws2cac(x)
|
667 |
+
# print("X view 2: {}".format(x.shape))
|
668 |
+
|
669 |
+
x = x.view(B, S, -1, Fq * self.num_subbands, T)
|
670 |
+
x = x * std[:, None] + mean[:, None]
|
671 |
+
# print("X returned: {}".format(x.shape))
|
672 |
+
|
673 |
+
zout = self._mask(z, x)
|
674 |
+
if self.use_train_segment:
|
675 |
+
if self.training:
|
676 |
+
x = self._ispec(zout, length)
|
677 |
+
else:
|
678 |
+
x = self._ispec(zout, training_length)
|
679 |
+
else:
|
680 |
+
x = self._ispec(zout, length)
|
681 |
+
|
682 |
+
if self.use_train_segment:
|
683 |
+
if self.training:
|
684 |
+
xt = xt.view(B, S, -1, length)
|
685 |
+
else:
|
686 |
+
xt = xt.view(B, S, -1, training_length)
|
687 |
+
else:
|
688 |
+
xt = xt.view(B, S, -1, length)
|
689 |
+
xt = xt * stdt[:, None] + meant[:, None]
|
690 |
+
x = xt + x
|
691 |
+
if length_pre_pad:
|
692 |
+
x = x[..., :length_pre_pad]
|
693 |
+
return x
|
694 |
+
|
695 |
+
|
696 |
+
def get_model(args):
|
697 |
+
extra = {
|
698 |
+
'sources': list(args.training.instruments),
|
699 |
+
'audio_channels': args.training.channels,
|
700 |
+
'samplerate': args.training.samplerate,
|
701 |
+
# 'segment': args.model_segment or 4 * args.dset.segment,
|
702 |
+
'segment': args.training.segment,
|
703 |
+
}
|
704 |
+
klass = {
|
705 |
+
'demucs': Demucs,
|
706 |
+
'hdemucs': HDemucs,
|
707 |
+
'htdemucs': HTDemucs,
|
708 |
+
}[args.model]
|
709 |
+
kw = OmegaConf.to_container(getattr(args, args.model), resolve=True)
|
710 |
+
model = klass(**extra, **kw)
|
711 |
+
return model
|
712 |
+
|
713 |
+
|
models/mdx23c_tfc_tdf_v3.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
|
7 |
+
class STFT:
|
8 |
+
def __init__(self, config):
|
9 |
+
self.n_fft = config.n_fft
|
10 |
+
self.hop_length = config.hop_length
|
11 |
+
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
12 |
+
self.dim_f = config.dim_f
|
13 |
+
|
14 |
+
def __call__(self, x):
|
15 |
+
window = self.window.to(x.device)
|
16 |
+
batch_dims = x.shape[:-2]
|
17 |
+
c, t = x.shape[-2:]
|
18 |
+
x = x.reshape([-1, t])
|
19 |
+
x = torch.stft(
|
20 |
+
x,
|
21 |
+
n_fft=self.n_fft,
|
22 |
+
hop_length=self.hop_length,
|
23 |
+
window=window,
|
24 |
+
center=True,
|
25 |
+
return_complex=True
|
26 |
+
)
|
27 |
+
x = torch.view_as_real(x)
|
28 |
+
x = x.permute([0, 3, 1, 2])
|
29 |
+
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
|
30 |
+
return x[..., :self.dim_f, :]
|
31 |
+
|
32 |
+
def inverse(self, x):
|
33 |
+
window = self.window.to(x.device)
|
34 |
+
batch_dims = x.shape[:-3]
|
35 |
+
c, f, t = x.shape[-3:]
|
36 |
+
n = self.n_fft // 2 + 1
|
37 |
+
f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
|
38 |
+
x = torch.cat([x, f_pad], -2)
|
39 |
+
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
|
40 |
+
x = x.permute([0, 2, 3, 1])
|
41 |
+
x = x[..., 0] + x[..., 1] * 1.j
|
42 |
+
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
|
43 |
+
x = x.reshape([*batch_dims, 2, -1])
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
def get_norm(norm_type):
|
48 |
+
def norm(c, norm_type):
|
49 |
+
if norm_type == 'BatchNorm':
|
50 |
+
return nn.BatchNorm2d(c)
|
51 |
+
elif norm_type == 'InstanceNorm':
|
52 |
+
return nn.InstanceNorm2d(c, affine=True)
|
53 |
+
elif 'GroupNorm' in norm_type:
|
54 |
+
g = int(norm_type.replace('GroupNorm', ''))
|
55 |
+
return nn.GroupNorm(num_groups=g, num_channels=c)
|
56 |
+
else:
|
57 |
+
return nn.Identity()
|
58 |
+
|
59 |
+
return partial(norm, norm_type=norm_type)
|
60 |
+
|
61 |
+
|
62 |
+
def get_act(act_type):
|
63 |
+
if act_type == 'gelu':
|
64 |
+
return nn.GELU()
|
65 |
+
elif act_type == 'relu':
|
66 |
+
return nn.ReLU()
|
67 |
+
elif act_type[:3] == 'elu':
|
68 |
+
alpha = float(act_type.replace('elu', ''))
|
69 |
+
return nn.ELU(alpha)
|
70 |
+
else:
|
71 |
+
raise Exception
|
72 |
+
|
73 |
+
|
74 |
+
class Upscale(nn.Module):
|
75 |
+
def __init__(self, in_c, out_c, scale, norm, act):
|
76 |
+
super().__init__()
|
77 |
+
self.conv = nn.Sequential(
|
78 |
+
norm(in_c),
|
79 |
+
act,
|
80 |
+
nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
return self.conv(x)
|
85 |
+
|
86 |
+
|
87 |
+
class Downscale(nn.Module):
|
88 |
+
def __init__(self, in_c, out_c, scale, norm, act):
|
89 |
+
super().__init__()
|
90 |
+
self.conv = nn.Sequential(
|
91 |
+
norm(in_c),
|
92 |
+
act,
|
93 |
+
nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
|
94 |
+
)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
return self.conv(x)
|
98 |
+
|
99 |
+
|
100 |
+
class TFC_TDF(nn.Module):
|
101 |
+
def __init__(self, in_c, c, l, f, bn, norm, act):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
self.blocks = nn.ModuleList()
|
105 |
+
for i in range(l):
|
106 |
+
block = nn.Module()
|
107 |
+
|
108 |
+
block.tfc1 = nn.Sequential(
|
109 |
+
norm(in_c),
|
110 |
+
act,
|
111 |
+
nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
|
112 |
+
)
|
113 |
+
block.tdf = nn.Sequential(
|
114 |
+
norm(c),
|
115 |
+
act,
|
116 |
+
nn.Linear(f, f // bn, bias=False),
|
117 |
+
norm(c),
|
118 |
+
act,
|
119 |
+
nn.Linear(f // bn, f, bias=False),
|
120 |
+
)
|
121 |
+
block.tfc2 = nn.Sequential(
|
122 |
+
norm(c),
|
123 |
+
act,
|
124 |
+
nn.Conv2d(c, c, 3, 1, 1, bias=False),
|
125 |
+
)
|
126 |
+
block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
|
127 |
+
|
128 |
+
self.blocks.append(block)
|
129 |
+
in_c = c
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
for block in self.blocks:
|
133 |
+
s = block.shortcut(x)
|
134 |
+
x = block.tfc1(x)
|
135 |
+
x = x + block.tdf(x)
|
136 |
+
x = block.tfc2(x)
|
137 |
+
x = x + s
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class TFC_TDF_net(nn.Module):
|
142 |
+
def __init__(self, config):
|
143 |
+
super().__init__()
|
144 |
+
self.config = config
|
145 |
+
|
146 |
+
norm = get_norm(norm_type=config.model.norm)
|
147 |
+
act = get_act(act_type=config.model.act)
|
148 |
+
|
149 |
+
self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
|
150 |
+
self.num_subbands = config.model.num_subbands
|
151 |
+
|
152 |
+
dim_c = self.num_subbands * config.audio.num_channels * 2
|
153 |
+
n = config.model.num_scales
|
154 |
+
scale = config.model.scale
|
155 |
+
l = config.model.num_blocks_per_scale
|
156 |
+
c = config.model.num_channels
|
157 |
+
g = config.model.growth
|
158 |
+
bn = config.model.bottleneck_factor
|
159 |
+
f = config.audio.dim_f // self.num_subbands
|
160 |
+
|
161 |
+
self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
|
162 |
+
|
163 |
+
self.encoder_blocks = nn.ModuleList()
|
164 |
+
for i in range(n):
|
165 |
+
block = nn.Module()
|
166 |
+
block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
|
167 |
+
block.downscale = Downscale(c, c + g, scale, norm, act)
|
168 |
+
f = f // scale[1]
|
169 |
+
c += g
|
170 |
+
self.encoder_blocks.append(block)
|
171 |
+
|
172 |
+
self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
|
173 |
+
|
174 |
+
self.decoder_blocks = nn.ModuleList()
|
175 |
+
for i in range(n):
|
176 |
+
block = nn.Module()
|
177 |
+
block.upscale = Upscale(c, c - g, scale, norm, act)
|
178 |
+
f = f * scale[1]
|
179 |
+
c -= g
|
180 |
+
block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act)
|
181 |
+
self.decoder_blocks.append(block)
|
182 |
+
|
183 |
+
self.final_conv = nn.Sequential(
|
184 |
+
nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
|
185 |
+
act,
|
186 |
+
nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
|
187 |
+
)
|
188 |
+
|
189 |
+
self.stft = STFT(config.audio)
|
190 |
+
|
191 |
+
def cac2cws(self, x):
|
192 |
+
k = self.num_subbands
|
193 |
+
b, c, f, t = x.shape
|
194 |
+
x = x.reshape(b, c, k, f // k, t)
|
195 |
+
x = x.reshape(b, c * k, f // k, t)
|
196 |
+
return x
|
197 |
+
|
198 |
+
def cws2cac(self, x):
|
199 |
+
k = self.num_subbands
|
200 |
+
b, c, f, t = x.shape
|
201 |
+
x = x.reshape(b, c // k, k, f, t)
|
202 |
+
x = x.reshape(b, c // k, f * k, t)
|
203 |
+
return x
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
|
207 |
+
x = self.stft(x)
|
208 |
+
|
209 |
+
mix = x = self.cac2cws(x)
|
210 |
+
|
211 |
+
first_conv_out = x = self.first_conv(x)
|
212 |
+
|
213 |
+
x = x.transpose(-1, -2)
|
214 |
+
|
215 |
+
encoder_outputs = []
|
216 |
+
for block in self.encoder_blocks:
|
217 |
+
x = block.tfc_tdf(x)
|
218 |
+
encoder_outputs.append(x)
|
219 |
+
x = block.downscale(x)
|
220 |
+
|
221 |
+
x = self.bottleneck_block(x)
|
222 |
+
|
223 |
+
for block in self.decoder_blocks:
|
224 |
+
x = block.upscale(x)
|
225 |
+
x = torch.cat([x, encoder_outputs.pop()], 1)
|
226 |
+
x = block.tfc_tdf(x)
|
227 |
+
|
228 |
+
x = x.transpose(-1, -2)
|
229 |
+
|
230 |
+
x = x * first_conv_out # reduce artifacts
|
231 |
+
|
232 |
+
x = self.final_conv(torch.cat([mix, x], 1))
|
233 |
+
|
234 |
+
x = self.cws2cac(x)
|
235 |
+
|
236 |
+
if self.num_target_instruments > 1:
|
237 |
+
b, c, f, t = x.shape
|
238 |
+
x = x.reshape(b, self.num_target_instruments, -1, f, t)
|
239 |
+
|
240 |
+
x = self.stft.inverse(x)
|
241 |
+
|
242 |
+
return x
|
models/scnet/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .scnet import SCNet
|
models/scnet/scnet.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from collections import deque
|
5 |
+
from .separation import SeparationNet
|
6 |
+
import typing as tp
|
7 |
+
import math
|
8 |
+
|
9 |
+
class Swish(nn.Module):
|
10 |
+
def forward(self, x):
|
11 |
+
return x * x.sigmoid()
|
12 |
+
|
13 |
+
|
14 |
+
class ConvolutionModule(nn.Module):
|
15 |
+
"""
|
16 |
+
Convolution Module in SD block.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
channels (int): input/output channels.
|
20 |
+
depth (int): number of layers in the residual branch. Each layer has its own
|
21 |
+
compress (float): amount of channel compression.
|
22 |
+
kernel (int): kernel size for the convolutions.
|
23 |
+
"""
|
24 |
+
def __init__(self, channels, depth=2, compress=4, kernel=3):
|
25 |
+
super().__init__()
|
26 |
+
assert kernel % 2 == 1
|
27 |
+
self.depth = abs(depth)
|
28 |
+
hidden_size = int(channels / compress)
|
29 |
+
norm = lambda d: nn.GroupNorm(1, d)
|
30 |
+
self.layers = nn.ModuleList([])
|
31 |
+
for _ in range(self.depth):
|
32 |
+
padding = (kernel // 2)
|
33 |
+
mods = [
|
34 |
+
norm(channels),
|
35 |
+
nn.Conv1d(channels, hidden_size*2, kernel, padding = padding),
|
36 |
+
nn.GLU(1),
|
37 |
+
nn.Conv1d(hidden_size, hidden_size, kernel, padding = padding, groups = hidden_size),
|
38 |
+
norm(hidden_size),
|
39 |
+
Swish(),
|
40 |
+
nn.Conv1d(hidden_size, channels, 1),
|
41 |
+
]
|
42 |
+
layer = nn.Sequential(*mods)
|
43 |
+
self.layers.append(layer)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
for layer in self.layers:
|
47 |
+
x = x + layer(x)
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class FusionLayer(nn.Module):
|
52 |
+
"""
|
53 |
+
A FusionLayer within the decoder.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
- channels (int): Number of input channels.
|
57 |
+
- kernel_size (int, optional): Kernel size for the convolutional layer, defaults to 3.
|
58 |
+
- stride (int, optional): Stride for the convolutional layer, defaults to 1.
|
59 |
+
- padding (int, optional): Padding for the convolutional layer, defaults to 1.
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, channels, kernel_size=3, stride=1, padding=1):
|
63 |
+
super(FusionLayer, self).__init__()
|
64 |
+
self.conv = nn.Conv2d(channels * 2, channels * 2, kernel_size, stride=stride, padding=padding)
|
65 |
+
|
66 |
+
def forward(self, x, skip=None):
|
67 |
+
if skip is not None:
|
68 |
+
x += skip
|
69 |
+
x = x.repeat(1, 2, 1, 1)
|
70 |
+
x = self.conv(x)
|
71 |
+
x = F.glu(x, dim=1)
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class SDlayer(nn.Module):
|
76 |
+
"""
|
77 |
+
Implements a Sparse Down-sample Layer for processing different frequency bands separately.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
- channels_in (int): Input channel count.
|
81 |
+
- channels_out (int): Output channel count.
|
82 |
+
- band_configs (dict): A dictionary containing configuration for each frequency band.
|
83 |
+
Keys are 'low', 'mid', 'high' for each band, and values are
|
84 |
+
dictionaries with keys 'SR', 'stride', and 'kernel' for proportion,
|
85 |
+
stride, and kernel size, respectively.
|
86 |
+
"""
|
87 |
+
def __init__(self, channels_in, channels_out, band_configs):
|
88 |
+
super(SDlayer, self).__init__()
|
89 |
+
|
90 |
+
# Initializing convolutional layers for each band
|
91 |
+
self.convs = nn.ModuleList()
|
92 |
+
self.strides = []
|
93 |
+
self.kernels = []
|
94 |
+
for config in band_configs.values():
|
95 |
+
self.convs.append(nn.Conv2d(channels_in, channels_out, (config['kernel'], 1), (config['stride'], 1), (0, 0)))
|
96 |
+
self.strides.append(config['stride'])
|
97 |
+
self.kernels.append(config['kernel'])
|
98 |
+
|
99 |
+
# Saving rate proportions for determining splits
|
100 |
+
self.SR_low = band_configs['low']['SR']
|
101 |
+
self.SR_mid = band_configs['mid']['SR']
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
B, C, Fr, T = x.shape
|
105 |
+
# Define splitting points based on sampling rates
|
106 |
+
splits = [
|
107 |
+
(0, math.ceil(Fr * self.SR_low)),
|
108 |
+
(math.ceil(Fr * self.SR_low), math.ceil(Fr * (self.SR_low + self.SR_mid))),
|
109 |
+
(math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr)
|
110 |
+
]
|
111 |
+
|
112 |
+
# Processing each band with the corresponding convolution
|
113 |
+
outputs = []
|
114 |
+
original_lengths=[]
|
115 |
+
for conv, stride, kernel, (start, end) in zip(self.convs, self.strides, self.kernels, splits):
|
116 |
+
extracted = x[:, :, start:end, :]
|
117 |
+
original_lengths.append(end-start)
|
118 |
+
current_length = extracted.shape[2]
|
119 |
+
|
120 |
+
# padding
|
121 |
+
if stride == 1:
|
122 |
+
total_padding = kernel - stride
|
123 |
+
else:
|
124 |
+
total_padding = (stride - current_length % stride) % stride
|
125 |
+
pad_left = total_padding // 2
|
126 |
+
pad_right = total_padding - pad_left
|
127 |
+
|
128 |
+
padded = F.pad(extracted, (0, 0, pad_left, pad_right))
|
129 |
+
|
130 |
+
output = conv(padded)
|
131 |
+
outputs.append(output)
|
132 |
+
|
133 |
+
return outputs, original_lengths
|
134 |
+
|
135 |
+
|
136 |
+
class SUlayer(nn.Module):
|
137 |
+
"""
|
138 |
+
Implements a Sparse Up-sample Layer in decoder.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
- channels_in: The number of input channels.
|
142 |
+
- channels_out: The number of output channels.
|
143 |
+
- convtr_configs: Dictionary containing the configurations for transposed convolutions.
|
144 |
+
"""
|
145 |
+
def __init__(self, channels_in, channels_out, band_configs):
|
146 |
+
super(SUlayer, self).__init__()
|
147 |
+
|
148 |
+
# Initializing convolutional layers for each band
|
149 |
+
self.convtrs = nn.ModuleList([
|
150 |
+
nn.ConvTranspose2d(channels_in, channels_out, [config['kernel'], 1], [config['stride'], 1])
|
151 |
+
for _, config in band_configs.items()
|
152 |
+
])
|
153 |
+
|
154 |
+
def forward(self, x, lengths, origin_lengths):
|
155 |
+
B, C, Fr, T = x.shape
|
156 |
+
# Define splitting points based on input lengths
|
157 |
+
splits = [
|
158 |
+
(0, lengths[0]),
|
159 |
+
(lengths[0], lengths[0] + lengths[1]),
|
160 |
+
(lengths[0] + lengths[1], None)
|
161 |
+
]
|
162 |
+
# Processing each band with the corresponding convolution
|
163 |
+
outputs = []
|
164 |
+
for idx, (convtr, (start, end)) in enumerate(zip(self.convtrs, splits)):
|
165 |
+
out = convtr(x[:, :, start:end, :])
|
166 |
+
# Calculate the distance to trim the output symmetrically to original length
|
167 |
+
current_Fr_length = out.shape[2]
|
168 |
+
dist = abs(origin_lengths[idx] - current_Fr_length) // 2
|
169 |
+
|
170 |
+
# Trim the output to the original length symmetrically
|
171 |
+
trimmed_out = out[:, :, dist:dist + origin_lengths[idx], :]
|
172 |
+
|
173 |
+
outputs.append(trimmed_out)
|
174 |
+
|
175 |
+
# Concatenate trimmed outputs along the frequency dimension to return the final tensor
|
176 |
+
x = torch.cat(outputs, dim=2)
|
177 |
+
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class SDblock(nn.Module):
|
182 |
+
"""
|
183 |
+
Implements a simplified Sparse Down-sample block in encoder.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
- channels_in (int): Number of input channels.
|
187 |
+
- channels_out (int): Number of output channels.
|
188 |
+
- band_config (dict): Configuration for the SDlayer specifying band splits and convolutions.
|
189 |
+
- conv_config (dict): Configuration for convolution modules applied to each band.
|
190 |
+
- depths (list of int): List specifying the convolution depths for low, mid, and high frequency bands.
|
191 |
+
"""
|
192 |
+
def __init__(self, channels_in, channels_out, band_configs={}, conv_config={}, depths=[3, 2, 1], kernel_size=3):
|
193 |
+
super(SDblock, self).__init__()
|
194 |
+
self.SDlayer = SDlayer(channels_in, channels_out, band_configs)
|
195 |
+
|
196 |
+
# Dynamically create convolution modules for each band based on depths
|
197 |
+
self.conv_modules = nn.ModuleList([
|
198 |
+
ConvolutionModule(channels_out, depth, **conv_config) for depth in depths
|
199 |
+
])
|
200 |
+
#Set the kernel_size to an odd number.
|
201 |
+
self.globalconv = nn.Conv2d(channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
bands, original_lengths = self.SDlayer(x)
|
205 |
+
# B, C, f, T = band.shape
|
206 |
+
bands = [
|
207 |
+
F.gelu(
|
208 |
+
conv(band.permute(0, 2, 1, 3).reshape(-1, band.shape[1], band.shape[3]))
|
209 |
+
.view(band.shape[0], band.shape[2], band.shape[1], band.shape[3])
|
210 |
+
.permute(0, 2, 1, 3)
|
211 |
+
)
|
212 |
+
for conv, band in zip(self.conv_modules, bands)
|
213 |
+
|
214 |
+
]
|
215 |
+
lengths = [band.size(-2) for band in bands]
|
216 |
+
full_band = torch.cat(bands, dim=2)
|
217 |
+
skip = full_band
|
218 |
+
|
219 |
+
output = self.globalconv(full_band)
|
220 |
+
|
221 |
+
return output, skip, lengths, original_lengths
|
222 |
+
|
223 |
+
|
224 |
+
class SCNet(nn.Module):
|
225 |
+
"""
|
226 |
+
The implementation of SCNet: Sparse Compression Network for Music Source Separation. Paper: https://arxiv.org/abs/2401.13276.pdf
|
227 |
+
|
228 |
+
Args:
|
229 |
+
- sources (List[str]): List of sources to be separated.
|
230 |
+
- audio_channels (int): Number of audio channels.
|
231 |
+
- nfft (int): Number of FFTs to determine the frequency dimension of the input.
|
232 |
+
- hop_size (int): Hop size for the STFT.
|
233 |
+
- win_size (int): Window size for STFT.
|
234 |
+
- normalized (bool): Whether to normalize the STFT.
|
235 |
+
- dims (List[int]): List of channel dimensions for each block.
|
236 |
+
- band_configs (Dict[str, Dict[str, int]]): Configuration for each frequency band, including how to divide the frequency bands,
|
237 |
+
and the settings for the upsampling/downsampling convolutional layers.
|
238 |
+
- conv_depths (List[int]): List specifying the number of convolution modules in each SD block.
|
239 |
+
- compress (int): Compression factor for convolution module.
|
240 |
+
- conv_kernel (int): Kernel size for convolution layer in convolution module.
|
241 |
+
- num_dplayer (int): Number of dual-path layers.
|
242 |
+
- expand (int): Expansion factor in the dual-path RNN, default is 1.
|
243 |
+
|
244 |
+
"""
|
245 |
+
def __init__(self,
|
246 |
+
sources = ['drums', 'bass', 'other', 'vocals'],
|
247 |
+
audio_channels = 2,
|
248 |
+
# Main structure
|
249 |
+
dims = [4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large
|
250 |
+
# STFT
|
251 |
+
nfft = 4096,
|
252 |
+
hop_size = 1024,
|
253 |
+
win_size = 4096,
|
254 |
+
normalized = True,
|
255 |
+
# SD/SU layer
|
256 |
+
band_configs = {
|
257 |
+
'low': { 'SR': .175, 'stride': 1, 'kernel': 3 },
|
258 |
+
'mid': { 'SR': .392, 'stride': 4, 'kernel': 4 },
|
259 |
+
'high': {'SR': .433, 'stride': 16, 'kernel': 16 }
|
260 |
+
},
|
261 |
+
# Convolution Module
|
262 |
+
conv_depths = [3,2,1],
|
263 |
+
compress = 4,
|
264 |
+
conv_kernel = 3,
|
265 |
+
# Dual-path RNN
|
266 |
+
num_dplayer = 6,
|
267 |
+
expand = 1,
|
268 |
+
# mamba
|
269 |
+
use_mamba = False,
|
270 |
+
mamba_config = {
|
271 |
+
'd_stat': 16,
|
272 |
+
'd_conv': 4,
|
273 |
+
'd_expand': 2
|
274 |
+
}):
|
275 |
+
super().__init__()
|
276 |
+
self.sources = sources
|
277 |
+
self.audio_channels = audio_channels
|
278 |
+
self.dims = dims
|
279 |
+
self.band_configs = band_configs
|
280 |
+
self.hop_length = hop_size
|
281 |
+
self.conv_config = {
|
282 |
+
'compress': compress,
|
283 |
+
'kernel': conv_kernel,
|
284 |
+
}
|
285 |
+
|
286 |
+
self.stft_config = {
|
287 |
+
'n_fft': nfft,
|
288 |
+
'hop_length': hop_size,
|
289 |
+
'win_length': win_size,
|
290 |
+
'center': True,
|
291 |
+
'normalized': normalized
|
292 |
+
}
|
293 |
+
|
294 |
+
self.encoder = nn.ModuleList()
|
295 |
+
self.decoder = nn.ModuleList()
|
296 |
+
|
297 |
+
for index in range(len(dims)-1):
|
298 |
+
enc = SDblock(
|
299 |
+
channels_in = dims[index],
|
300 |
+
channels_out = dims[index+1],
|
301 |
+
band_configs = self.band_configs,
|
302 |
+
conv_config = self.conv_config,
|
303 |
+
depths = conv_depths
|
304 |
+
)
|
305 |
+
self.encoder.append(enc)
|
306 |
+
|
307 |
+
dec = nn.Sequential(
|
308 |
+
FusionLayer(channels = dims[index+1]),
|
309 |
+
SUlayer(
|
310 |
+
channels_in = dims[index+1],
|
311 |
+
channels_out = dims[index] if index != 0 else dims[index] * len(sources),
|
312 |
+
band_configs = self.band_configs,
|
313 |
+
)
|
314 |
+
)
|
315 |
+
self.decoder.insert(0, dec)
|
316 |
+
|
317 |
+
self.separation_net = SeparationNet(
|
318 |
+
channels = dims[-1],
|
319 |
+
expand = expand,
|
320 |
+
num_layers = num_dplayer,
|
321 |
+
use_mamba = use_mamba,
|
322 |
+
**mamba_config
|
323 |
+
)
|
324 |
+
|
325 |
+
|
326 |
+
def forward(self, x):
|
327 |
+
# B, C, L = x.shape
|
328 |
+
B = x.shape[0]
|
329 |
+
# In the initial padding, ensure that the number of frames after the STFT (the length of the T dimension) is even,
|
330 |
+
# so that the RFFT operation can be used in the separation network.
|
331 |
+
padding = self.hop_length - x.shape[-1] % self.hop_length
|
332 |
+
if (x.shape[-1] + padding) // self.hop_length % 2 == 0:
|
333 |
+
padding += self.hop_length
|
334 |
+
x = F.pad(x, (0, padding))
|
335 |
+
|
336 |
+
# STFT
|
337 |
+
L = x.shape[-1]
|
338 |
+
x = x.reshape(-1, L)
|
339 |
+
x = torch.stft(x, **self.stft_config, return_complex=True)
|
340 |
+
x = torch.view_as_real(x)
|
341 |
+
x = x.permute(0, 3, 1, 2).reshape(x.shape[0]//self.audio_channels, x.shape[3]*self.audio_channels, x.shape[1], x.shape[2])
|
342 |
+
|
343 |
+
B, C, Fr, T = x.shape
|
344 |
+
|
345 |
+
save_skip = deque()
|
346 |
+
save_lengths = deque()
|
347 |
+
save_original_lengths = deque()
|
348 |
+
# encoder
|
349 |
+
for sd_layer in self.encoder:
|
350 |
+
x, skip, lengths, original_lengths = sd_layer(x)
|
351 |
+
save_skip.append(skip)
|
352 |
+
save_lengths.append(lengths)
|
353 |
+
save_original_lengths.append(original_lengths)
|
354 |
+
|
355 |
+
#separation
|
356 |
+
x = self.separation_net(x)
|
357 |
+
|
358 |
+
#decoder
|
359 |
+
for fusion_layer, su_layer in self.decoder:
|
360 |
+
x = fusion_layer(x, save_skip.pop())
|
361 |
+
x = su_layer(x, save_lengths.pop(), save_original_lengths.pop())
|
362 |
+
|
363 |
+
#output
|
364 |
+
n = self.dims[0]
|
365 |
+
x = x.view(B, n, -1, Fr, T)
|
366 |
+
x = x.reshape(-1, 2, Fr, T).permute(0, 2, 3, 1)
|
367 |
+
x = torch.view_as_complex(x.contiguous())
|
368 |
+
x = torch.istft(x, **self.stft_config)
|
369 |
+
x = x.reshape(B, len(self.sources), self.audio_channels, -1)
|
370 |
+
|
371 |
+
x = x[:, :, :, :-padding]
|
372 |
+
|
373 |
+
return x
|
models/scnet/separation.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.modules.rnn import LSTM
|
4 |
+
import torch.nn.functional as Func
|
5 |
+
try:
|
6 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
7 |
+
except Exception as e:
|
8 |
+
print('No mamba found. Please install mamba_ssm')
|
9 |
+
|
10 |
+
class RMSNorm(nn.Module):
|
11 |
+
def __init__(self, dim):
|
12 |
+
super().__init__()
|
13 |
+
self.scale = dim ** 0.5
|
14 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return Func.normalize(x, dim=-1) * self.scale * self.gamma
|
18 |
+
|
19 |
+
|
20 |
+
class MambaModule(nn.Module):
|
21 |
+
def __init__(self, d_model, d_state, d_conv, d_expand):
|
22 |
+
super().__init__()
|
23 |
+
self.norm = RMSNorm(dim=d_model)
|
24 |
+
self.mamba = Mamba(
|
25 |
+
d_model=d_model,
|
26 |
+
d_state=d_state,
|
27 |
+
d_conv=d_conv,
|
28 |
+
expand=d_expand
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = x + self.mamba(self.norm(x))
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class FeatureConversion(nn.Module):
|
37 |
+
"""
|
38 |
+
Integrates into the adjacent Dual-Path layer.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
channels (int): Number of input channels.
|
42 |
+
inverse (bool): If True, uses ifft; otherwise, uses rfft.
|
43 |
+
"""
|
44 |
+
def __init__(self, channels, inverse):
|
45 |
+
super().__init__()
|
46 |
+
self.inverse = inverse
|
47 |
+
self.channels= channels
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
# B, C, F, T = x.shape
|
51 |
+
if self.inverse:
|
52 |
+
x = x.float()
|
53 |
+
x_r = x[:, :self.channels//2, :, :]
|
54 |
+
x_i = x[:, self.channels//2:, :, :]
|
55 |
+
x = torch.complex(x_r, x_i)
|
56 |
+
x = torch.fft.irfft(x, dim=3, norm="ortho")
|
57 |
+
else:
|
58 |
+
x = x.float()
|
59 |
+
x = torch.fft.rfft(x, dim=3, norm="ortho")
|
60 |
+
x_real = x.real
|
61 |
+
x_imag = x.imag
|
62 |
+
x = torch.cat([x_real, x_imag], dim=1)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class DualPathRNN(nn.Module):
|
67 |
+
"""
|
68 |
+
Dual-Path RNN in Separation Network.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
d_model (int): The number of expected features in the input (input_size).
|
72 |
+
expand (int): Expansion factor used to calculate the hidden_size of LSTM.
|
73 |
+
bidirectional (bool): If True, becomes a bidirectional LSTM.
|
74 |
+
"""
|
75 |
+
def __init__(self, d_model, expand, bidirectional=True):
|
76 |
+
super(DualPathRNN, self).__init__()
|
77 |
+
|
78 |
+
self.d_model = d_model
|
79 |
+
self.hidden_size = d_model * expand
|
80 |
+
self.bidirectional = bidirectional
|
81 |
+
# Initialize LSTM layers and normalization layers
|
82 |
+
self.lstm_layers = nn.ModuleList([self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)])
|
83 |
+
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_size*2, self.d_model) for _ in range(2)])
|
84 |
+
self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)])
|
85 |
+
|
86 |
+
def _init_lstm_layer(self, d_model, hidden_size):
|
87 |
+
return LSTM(d_model, hidden_size, num_layers=1, bidirectional=self.bidirectional, batch_first=True)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
B, C, F, T = x.shape
|
91 |
+
|
92 |
+
# Process dual-path rnn
|
93 |
+
|
94 |
+
original_x = x
|
95 |
+
# Frequency-path
|
96 |
+
x = self.norm_layers[0](x)
|
97 |
+
x = x.transpose(1, 3).contiguous().view(B * T, F, C)
|
98 |
+
x, _ = self.lstm_layers[0](x)
|
99 |
+
x = self.linear_layers[0](x)
|
100 |
+
x = x.view(B, T, F, C).transpose(1, 3)
|
101 |
+
x = x + original_x
|
102 |
+
|
103 |
+
original_x = x
|
104 |
+
# Time-path
|
105 |
+
x = self.norm_layers[1](x)
|
106 |
+
x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2)
|
107 |
+
x, _ = self.lstm_layers[1](x)
|
108 |
+
x = self.linear_layers[1](x)
|
109 |
+
x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2)
|
110 |
+
x = x + original_x
|
111 |
+
|
112 |
+
return x
|
113 |
+
|
114 |
+
|
115 |
+
class DualPathMamba(nn.Module):
|
116 |
+
"""
|
117 |
+
Dual-Path Mamba.
|
118 |
+
|
119 |
+
"""
|
120 |
+
def __init__(self, d_model, d_stat, d_conv, d_expand):
|
121 |
+
super(DualPathMamba, self).__init__()
|
122 |
+
# Initialize mamba layers
|
123 |
+
self.mamba_layers = nn.ModuleList([MambaModule(d_model, d_stat, d_conv, d_expand) for _ in range(2)])
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
B, C, F, T = x.shape
|
127 |
+
|
128 |
+
# Process dual-path mamba
|
129 |
+
|
130 |
+
# Frequency-path
|
131 |
+
x = x.transpose(1, 3).contiguous().view(B * T, F, C)
|
132 |
+
x = self.mamba_layers[0](x)
|
133 |
+
x = x.view(B, T, F, C).transpose(1, 3)
|
134 |
+
|
135 |
+
# Time-path
|
136 |
+
x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2)
|
137 |
+
x = self.mamba_layers[1](x)
|
138 |
+
x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2)
|
139 |
+
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
class SeparationNet(nn.Module):
|
144 |
+
"""
|
145 |
+
Implements a simplified Sparse Down-sample block in an encoder architecture.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
- channels (int): Number input channels.
|
149 |
+
- expand (int): Expansion factor used to calculate the hidden_size of LSTM.
|
150 |
+
- num_layers (int): Number of dual-path layers.
|
151 |
+
- use_mamba (bool): If true, use the Mamba module to replace the RNN.
|
152 |
+
- d_stat (int), d_conv (int), d_expand (int): These are built-in parameters of the Mamba model.
|
153 |
+
"""
|
154 |
+
def __init__(self, channels, expand=1, num_layers=6, use_mamba=True, d_stat=16, d_conv=4, d_expand=2):
|
155 |
+
super(SeparationNet, self).__init__()
|
156 |
+
|
157 |
+
self.num_layers = num_layers
|
158 |
+
if use_mamba:
|
159 |
+
self.dp_modules = nn.ModuleList([
|
160 |
+
DualPathMamba(channels * (2 if i % 2 == 1 else 1), d_stat, d_conv, d_expand * (2 if i % 2 == 1 else 1)) for i in range(num_layers)
|
161 |
+
])
|
162 |
+
else:
|
163 |
+
self.dp_modules = nn.ModuleList([
|
164 |
+
DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) for i in range(num_layers)
|
165 |
+
])
|
166 |
+
|
167 |
+
self.feature_conversion = nn.ModuleList([
|
168 |
+
FeatureConversion(channels * 2 , inverse = False if i % 2 == 0 else True) for i in range(num_layers)
|
169 |
+
])
|
170 |
+
def forward(self, x):
|
171 |
+
for i in range(self.num_layers):
|
172 |
+
x = self.dp_modules[i](x)
|
173 |
+
x = self.feature_conversion[i](x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
|
models/scnet_unofficial/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from models.scnet_unofficial.scnet import SCNet
|
models/scnet_unofficial/modules/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from models.scnet_unofficial.modules.dualpath_rnn import DualPathRNN
|
2 |
+
from models.scnet_unofficial.modules.sd_encoder import SDBlock
|
3 |
+
from models.scnet_unofficial.modules.su_decoder import SUBlock
|
models/scnet_unofficial/modules/dualpath_rnn.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as Func
|
4 |
+
|
5 |
+
class RMSNorm(nn.Module):
|
6 |
+
def __init__(self, dim):
|
7 |
+
super().__init__()
|
8 |
+
self.scale = dim ** 0.5
|
9 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
return Func.normalize(x, dim=-1) * self.scale * self.gamma
|
13 |
+
|
14 |
+
|
15 |
+
class MambaModule(nn.Module):
|
16 |
+
def __init__(self, d_model, d_state, d_conv, d_expand):
|
17 |
+
super().__init__()
|
18 |
+
self.norm = RMSNorm(dim=d_model)
|
19 |
+
self.mamba = Mamba(
|
20 |
+
d_model=d_model,
|
21 |
+
d_state=d_state,
|
22 |
+
d_conv=d_conv,
|
23 |
+
d_expand=d_expand
|
24 |
+
)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = x + self.mamba(self.norm(x))
|
28 |
+
return x
|
29 |
+
|
30 |
+
|
31 |
+
class RNNModule(nn.Module):
|
32 |
+
"""
|
33 |
+
RNNModule class implements a recurrent neural network module with LSTM cells.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
- input_dim (int): Dimensionality of the input features.
|
37 |
+
- hidden_dim (int): Dimensionality of the hidden state of the LSTM.
|
38 |
+
- bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True.
|
39 |
+
|
40 |
+
Shapes:
|
41 |
+
- Input: (B, T, D) where
|
42 |
+
B is batch size,
|
43 |
+
T is sequence length,
|
44 |
+
D is input dimensionality.
|
45 |
+
- Output: (B, T, D) where
|
46 |
+
B is batch size,
|
47 |
+
T is sequence length,
|
48 |
+
D is input dimensionality.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True):
|
52 |
+
"""
|
53 |
+
Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag.
|
54 |
+
"""
|
55 |
+
super().__init__()
|
56 |
+
self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim)
|
57 |
+
self.rnn = nn.LSTM(
|
58 |
+
input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional
|
59 |
+
)
|
60 |
+
self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim)
|
61 |
+
|
62 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
63 |
+
"""
|
64 |
+
Performs forward pass through the RNNModule.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
- x (torch.Tensor): Input tensor of shape (B, T, D).
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
- torch.Tensor: Output tensor of shape (B, T, D).
|
71 |
+
"""
|
72 |
+
x = x.transpose(1, 2)
|
73 |
+
x = self.groupnorm(x)
|
74 |
+
x = x.transpose(1, 2)
|
75 |
+
|
76 |
+
x, (hidden, _) = self.rnn(x)
|
77 |
+
x = self.fc(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class RFFTModule(nn.Module):
|
82 |
+
"""
|
83 |
+
RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT)
|
84 |
+
or its inverse on input tensors.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
- inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False.
|
88 |
+
|
89 |
+
Shapes:
|
90 |
+
- Input: (B, F, T, D) where
|
91 |
+
B is batch size,
|
92 |
+
F is the number of features,
|
93 |
+
T is sequence length,
|
94 |
+
D is input dimensionality.
|
95 |
+
- Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT.
|
96 |
+
(B, F, T, D // 2, 2) if performing inverse FFT.
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, inverse: bool = False):
|
100 |
+
"""
|
101 |
+
Initializes RFFTModule with inverse flag.
|
102 |
+
"""
|
103 |
+
super().__init__()
|
104 |
+
self.inverse = inverse
|
105 |
+
|
106 |
+
def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor:
|
107 |
+
"""
|
108 |
+
Performs forward or inverse FFT on the input tensor x.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
- x (torch.Tensor): Input tensor of shape (B, F, T, D).
|
112 |
+
- time_dim (int): Input size of time dimension.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
- torch.Tensor: Output tensor after FFT or its inverse operation.
|
116 |
+
"""
|
117 |
+
dtype = x.dtype
|
118 |
+
B, F, T, D = x.shape
|
119 |
+
|
120 |
+
# RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision
|
121 |
+
x = x.float()
|
122 |
+
|
123 |
+
if not self.inverse:
|
124 |
+
x = torch.fft.rfft(x, dim=2)
|
125 |
+
x = torch.view_as_real(x)
|
126 |
+
x = x.reshape(B, F, T // 2 + 1, D * 2)
|
127 |
+
else:
|
128 |
+
x = x.reshape(B, F, T, D // 2, 2)
|
129 |
+
x = torch.view_as_complex(x)
|
130 |
+
x = torch.fft.irfft(x, n=time_dim, dim=2)
|
131 |
+
|
132 |
+
x = x.to(dtype)
|
133 |
+
return x
|
134 |
+
|
135 |
+
def extra_repr(self) -> str:
|
136 |
+
"""
|
137 |
+
Returns extra representation string with module's configuration.
|
138 |
+
"""
|
139 |
+
return f"inverse={self.inverse}"
|
140 |
+
|
141 |
+
|
142 |
+
class DualPathRNN(nn.Module):
|
143 |
+
"""
|
144 |
+
DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
- n_layers (int): Number of layers in the network.
|
148 |
+
- input_dim (int): Dimensionality of the input features.
|
149 |
+
- hidden_dim (int): Dimensionality of the hidden state of the RNNModule.
|
150 |
+
|
151 |
+
Shapes:
|
152 |
+
- Input: (B, F, T, D) where
|
153 |
+
B is batch size,
|
154 |
+
F is the number of features (frequency dimension),
|
155 |
+
T is sequence length (time dimension),
|
156 |
+
D is input dimensionality (channel dimension).
|
157 |
+
- Output: (B, F, T, D) where
|
158 |
+
B is batch size,
|
159 |
+
F is the number of features (frequency dimension),
|
160 |
+
T is sequence length (time dimension),
|
161 |
+
D is input dimensionality (channel dimension).
|
162 |
+
"""
|
163 |
+
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
n_layers: int,
|
167 |
+
input_dim: int,
|
168 |
+
hidden_dim: int,
|
169 |
+
|
170 |
+
use_mamba: bool = False,
|
171 |
+
d_state: int = 16,
|
172 |
+
d_conv: int = 4,
|
173 |
+
d_expand: int = 2
|
174 |
+
):
|
175 |
+
"""
|
176 |
+
Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension.
|
177 |
+
"""
|
178 |
+
super().__init__()
|
179 |
+
|
180 |
+
if use_mamba:
|
181 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
182 |
+
net = MambaModule
|
183 |
+
dkwargs = {"d_model": input_dim, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand}
|
184 |
+
ukwargs = {"d_model": input_dim * 2, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand * 2}
|
185 |
+
else:
|
186 |
+
net = RNNModule
|
187 |
+
dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim}
|
188 |
+
ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2}
|
189 |
+
|
190 |
+
self.layers = nn.ModuleList()
|
191 |
+
for i in range(1, n_layers + 1):
|
192 |
+
kwargs = dkwargs if i % 2 == 1 else ukwargs
|
193 |
+
layer = nn.ModuleList([
|
194 |
+
net(**kwargs),
|
195 |
+
net(**kwargs),
|
196 |
+
RFFTModule(inverse=(i % 2 == 0)),
|
197 |
+
])
|
198 |
+
self.layers.append(layer)
|
199 |
+
|
200 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
201 |
+
"""
|
202 |
+
Performs forward pass through the DualPathRNN.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
- x (torch.Tensor): Input tensor of shape (B, F, T, D).
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
- torch.Tensor: Output tensor of shape (B, F, T, D).
|
209 |
+
"""
|
210 |
+
|
211 |
+
time_dim = x.shape[2]
|
212 |
+
|
213 |
+
for time_layer, freq_layer, rfft_layer in self.layers:
|
214 |
+
B, F, T, D = x.shape
|
215 |
+
|
216 |
+
x = x.reshape((B * F), T, D)
|
217 |
+
x = time_layer(x)
|
218 |
+
x = x.reshape(B, F, T, D)
|
219 |
+
x = x.permute(0, 2, 1, 3)
|
220 |
+
|
221 |
+
x = x.reshape((B * T), F, D)
|
222 |
+
x = freq_layer(x)
|
223 |
+
x = x.reshape(B, T, F, D)
|
224 |
+
x = x.permute(0, 2, 1, 3)
|
225 |
+
|
226 |
+
x = rfft_layer(x, time_dim)
|
227 |
+
|
228 |
+
return x
|
models/scnet_unofficial/modules/sd_encoder.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from models.scnet_unofficial.utils import create_intervals
|
7 |
+
|
8 |
+
|
9 |
+
class Downsample(nn.Module):
|
10 |
+
"""
|
11 |
+
Downsample class implements a module for downsampling input tensors using 2D convolution.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
- input_dim (int): Dimensionality of the input channels.
|
15 |
+
- output_dim (int): Dimensionality of the output channels.
|
16 |
+
- stride (int): Stride value for the convolution operation.
|
17 |
+
|
18 |
+
Shapes:
|
19 |
+
- Input: (B, C_in, F, T) where
|
20 |
+
B is batch size,
|
21 |
+
C_in is the number of input channels,
|
22 |
+
F is the frequency dimension,
|
23 |
+
T is the time dimension.
|
24 |
+
- Output: (B, C_out, F // stride, T) where
|
25 |
+
B is batch size,
|
26 |
+
C_out is the number of output channels,
|
27 |
+
F // stride is the downsampled frequency dimension.
|
28 |
+
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
input_dim: int,
|
34 |
+
output_dim: int,
|
35 |
+
stride: int,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Initializes Downsample with input dimension, output dimension, and stride.
|
39 |
+
"""
|
40 |
+
super().__init__()
|
41 |
+
self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1))
|
42 |
+
|
43 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
44 |
+
"""
|
45 |
+
Performs forward pass through the Downsample module.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
- x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
- torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T).
|
52 |
+
"""
|
53 |
+
return self.conv(x)
|
54 |
+
|
55 |
+
|
56 |
+
class ConvolutionModule(nn.Module):
|
57 |
+
"""
|
58 |
+
ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
- input_dim (int): Dimensionality of the input features.
|
62 |
+
- hidden_dim (int): Dimensionality of the hidden features.
|
63 |
+
- kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
|
64 |
+
- bias (bool, optional): If True, adds a learnable bias to the output. Default is False.
|
65 |
+
|
66 |
+
Shapes:
|
67 |
+
- Input: (B, T, D) where
|
68 |
+
B is batch size,
|
69 |
+
T is sequence length,
|
70 |
+
D is input dimensionality.
|
71 |
+
- Output: (B, T, D) where
|
72 |
+
B is batch size,
|
73 |
+
T is sequence length,
|
74 |
+
D is input dimensionality.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
input_dim: int,
|
80 |
+
hidden_dim: int,
|
81 |
+
kernel_sizes: List[int],
|
82 |
+
bias: bool = False,
|
83 |
+
) -> None:
|
84 |
+
"""
|
85 |
+
Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias.
|
86 |
+
"""
|
87 |
+
super().__init__()
|
88 |
+
self.sequential = nn.Sequential(
|
89 |
+
nn.GroupNorm(num_groups=1, num_channels=input_dim),
|
90 |
+
nn.Conv1d(
|
91 |
+
input_dim,
|
92 |
+
2 * hidden_dim,
|
93 |
+
kernel_sizes[0],
|
94 |
+
stride=1,
|
95 |
+
padding=(kernel_sizes[0] - 1) // 2,
|
96 |
+
bias=bias,
|
97 |
+
),
|
98 |
+
nn.GLU(dim=1),
|
99 |
+
nn.Conv1d(
|
100 |
+
hidden_dim,
|
101 |
+
hidden_dim,
|
102 |
+
kernel_sizes[1],
|
103 |
+
stride=1,
|
104 |
+
padding=(kernel_sizes[1] - 1) // 2,
|
105 |
+
groups=hidden_dim,
|
106 |
+
bias=bias,
|
107 |
+
),
|
108 |
+
nn.GroupNorm(num_groups=1, num_channels=hidden_dim),
|
109 |
+
nn.SiLU(),
|
110 |
+
nn.Conv1d(
|
111 |
+
hidden_dim,
|
112 |
+
input_dim,
|
113 |
+
kernel_sizes[2],
|
114 |
+
stride=1,
|
115 |
+
padding=(kernel_sizes[2] - 1) // 2,
|
116 |
+
bias=bias,
|
117 |
+
),
|
118 |
+
)
|
119 |
+
|
120 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
121 |
+
"""
|
122 |
+
Performs forward pass through the ConvolutionModule.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
- x (torch.Tensor): Input tensor of shape (B, T, D).
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
- torch.Tensor: Output tensor of shape (B, T, D).
|
129 |
+
"""
|
130 |
+
x = x.transpose(1, 2)
|
131 |
+
x = x + self.sequential(x)
|
132 |
+
x = x.transpose(1, 2)
|
133 |
+
return x
|
134 |
+
|
135 |
+
|
136 |
+
class SDLayer(nn.Module):
|
137 |
+
"""
|
138 |
+
SDLayer class implements a subband decomposition layer with downsampling and convolutional modules.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
- subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition.
|
142 |
+
- input_dim (int): Dimensionality of the input channels.
|
143 |
+
- output_dim (int): Dimensionality of the output channels after downsampling.
|
144 |
+
- downsample_stride (int): Stride value for the downsampling operation.
|
145 |
+
- n_conv_modules (int): Number of convolutional modules.
|
146 |
+
- kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
|
147 |
+
- bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True.
|
148 |
+
|
149 |
+
Shapes:
|
150 |
+
- Input: (B, Fi, T, Ci) where
|
151 |
+
B is batch size,
|
152 |
+
Fi is the number of input subbands,
|
153 |
+
T is sequence length, and
|
154 |
+
Ci is the number of input channels.
|
155 |
+
- Output: (B, Fi+1, T, Ci+1) where
|
156 |
+
B is batch size,
|
157 |
+
Fi+1 is the number of output subbands,
|
158 |
+
T is sequence length,
|
159 |
+
Ci+1 is the number of output channels.
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
subband_interval: Tuple[float, float],
|
165 |
+
input_dim: int,
|
166 |
+
output_dim: int,
|
167 |
+
downsample_stride: int,
|
168 |
+
n_conv_modules: int,
|
169 |
+
kernel_sizes: List[int],
|
170 |
+
bias: bool = True,
|
171 |
+
):
|
172 |
+
"""
|
173 |
+
Initializes SDLayer with subband interval, input dimension,
|
174 |
+
output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias.
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
self.subband_interval = subband_interval
|
178 |
+
self.downsample = Downsample(input_dim, output_dim, downsample_stride)
|
179 |
+
self.activation = nn.GELU()
|
180 |
+
conv_modules = [
|
181 |
+
ConvolutionModule(
|
182 |
+
input_dim=output_dim,
|
183 |
+
hidden_dim=output_dim // 4,
|
184 |
+
kernel_sizes=kernel_sizes,
|
185 |
+
bias=bias,
|
186 |
+
)
|
187 |
+
for _ in range(n_conv_modules)
|
188 |
+
]
|
189 |
+
self.conv_modules = nn.Sequential(*conv_modules)
|
190 |
+
|
191 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
192 |
+
"""
|
193 |
+
Performs forward pass through the SDLayer.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
- x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
- torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1).
|
200 |
+
"""
|
201 |
+
B, F, T, C = x.shape
|
202 |
+
x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)]
|
203 |
+
x = x.permute(0, 3, 1, 2)
|
204 |
+
x = self.downsample(x)
|
205 |
+
x = self.activation(x)
|
206 |
+
x = x.permute(0, 2, 3, 1)
|
207 |
+
|
208 |
+
B, F, T, C = x.shape
|
209 |
+
x = x.reshape((B * F), T, C)
|
210 |
+
x = self.conv_modules(x)
|
211 |
+
x = x.reshape(B, F, T, C)
|
212 |
+
|
213 |
+
return x
|
214 |
+
|
215 |
+
|
216 |
+
class SDBlock(nn.Module):
|
217 |
+
"""
|
218 |
+
SDBlock class implements a block with subband decomposition layers and global convolution.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
- input_dim (int): Dimensionality of the input channels.
|
222 |
+
- output_dim (int): Dimensionality of the output channels.
|
223 |
+
- bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
|
224 |
+
- downsample_strides (List[int]): List of stride values for downsampling in each subband layer.
|
225 |
+
- n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer.
|
226 |
+
- kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None.
|
227 |
+
|
228 |
+
Shapes:
|
229 |
+
- Input: (B, Fi, T, Ci) where
|
230 |
+
B is batch size,
|
231 |
+
Fi is the number of input subbands,
|
232 |
+
T is sequence length,
|
233 |
+
Ci is the number of input channels.
|
234 |
+
- Output: (B, Fi+1, T, Ci+1) where
|
235 |
+
B is batch size,
|
236 |
+
Fi+1 is the number of output subbands,
|
237 |
+
T is sequence length,
|
238 |
+
Ci+1 is the number of output channels.
|
239 |
+
"""
|
240 |
+
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
input_dim: int,
|
244 |
+
output_dim: int,
|
245 |
+
bandsplit_ratios: List[float],
|
246 |
+
downsample_strides: List[int],
|
247 |
+
n_conv_modules: List[int],
|
248 |
+
kernel_sizes: List[int] = None,
|
249 |
+
):
|
250 |
+
"""
|
251 |
+
Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes.
|
252 |
+
"""
|
253 |
+
super().__init__()
|
254 |
+
if kernel_sizes is None:
|
255 |
+
kernel_sizes = [3, 3, 1]
|
256 |
+
assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1."
|
257 |
+
subband_intervals = create_intervals(bandsplit_ratios)
|
258 |
+
self.sd_layers = nn.ModuleList(
|
259 |
+
SDLayer(
|
260 |
+
input_dim=input_dim,
|
261 |
+
output_dim=output_dim,
|
262 |
+
subband_interval=sbi,
|
263 |
+
downsample_stride=dss,
|
264 |
+
n_conv_modules=ncm,
|
265 |
+
kernel_sizes=kernel_sizes,
|
266 |
+
)
|
267 |
+
for sbi, dss, ncm in zip(
|
268 |
+
subband_intervals, downsample_strides, n_conv_modules
|
269 |
+
)
|
270 |
+
)
|
271 |
+
self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1)
|
272 |
+
|
273 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
274 |
+
"""
|
275 |
+
Performs forward pass through the SDBlock.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
- x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
- Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor.
|
282 |
+
"""
|
283 |
+
x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1)
|
284 |
+
x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
285 |
+
return x, x_skip
|