# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import faster_whisper
from typing import List, Union, Optional, NamedTuple
import torch
import numpy as np
import tqdm
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from whisperx.types import TranscriptionResult, SingleSegment
from whisperx.asr import WhisperModel, FasterWhisperPipeline, find_numeral_symbol_tokens


class VadFreeFasterWhisperPipeline(FasterWhisperPipeline):
    """
    FasterWhisperModel without VAD
    """

    def __init__(
        self,
        model,
        options: NamedTuple,
        tokenizer=None,
        device: Union[int, str, "torch.device"] = -1,
        framework="pt",
        language: Optional[str] = None,
        suppress_numerals: bool = False,
        **kwargs,
    ):
        """
        Initialize the VadFreeFasterWhisperPipeline.

        Args:
            model: The Whisper model instance.
            options: Transcription options.
            tokenizer: The tokenizer instance.
            device: Device to run the model on.
            framework: The framework to use ('pt' for PyTorch).
            language: The language for transcription.
            suppress_numerals: Whether to suppress numeral tokens.
            **kwargs: Additional keyword arguments.

        Returns:
            None
        """
        super().__init__(
            model=model,
            vad=None,
            vad_params={},
            options=options,
            tokenizer=tokenizer,
            device=device,
            framework=framework,
            language=language,
            suppress_numerals=suppress_numerals,
            **kwargs,
        )

    def detect_language(self, audio: np.ndarray):
        """
        Detect the language of the audio.

        Args:
            audio (np.ndarray): The input audio signal.

        Returns:
            tuple: Detected language and its probability.
        """
        model_n_mels = self.model.feat_kwargs.get("feature_size")
        if audio.shape[0] > N_SAMPLES:
            # Randomly sample N_SAMPLES from the audio array
            start_index = np.random.randint(0, audio.shape[0] - N_SAMPLES)
            audio_sample = audio[start_index : start_index + N_SAMPLES]
        else:
            audio_sample = audio[:N_SAMPLES]
        padding = 0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]
        segment = log_mel_spectrogram(
            audio_sample,
            n_mels=model_n_mels if model_n_mels is not None else 80,
            padding=padding,
        )
        encoder_output = self.model.encode(segment)
        results = self.model.model.detect_language(encoder_output)
        language_token, language_probability = results[0][0]
        language = language_token[2:-2]
        return language, language_probability

    def transcribe(
        self,
        audio: Union[str, np.ndarray],
        vad_segments: List[dict],
        batch_size=None,
        num_workers=0,
        language=None,
        task=None,
        chunk_size=30,
        print_progress=False,
        combined_progress=False,
    ) -> TranscriptionResult:
        """
        Transcribe the audio into text.

        Args:
            audio (Union[str, np.ndarray]): The input audio signal or path to audio file.
            vad_segments (List[dict]): List of VAD segments.
            batch_size (int, optional): Batch size for transcription. Defaults to None.
            num_workers (int, optional): Number of workers for loading data. Defaults to 0.
            language (str, optional): Language for transcription. Defaults to None.
            task (str, optional): Task type ('transcribe' or 'translate'). Defaults to None.
            chunk_size (int, optional): Size of chunks for processing. Defaults to 30.
            print_progress (bool, optional): Whether to print progress. Defaults to False.
            combined_progress (bool, optional): Whether to combine progress. Defaults to False.

        Returns:
            TranscriptionResult: The transcription result containing segments and language.
        """
        if isinstance(audio, str):
            audio = load_audio(audio)

        def data(audio, segments):
            for seg in segments:
                f1 = int(seg["start"] * SAMPLE_RATE)
                f2 = int(seg["end"] * SAMPLE_RATE)
                yield {"inputs": audio[f1:f2]}

        if self.tokenizer is None:
            language = language or self.detect_language(audio)
            task = task or "transcribe"
            self.tokenizer = faster_whisper.tokenizer.Tokenizer(
                self.model.hf_tokenizer,
                self.model.model.is_multilingual,
                task=task,
                language=language,
            )
        else:
            language = language or self.tokenizer.language_code
            task = task or self.tokenizer.task
            if task != self.tokenizer.task or language != self.tokenizer.language_code:
                self.tokenizer = faster_whisper.tokenizer.Tokenizer(
                    self.model.hf_tokenizer,
                    self.model.model.is_multilingual,
                    task=task,
                    language=language,
                )

        if self.suppress_numerals:
            previous_suppress_tokens = self.options.suppress_tokens
            numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
            new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
            new_suppressed_tokens = list(set(new_suppressed_tokens))
            self.options = self.options._replace(suppress_tokens=new_suppressed_tokens)

        segments: List[SingleSegment] = []
        batch_size = batch_size or self._batch_size
        total_segments = len(vad_segments)
        progress = tqdm.tqdm(total=total_segments, desc="Transcribing")
        for idx, out in enumerate(
            self.__call__(
                data(audio, vad_segments),
                batch_size=batch_size,
                num_workers=num_workers,
            )
        ):
            if print_progress:
                progress.update(1)
            text = out["text"]
            if batch_size in [0, 1, None]:
                text = text[0]
            segments.append(
                {
                    "text": text,
                    "start": round(vad_segments[idx]["start"], 3),
                    "end": round(vad_segments[idx]["end"], 3),
                    "speaker": vad_segments[idx].get("speaker", None),
                }
            )

        # revert the tokenizer if multilingual inference is enabled
        if self.preset_language is None:
            self.tokenizer = None

        # revert suppressed tokens if suppress_numerals is enabled
        if self.suppress_numerals:
            self.options = self.options._replace(
                suppress_tokens=previous_suppress_tokens
            )

        return {"segments": segments, "language": language}


def load_asr_model(
    whisper_arch: str,
    device: str,
    device_index: int = 0,
    compute_type: str = "float16",
    asr_options: Optional[dict] = None,
    language: Optional[str] = None,
    vad_model=None,
    vad_options=None,
    model: Optional[WhisperModel] = None,
    task: str = "transcribe",
    download_root: Optional[str] = None,
    threads: int = 4,
) -> VadFreeFasterWhisperPipeline:
    """
    Load a Whisper model for inference.

    Args:
        whisper_arch (str): The name of the Whisper model to load.
        device (str): The device to load the model on.
        device_index (int, optional): The device index. Defaults to 0.
        compute_type (str, optional): The compute type to use for the model. Defaults to "float16".
        asr_options (Optional[dict], optional): Options for ASR. Defaults to None.
        language (Optional[str], optional): The language of the model. Defaults to None.
        vad_model: The VAD model instance. Defaults to None.
        vad_options: Options for VAD. Defaults to None.
        model (Optional[WhisperModel], optional): The WhisperModel instance to use. Defaults to None.
        task (str, optional): The task type ('transcribe' or 'translate'). Defaults to "transcribe".
        download_root (Optional[str], optional): The root directory to download the model to. Defaults to None.
        threads (int, optional): The number of CPU threads to use per worker. Defaults to 4.

    Returns:
        VadFreeFasterWhisperPipeline: The loaded Whisper pipeline.

    Raises:
        ValueError: If the whisper architecture is not recognized.
    """

    if whisper_arch.endswith(".en"):
        language = "en"

    model = model or WhisperModel(
        whisper_arch,
        device=device,
        device_index=device_index,
        compute_type=compute_type,
        download_root=download_root,
        cpu_threads=threads,
    )
    if language is not None:
        tokenizer = faster_whisper.tokenizer.Tokenizer(
            model.hf_tokenizer,
            model.model.is_multilingual,
            task=task,
            language=language,
        )
    else:
        print(
            "No language specified, language will be detected for each audio file (increases inference time)."
        )
        tokenizer = None

    default_asr_options = {
        "beam_size": 5,
        "best_of": 5,
        "patience": 1,
        "length_penalty": 1,
        "repetition_penalty": 1,
        "no_repeat_ngram_size": 0,
        "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
        "compression_ratio_threshold": 2.4,
        "log_prob_threshold": -1.0,
        "no_speech_threshold": 0.6,
        "condition_on_previous_text": False,
        "prompt_reset_on_temperature": 0.5,
        "initial_prompt": None,
        "prefix": None,
        "suppress_blank": True,
        "suppress_tokens": [-1],
        "without_timestamps": True,
        "max_initial_timestamp": 0.0,
        "word_timestamps": False,
        "prepend_punctuations": "\"'“¿([{-",
        "append_punctuations": "\"'.。,,!!??::”)]}、",
        "suppress_numerals": False,
        "max_new_tokens": None,
        "clip_timestamps": None,
        "hallucination_silence_threshold": None,
    }

    if asr_options is not None:
        default_asr_options.update(asr_options)

    suppress_numerals = default_asr_options["suppress_numerals"]
    del default_asr_options["suppress_numerals"]

    default_asr_options = faster_whisper.transcribe.TranscriptionOptions(
        **default_asr_options
    )

    return VadFreeFasterWhisperPipeline(
        model=model,
        options=default_asr_options,
        tokenizer=tokenizer,
        language=language,
        suppress_numerals=suppress_numerals,
    )