"""Vocoder wrapper.

Copyright PolyAI Limited.
"""
import enum

import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
from speechtokenizer import SpeechTokenizer


class VocoderType(enum.Enum):
    SPEECHTOKENIZER = ("SPEECHTOKENIZER", 320)

    def __init__(self, name, compression_ratio):
        self._name_ = name
        self.compression_ratio = compression_ratio

    def get_vocoder(self, ckpt_path, config_path, **kwargs):
        if self.name == "SPEECHTOKENIZER":
            if ckpt_path:
                vocoder = STWrapper(ckpt_path, config_path)
            else:
                vocoder = STWrapper()
        else:
            raise ValueError(f"Unknown vocoder type {self.name}")
        return vocoder


class STWrapper(nn.Module):
    def __init__(
            self, 
            ckpt_path: str = './ckpt/speechtokenizer/SpeechTokenizer.pt',
            config_path = './ckpt/speechtokenizer/config.json',
        ):
        super().__init__()
        self.model = SpeechTokenizer.load_from_checkpoint(
            config_path, ckpt_path)

    def eval(self):
        self.model.eval()

    @torch.no_grad()
    def decode(self, codes: torch.Tensor, verbose: bool = False):
        original_device = codes.device

        codes = codes.to(self.device)
        audio_array = self.model.decode(codes)

        return audio_array.to(original_device)

    def decode_to_file(self, codes_path, out_path) -> None:
        codes = np.load(codes_path)
        codes = torch.from_numpy(codes)
        wav = self.decode(codes).cpu().numpy()
        sf.write(out_path, wav, samplerate=self.model.sample_rate)

    @torch.no_grad()
    def encode(self, wav, verbose=False, n_quantizers: int = None):
        original_device = wav.device
        wav = wav.to(self.device)
        codes = self.model.encode(wav) # codes: (n_q, B, T)
        return codes.to(original_device)

    def encode_to_file(self, wav_path, out_path) -> None:
        wav, _ = sf.read(wav_path, dtype='float32')
        wav = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0)
        codes = self.encode(wav).cpu().numpy()
        np.save(out_path, codes)

    def remove_weight_norm(self):
        pass

    @property
    def device(self):
        return next(self.model.parameters()).device