import librosa
import soundfile
from pathlib import Path

import torch

from .splitter import Splitter


def sound_split(
        model: Splitter,
        input: str = "data/audio_example.mp3",
        output_dir: str = "output",
        write_src: bool = False,
) -> None:
    sr = 44100
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # splitter = Splitter.from_pretrained(model_path).to(device).eval()

    # load wav audio
    fpath_src = Path(input)
    wav, _ = librosa.load(
        fpath_src,
        mono=False,
        res_type="kaiser_fast",
        sr=sr,
    )
    wav = torch.Tensor(wav).to(device)

    # normalize audio
    # wav_torch = wav / (wav.max() + 1e-8)

    with torch.no_grad():
        stems = model.separate(wav)

    if write_src:
        stems["input"] = wav
    for name, stem in stems.items():
        fpath_dst = Path(output_dir) / f"{name}.wav"
        print(f"Writing {fpath_dst}")
        fpath_dst.parent.mkdir(exist_ok=True)
        soundfile.write(fpath_dst, stem.cpu().detach().numpy().T, sr, "PCM_16")