import argparse
import os
import shutil
from pathlib import Path

import soundfile as sf
import torch
from tqdm import tqdm

from common.log import logger
from common.stdout_wrapper import SAFE_STDOUT

vad_model, utils = torch.hub.load(
    repo_or_dir="snakers4/silero-vad",
    model="silero_vad",
    onnx=True,
    trust_repo=True,
)

(get_speech_timestamps, _, read_audio, *_) = utils


def get_stamps(
    audio_file, min_silence_dur_ms: int = 700, min_sec: float = 2, max_sec: float = 12
):
    """
    min_silence_dur_ms: int (ミリ秒):
        このミリ秒数以上を無音だと判断する。
        逆に、この秒数以下の無音区間では区切られない。
        小さくすると、音声がぶつ切りに小さくなりすぎ、
        大きくすると音声一つ一つが長くなりすぎる。
        データセットによってたぶん要調整。
    min_sec: float (秒):
        この秒数より小さい発話は無視する。
    max_sec: float (秒):
        この秒数より大きい発話は無視する。
    """

    sampling_rate = 16000  # 16kHzか8kHzのみ対応

    min_ms = int(min_sec * 1000)

    wav = read_audio(audio_file, sampling_rate=sampling_rate)
    speech_timestamps = get_speech_timestamps(
        wav,
        vad_model,
        sampling_rate=sampling_rate,
        min_silence_duration_ms=min_silence_dur_ms,
        min_speech_duration_ms=min_ms,
        max_speech_duration_s=max_sec,
    )

    return speech_timestamps


def split_wav(
    audio_file,
    target_dir="raw",
    min_sec=2,
    max_sec=12,
    min_silence_dur_ms=700,
):
    margin = 200  # ミリ秒単位で、音声の前後に余裕を持たせる
    speech_timestamps = get_stamps(
        audio_file,
        min_silence_dur_ms=min_silence_dur_ms,
        min_sec=min_sec,
        max_sec=max_sec,
    )

    data, sr = sf.read(audio_file)

    total_ms = len(data) / sr * 1000

    file_name = os.path.basename(audio_file).split(".")[0]
    os.makedirs(target_dir, exist_ok=True)

    total_time_ms = 0

    # タイムスタンプに従って分割し、ファイルに保存
    for i, ts in enumerate(speech_timestamps):
        start_ms = max(ts["start"] / 16 - margin, 0)
        end_ms = min(ts["end"] / 16 + margin, total_ms)

        start_sample = int(start_ms / 1000 * sr)
        end_sample = int(end_ms / 1000 * sr)
        segment = data[start_sample:end_sample]

        sf.write(os.path.join(target_dir, f"{file_name}-{i}.wav"), segment, sr)
        total_time_ms += end_ms - start_ms

    return total_time_ms / 1000


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--min_sec", "-m", type=float, default=2, help="Minimum seconds of a slice"
    )
    parser.add_argument(
        "--max_sec", "-M", type=float, default=12, help="Maximum seconds of a slice"
    )
    parser.add_argument(
        "--input_dir",
        "-i",
        type=str,
        default="inputs",
        help="Directory of input wav files",
    )
    parser.add_argument(
        "--output_dir",
        "-o",
        type=str,
        default="raw",
        help="Directory of output wav files",
    )
    parser.add_argument(
        "--min_silence_dur_ms",
        "-s",
        type=int,
        default=700,
        help="Silence above this duration (ms) is considered as a split point.",
    )
    args = parser.parse_args()

    input_dir = args.input_dir
    output_dir = args.output_dir
    min_sec = args.min_sec
    max_sec = args.max_sec
    min_silence_dur_ms = args.min_silence_dur_ms

    wav_files = Path(input_dir).glob("**/*.wav")
    wav_files = list(wav_files)
    logger.info(f"Found {len(wav_files)} wav files.")
    if os.path.exists(output_dir):
        logger.warning(f"Output directory {output_dir} already exists, deleting...")
        shutil.rmtree(output_dir)

    total_sec = 0
    for wav_file in tqdm(wav_files, file=SAFE_STDOUT):
        time_sec = split_wav(
            audio_file=str(wav_file),
            target_dir=output_dir,
            min_sec=min_sec,
            max_sec=max_sec,
            min_silence_dur_ms=min_silence_dur_ms,
        )
        total_sec += time_sec

    logger.info(f"Slice done! Total time: {total_sec / 60:.2f} min.")