Spaces:
Running
Running
# coding: utf-8 | |
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' | |
import sys | |
import os | |
sys.path.append(os.path.dirname(__file__)) | |
import argparse | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import yaml | |
import os | |
import soundfile as sf | |
import matplotlib.pyplot as plt | |
from ml_collections import ConfigDict | |
from omegaconf import OmegaConf | |
from tqdm.auto import tqdm | |
from typing import Dict, List, Tuple, Any, Union | |
import loralib as lora | |
def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]: | |
""" | |
Load the configuration from the specified path based on the model type. | |
Parameters: | |
---------- | |
model_type : str | |
The type of model to load (e.g., 'htdemucs', 'mdx23c', etc.). | |
config_path : str | |
The path to the YAML or OmegaConf configuration file. | |
Returns: | |
------- | |
config : Any | |
The loaded configuration, which can be in different formats (e.g., OmegaConf or ConfigDict). | |
Raises: | |
------ | |
FileNotFoundError: | |
If the configuration file at `config_path` is not found. | |
ValueError: | |
If there is an error loading the configuration file. | |
""" | |
try: | |
with open(config_path, 'r') as f: | |
if model_type == 'htdemucs': | |
config = OmegaConf.load(config_path) | |
else: | |
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) | |
return config | |
except FileNotFoundError: | |
raise FileNotFoundError(f"Configuration file not found at {config_path}") | |
except Exception as e: | |
raise ValueError(f"Error loading configuration: {e}") | |
''' | |
def get_model_from_config(model_type: str, config_path: str) -> Tuple: | |
""" | |
Load the model specified by the model type and configuration file. | |
Parameters: | |
---------- | |
model_type : str | |
The type of model to load (e.g., 'mdx23c', 'htdemucs', 'scnet', etc.). | |
config_path : str | |
The path to the configuration file (YAML or OmegaConf format). | |
Returns: | |
------- | |
model : nn.Module or None | |
The initialized model based on the `model_type`, or None if the model type is not recognized. | |
config : Any | |
The configuration used to initialize the model. This could be in different formats | |
depending on the model type (e.g., OmegaConf, ConfigDict). | |
Raises: | |
------ | |
ValueError: | |
If the `model_type` is unknown or an error occurs during model initialization. | |
""" | |
config = load_config(model_type, config_path) | |
if model_type == 'mdx23c': | |
from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net | |
model = TFC_TDF_net(config) | |
elif model_type == 'htdemucs': | |
from models.demucs4ht import get_model | |
model = get_model(config) | |
elif model_type == 'segm_models': | |
from models.segm_models import Segm_Models_Net | |
model = Segm_Models_Net(config) | |
elif model_type == 'torchseg': | |
from models.torchseg_models import Torchseg_Net | |
model = Torchseg_Net(config) | |
elif model_type == 'mel_band_roformer': | |
from models.bs_roformer import MelBandRoformer | |
model = MelBandRoformer(**dict(config.model)) | |
elif model_type == 'bs_roformer': | |
from models.bs_roformer import BSRoformer | |
model = BSRoformer(**dict(config.model)) | |
elif model_type == 'swin_upernet': | |
from models.upernet_swin_transformers import Swin_UperNet_Model | |
model = Swin_UperNet_Model(config) | |
elif model_type == 'bandit': | |
from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple | |
model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model) | |
elif model_type == 'bandit_v2': | |
from models.bandit_v2.bandit import Bandit | |
model = Bandit(**config.kwargs) | |
elif model_type == 'scnet_unofficial': | |
from models.scnet_unofficial import SCNet | |
model = SCNet(**config.model) | |
elif model_type == 'scnet': | |
from models.scnet import SCNet | |
model = SCNet(**config.model) | |
elif model_type == 'apollo': | |
from models.look2hear.models import BaseModel | |
model = BaseModel.apollo(**config.model) | |
elif model_type == 'bs_mamba2': | |
from models.ts_bs_mamba2 import Separator | |
model = Separator(**config.model) | |
else: | |
raise ValueError(f"Unknown model type: {model_type}") | |
return model, config | |
''' | |
def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]: | |
""" | |
Reads an audio file, ensuring mono audio is converted to two-dimensional format, | |
and transposes the data to have channels as the first dimension. | |
Parameters | |
---------- | |
path : str | |
Path to the audio file. | |
skip_err: bool | |
If true, not raise errors | |
instr: | |
name of instument | |
Returns | |
------- | |
Tuple[np.ndarray, int] | |
A tuple containing: | |
- Transposed audio data as a NumPy array with shape (channels, length). | |
For mono audio, the shape will be (1, length). | |
- Sampling rate (int), e.g., 44100. | |
""" | |
try: | |
mix, sr = sf.read(path) | |
except Exception as e: | |
if skip_err: | |
print(f"No stem {instr}: skip!") | |
return None, None | |
else: | |
raise RuntimeError(f"Error reading the file at {path}: {e}") | |
else: | |
if len(mix.shape) == 1: # For mono audio | |
mix = np.expand_dims(mix, axis=-1) | |
return mix.T, sr | |
def normalize_audio(audio: np.ndarray) -> tuple[np.ndarray, Dict[str, float]]: | |
""" | |
Normalize an audio signal by subtracting the mean and dividing by the standard deviation. | |
Parameters: | |
---------- | |
audio : np.ndarray | |
Input audio array with shape (channels, time) or (time,). | |
Returns: | |
------- | |
tuple[np.ndarray, dict[str, float]] | |
- Normalized audio array with the same shape as the input. | |
- Dictionary containing the mean and standard deviation of the original audio. | |
""" | |
mono = audio.mean(0) | |
mean, std = mono.mean(), mono.std() | |
return (audio - mean) / std, {"mean": mean, "std": std} | |
def denormalize_audio(audio: np.ndarray, norm_params: Dict[str, float]) -> np.ndarray: | |
""" | |
Denormalize an audio signal by reversing the normalization process (multiplying by the standard deviation | |
and adding the mean). | |
Parameters: | |
---------- | |
audio : np.ndarray | |
Normalized audio array to be denormalized. | |
norm_params : dict[str, float] | |
Dictionary containing the 'mean' and 'std' values used for normalization. | |
Returns: | |
------- | |
np.ndarray | |
Denormalized audio array with the same shape as the input. | |
""" | |
return audio * norm_params["std"] + norm_params["mean"] | |
def apply_tta( | |
config, | |
model: torch.nn.Module, | |
mix: torch.Tensor, | |
waveforms_orig: Dict[str, torch.Tensor], | |
device: torch.device, | |
model_type: str | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Apply Test-Time Augmentation (TTA) for source separation. | |
This function processes the input mixture with test-time augmentations, including | |
channel inversion and polarity inversion, to enhance the separation results. The | |
results from all augmentations are averaged to produce the final output. | |
Parameters: | |
---------- | |
config : Any | |
Configuration object containing model and processing parameters. | |
model : torch.nn.Module | |
The trained model used for source separation. | |
mix : torch.Tensor | |
The mixed audio tensor with shape (channels, time). | |
waveforms_orig : Dict[str, torch.Tensor] | |
Dictionary of original separated waveforms (before TTA) for each instrument. | |
device : torch.device | |
Device (CPU or CUDA) on which the model will be executed. | |
model_type : str | |
Type of the model being used (e.g., "demucs", "custom_model"). | |
Returns: | |
------- | |
Dict[str, torch.Tensor] | |
Updated dictionary of separated waveforms after applying TTA. | |
""" | |
# Create augmentations: channel inversion and polarity inversion | |
track_proc_list = [mix[::-1].copy(), -1.0 * mix.copy()] | |
# Process each augmented mixture | |
for i, augmented_mix in enumerate(track_proc_list): | |
waveforms = demix(config, model, augmented_mix, device, model_type=model_type) | |
for el in waveforms: | |
if i == 0: | |
waveforms_orig[el] += waveforms[el][::-1].copy() | |
else: | |
waveforms_orig[el] -= waveforms[el] | |
# Average the results across augmentations | |
for el in waveforms_orig: | |
waveforms_orig[el] /= len(track_proc_list) + 1 | |
return waveforms_orig | |
def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor: | |
""" | |
Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end. | |
This function creates a window of size `window_size` where the first `fade_size` elements | |
linearly increase from 0 to 1 (fade-in) and the last `fade_size` elements linearly decrease | |
from 1 to 0 (fade-out). The middle part of the window is filled with ones. | |
Parameters: | |
---------- | |
window_size : int | |
The total size of the window. | |
fade_size : int | |
The size of the fade-in and fade-out regions. | |
Returns: | |
------- | |
torch.Tensor | |
A tensor of shape (window_size,) containing the generated windowing array. | |
Example: | |
------- | |
If `window_size=10` and `fade_size=3`, the output will be: | |
tensor([0.0000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.0000]) | |
""" | |
fadein = torch.linspace(0, 1, fade_size) | |
fadeout = torch.linspace(1, 0, fade_size) | |
window = torch.ones(window_size) | |
window[-fade_size:] = fadeout | |
window[:fade_size] = fadein | |
return window | |
def demix( | |
config: ConfigDict, | |
model: torch.nn.Module, | |
mix: torch.Tensor, | |
device: torch.device, | |
model_type: str, | |
pbar: bool = False | |
) -> Tuple[List[Dict[str, np.ndarray]], np.ndarray]: | |
""" | |
Unified function for audio source separation with support for multiple processing modes. | |
This function separates audio into its constituent sources using either a generic custom logic | |
or a Demucs-specific logic. It supports batch processing and overlapping window-based chunking | |
for efficient and artifact-free separation. | |
Parameters: | |
---------- | |
config : ConfigDict | |
Configuration object containing audio and inference settings. | |
model : torch.nn.Module | |
The trained model used for audio source separation. | |
mix : torch.Tensor | |
Input audio tensor with shape (channels, time). | |
device : torch.device | |
The computation device (CPU or CUDA). | |
model_type : str, optional | |
Processing mode: | |
- "demucs" for logic specific to the Demucs model. | |
Default is "generic". | |
pbar : bool, optional | |
If True, displays a progress bar during chunk processing. Default is False. | |
Returns: | |
------- | |
Union[Dict[str, np.ndarray], np.ndarray] | |
- A dictionary mapping target instruments to separated audio sources if multiple instruments are present. | |
- A numpy array of the separated source if only one instrument is present. | |
""" | |
mix = torch.tensor(mix, dtype=torch.float32) | |
if model_type == 'htdemucs': | |
mode = 'demucs' | |
else: | |
mode = 'generic' | |
# Define processing parameters based on the mode | |
if mode == 'demucs': | |
chunk_size = config.training.samplerate * config.training.segment | |
num_instruments = len(config.training.instruments) | |
num_overlap = config.inference.num_overlap | |
step = chunk_size // num_overlap | |
else: | |
chunk_size = config.audio.chunk_size | |
num_instruments = len(prefer_target_instrument(config)) | |
num_overlap = config.inference.num_overlap | |
fade_size = chunk_size // 10 | |
step = chunk_size // num_overlap | |
border = chunk_size - step | |
length_init = mix.shape[-1] | |
windowing_array = _getWindowingArray(chunk_size, fade_size) | |
# Add padding for generic mode to handle edge artifacts | |
if length_init > 2 * border and border > 0: | |
mix = nn.functional.pad(mix, (border, border), mode="reflect") | |
batch_size = config.inference.batch_size | |
use_amp = getattr(config.training, 'use_amp', True) | |
with torch.cuda.amp.autocast(enabled=use_amp): | |
with torch.inference_mode(): | |
# Initialize result and counter tensors | |
req_shape = (num_instruments,) + mix.shape | |
result = torch.zeros(req_shape, dtype=torch.float32) | |
counter = torch.zeros(req_shape, dtype=torch.float32) | |
i = 0 | |
batch_data = [] | |
batch_locations = [] | |
progress_bar = tqdm( | |
total=mix.shape[1], desc="Processing audio chunks", leave=False | |
) if pbar else None | |
while i < mix.shape[1]: | |
# Extract chunk and apply padding if necessary | |
part = mix[:, i:i + chunk_size].to(device) | |
chunk_len = part.shape[-1] | |
if mode == "generic" and chunk_len > chunk_size // 2: | |
pad_mode = "reflect" | |
else: | |
pad_mode = "constant" | |
part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0) | |
batch_data.append(part) | |
batch_locations.append((i, chunk_len)) | |
i += step | |
# Process batch if it's full or the end is reached | |
if len(batch_data) >= batch_size or i >= mix.shape[1]: | |
arr = torch.stack(batch_data, dim=0) | |
x = model(arr) | |
if mode == "generic": | |
window = windowing_array.clone() # using clone() fixes the clicks at chunk edges when using batch_size=1 | |
if i - step == 0: # First audio chunk, no fadein | |
window[:fade_size] = 1 | |
elif i >= mix.shape[1]: # Last audio chunk, no fadeout | |
window[-fade_size:] = 1 | |
for j, (start, seg_len) in enumerate(batch_locations): | |
if mode == "generic": | |
result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() * window[..., :seg_len] | |
counter[..., start:start + seg_len] += window[..., :seg_len] | |
else: | |
result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() | |
counter[..., start:start + seg_len] += 1.0 | |
batch_data.clear() | |
batch_locations.clear() | |
if progress_bar: | |
progress_bar.update(step) | |
if progress_bar: | |
progress_bar.close() | |
# Compute final estimated sources | |
estimated_sources = result / counter | |
estimated_sources = estimated_sources.cpu().numpy() | |
np.nan_to_num(estimated_sources, copy=False, nan=0.0) | |
# Remove padding for generic mode | |
if mode == "generic": | |
if length_init > 2 * border and border > 0: | |
estimated_sources = estimated_sources[..., border:-border] | |
# Return the result as a dictionary or a single array | |
if mode == "demucs": | |
instruments = config.training.instruments | |
else: | |
instruments = prefer_target_instrument(config) | |
ret_data = {k: v for k, v in zip(instruments, estimated_sources)} | |
if mode == "demucs" and num_instruments <= 1: | |
return estimated_sources | |
else: | |
return ret_data | |
def prefer_target_instrument(config: ConfigDict) -> List[str]: | |
""" | |
Return the list of target instruments based on the configuration. | |
If a specific target instrument is specified in the configuration, | |
it returns a list with that instrument. Otherwise, it returns the list of instruments. | |
Parameters: | |
---------- | |
config : ConfigDict | |
Configuration object containing the list of instruments or the target instrument. | |
Returns: | |
------- | |
List[str] | |
A list of target instruments. | |
""" | |
if getattr(config.training, 'target_instrument', None): | |
return [config.training.target_instrument] | |
else: | |
return config.training.instruments | |
def load_not_compatible_weights(model: torch.nn.Module, weights: str, verbose: bool = False) -> None: | |
""" | |
Load weights into a model, handling mismatched shapes and dimensions. | |
Args: | |
model: PyTorch model into which the weights will be loaded. | |
weights: Path to the weights file. | |
verbose: If True, prints detailed information about matching and mismatched layers. | |
""" | |
new_model = model.state_dict() | |
old_model = torch.load(weights) | |
if 'state' in old_model: | |
# Fix for htdemucs weights loading | |
old_model = old_model['state'] | |
if 'state_dict' in old_model: | |
# Fix for apollo weights loading | |
old_model = old_model['state_dict'] | |
for el in new_model: | |
if el in old_model: | |
if verbose: | |
print(f'Match found for {el}!') | |
if new_model[el].shape == old_model[el].shape: | |
if verbose: | |
print('Action: Just copy weights!') | |
new_model[el] = old_model[el] | |
else: | |
if len(new_model[el].shape) != len(old_model[el].shape): | |
if verbose: | |
print('Action: Different dimension! Too lazy to write the code... Skip it') | |
else: | |
if verbose: | |
print(f'Shape is different: {tuple(new_model[el].shape)} != {tuple(old_model[el].shape)}') | |
ln = len(new_model[el].shape) | |
max_shape = [] | |
slices_old = [] | |
slices_new = [] | |
for i in range(ln): | |
max_shape.append(max(new_model[el].shape[i], old_model[el].shape[i])) | |
slices_old.append(slice(0, old_model[el].shape[i])) | |
slices_new.append(slice(0, new_model[el].shape[i])) | |
# print(max_shape) | |
# print(slices_old, slices_new) | |
slices_old = tuple(slices_old) | |
slices_new = tuple(slices_new) | |
max_matrix = np.zeros(max_shape, dtype=np.float32) | |
for i in range(ln): | |
max_matrix[slices_old] = old_model[el].cpu().numpy() | |
max_matrix = torch.from_numpy(max_matrix) | |
new_model[el] = max_matrix[slices_new] | |
else: | |
if verbose: | |
print(f'Match not found for {el}!') | |
model.load_state_dict( | |
new_model | |
) | |
def load_lora_weights(model: torch.nn.Module, lora_path: str, device: str = 'cpu') -> None: | |
""" | |
Load LoRA weights into a model. | |
This function updates the given model with LoRA-specific weights from the specified checkpoint file. | |
It does not require the checkpoint to match the model's full state dictionary, as only LoRA layers are updated. | |
Parameters: | |
---------- | |
model : Module | |
The PyTorch model into which the LoRA weights will be loaded. | |
lora_path : str | |
Path to the LoRA checkpoint file. | |
device : str, optional | |
The device to load the weights onto, by default 'cpu'. Common values are 'cpu' or 'cuda'. | |
Returns: | |
------- | |
None | |
The model is updated in place. | |
""" | |
lora_state_dict = torch.load(lora_path, map_location=device) | |
model.load_state_dict(lora_state_dict, strict=False) | |
def load_start_checkpoint(args: argparse.Namespace, model: torch.nn.Module, type_='train') -> None: | |
""" | |
Load the starting checkpoint for a model. | |
Args: | |
args: Parsed command-line arguments containing the checkpoint path. | |
model: PyTorch model to load the checkpoint into. | |
type_: how to load weights - for train we can load not fully compatible weights | |
""" | |
print(f'Start from checkpoint: {args.start_check_point}') | |
if type_ in ['train']: | |
if 1: | |
load_not_compatible_weights(model, args.start_check_point, verbose=False) | |
else: | |
model.load_state_dict(torch.load(args.start_check_point)) | |
else: | |
device='cpu' | |
if args.model_type in ['htdemucs', 'apollo']: | |
state_dict = torch.load(args.start_check_point, map_location=device, weights_only=False) | |
# Fix for htdemucs pretrained models | |
if 'state' in state_dict: | |
state_dict = state_dict['state'] | |
# Fix for apollo pretrained models | |
if 'state_dict' in state_dict: | |
state_dict = state_dict['state_dict'] | |
else: | |
state_dict = torch.load(args.start_check_point, map_location=device, weights_only=True) | |
model.load_state_dict(state_dict) | |
if args.lora_checkpoint: | |
print(f"Loading LoRA weights from: {args.lora_checkpoint}") | |
load_lora_weights(model, args.lora_checkpoint) | |
def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module: | |
""" | |
Replaces specific layers in the model with LoRA-extended versions. | |
Parameters: | |
---------- | |
config : Dict[str, Any] | |
Configuration containing parameters for LoRA. It should include a 'lora' key with parameters for `MergedLinear`. | |
model : nn.Module | |
The original model in which the layers will be replaced. | |
Returns: | |
------- | |
nn.Module | |
The modified model with the replaced layers. | |
""" | |
if 'lora' not in config: | |
raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.") | |
replaced_layers = 0 # Counter for replaced layers | |
for name, module in model.named_modules(): | |
hierarchy = name.split('.') | |
layer_name = hierarchy[-1] | |
# Check if this is the target layer to replace (and layer_name == 'to_qkv') | |
if isinstance(module, nn.Linear): | |
try: | |
# Get the parent module | |
parent_module = model | |
for submodule_name in hierarchy[:-1]: | |
parent_module = getattr(parent_module, submodule_name) | |
# Replace the module with LoRA-enabled layer | |
setattr( | |
parent_module, | |
layer_name, | |
lora.MergedLinear( | |
in_features=module.in_features, | |
out_features=module.out_features, | |
bias=module.bias is not None, | |
**config['lora'] | |
) | |
) | |
replaced_layers += 1 # Increment the counter | |
except Exception as e: | |
print(f"Error replacing layer {name}: {e}") | |
if replaced_layers == 0: | |
print("Warning: No layers were replaced. Check the model structure and configuration.") | |
else: | |
print(f"Number of layers replaced with LoRA: {replaced_layers}") | |
return model | |
def draw_spectrogram(waveform, sample_rate, length, output_file): | |
import librosa.display | |
# Cut only required part of spectorgram | |
x = waveform[:int(length * sample_rate), :] | |
X = librosa.stft(x.mean(axis=-1)) # perform short-term fourier transform on mono signal | |
Xdb = librosa.amplitude_to_db(np.abs(X), ref=np.max) # convert an amplitude spectrogram to dB-scaled spectrogram. | |
fig, ax = plt.subplots() | |
# plt.figure(figsize=(30, 10)) # initialize the fig size | |
img = librosa.display.specshow( | |
Xdb, | |
cmap='plasma', | |
sr=sample_rate, | |
x_axis='time', | |
y_axis='linear', | |
ax=ax | |
) | |
ax.set(title='File: ' + os.path.basename(output_file)) | |
fig.colorbar(img, ax=ax, format="%+2.f dB") | |
if output_file is not None: | |
plt.savefig(output_file) | |