from pydub import AudioSegment
from typing import Any, List, Dict, Union
from scipy.io.wavfile import write
import io
from modules.utils import rng
from modules.utils.audio import time_stretch, pitch_shift
from modules import generate_audio
from modules.normalization import text_normalize
import logging
import json
import copy
import numpy as np

from modules.speaker import Speaker

logger = logging.getLogger(__name__)


def audio_data_to_segment(audio_data, sr):
    byte_io = io.BytesIO()
    write(byte_io, rate=sr, data=audio_data)
    byte_io.seek(0)

    return AudioSegment.from_file(byte_io, format="wav")


def combine_audio_segments(audio_segments: list) -> AudioSegment:
    combined_audio = AudioSegment.empty()
    for segment in audio_segments:
        combined_audio += segment
    return combined_audio


def apply_prosody(
    audio_segment: AudioSegment, rate: float, volume: float, pitch: float
) -> AudioSegment:
    if rate != 1:
        audio_segment = time_stretch(audio_segment, rate)

    if volume != 0:
        audio_segment += volume

    if pitch != 0:
        audio_segment = pitch_shift(audio_segment, pitch)

    return audio_segment


def to_number(value, t, default=0):
    try:
        number = t(value)
        return number
    except (ValueError, TypeError) as e:
        return default


class SynthesizeSegments:
    batch_default_spk_seed = rng.np_rng()
    batch_default_infer_seed = rng.np_rng()

    def __init__(self, batch_size: int = 8):
        self.batch_size = batch_size

    def segment_to_generate_params(self, segment: Dict[str, Any]) -> Dict[str, Any]:
        if segment.get("params", None) is not None:
            return segment["params"]

        text = segment.get("text", "")
        is_end = segment.get("is_end", False)

        text = str(text).strip()

        attrs = segment.get("attrs", {})
        spk = attrs.get("spk", "")
        if isinstance(spk, str):
            spk = int(spk)
        seed = to_number(attrs.get("seed", ""), int, -1)
        top_k = to_number(attrs.get("top_k", ""), int, None)
        top_p = to_number(attrs.get("top_p", ""), float, None)
        temp = to_number(attrs.get("temp", ""), float, None)

        prompt1 = attrs.get("prompt1", "")
        prompt2 = attrs.get("prompt2", "")
        prefix = attrs.get("prefix", "")
        disable_normalize = attrs.get("normalize", "") == "False"

        params = {
            "text": text,
            "temperature": temp if temp is not None else 0.3,
            "top_P": top_p if top_p is not None else 0.5,
            "top_K": top_k if top_k is not None else 20,
            "spk": spk if spk else -1,
            "infer_seed": seed if seed else -1,
            "prompt1": prompt1 if prompt1 else "",
            "prompt2": prompt2 if prompt2 else "",
            "prefix": prefix if prefix else "",
        }

        if not disable_normalize:
            params["text"] = text_normalize(text, is_end=is_end)

        # Set default values for spk and infer_seed
        if params["spk"] == -1:
            params["spk"] = self.batch_default_spk_seed
        if params["infer_seed"] == -1:
            params["infer_seed"] = self.batch_default_infer_seed

        return params

    def bucket_segments(
        self, segments: List[Dict[str, Any]]
    ) -> List[List[Dict[str, Any]]]:
        # Create a dictionary to hold buckets
        buckets = {}
        for segment in segments:
            params = self.segment_to_generate_params(segment)

            key_params = copy.copy(params)
            if isinstance(key_params.get("spk"), Speaker):
                key_params["spk"] = str(key_params["spk"].id)
            key = json.dumps(
                {k: v for k, v in key_params.items() if k != "text"}, sort_keys=True
            )
            if key not in buckets:
                buckets[key] = []
            buckets[key].append(segment)

        # Convert dictionary to list of buckets
        bucket_list = list(buckets.values())
        return bucket_list

    def synthesize_segments(self, segments: List[Dict[str, Any]]) -> List[AudioSegment]:
        audio_segments = [None] * len(
            segments
        )  # Create a list with the same length as segments
        buckets = self.bucket_segments(segments)
        logger.debug(f"segments len: {len(segments)}")
        logger.debug(f"bucket pool size: {len(buckets)}")
        for bucket in buckets:
            for i in range(0, len(bucket), self.batch_size):
                batch = bucket[i : i + self.batch_size]
                param_arr = [
                    self.segment_to_generate_params(segment) for segment in batch
                ]
                texts = [params["text"] for params in param_arr]

                params = param_arr[0]  # Use the first segment to get the parameters
                audio_datas = generate_audio.generate_audio_batch(
                    texts=texts,
                    temperature=params["temperature"],
                    top_P=params["top_P"],
                    top_K=params["top_K"],
                    spk=params["spk"],
                    infer_seed=params["infer_seed"],
                    prompt1=params["prompt1"],
                    prompt2=params["prompt2"],
                    prefix=params["prefix"],
                )
                for idx, segment in enumerate(batch):
                    (sr, audio_data) = audio_datas[idx]
                    rate = float(segment.get("rate", "1.0"))
                    volume = float(segment.get("volume", "0"))
                    pitch = float(segment.get("pitch", "0"))

                    audio_segment = audio_data_to_segment(audio_data, sr)
                    audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
                    original_index = segments.index(
                        segment
                    )  # Get the original index of the segment
                    audio_segments[original_index] = (
                        audio_segment  # Place the audio_segment in the correct position
                    )

        return audio_segments


def generate_audio_segment(
    text: str,
    spk: int = -1,
    seed: int = -1,
    top_p: float = 0.5,
    top_k: int = 20,
    temp: float = 0.3,
    prompt1: str = "",
    prompt2: str = "",
    prefix: str = "",
    enable_normalize=True,
    is_end: bool = False,
) -> AudioSegment:
    if enable_normalize:
        text = text_normalize(text, is_end=is_end)

    logger.debug(f"generate segment: {text}")

    sample_rate, audio_data = generate_audio.generate_audio(
        text=text,
        temperature=temp if temp is not None else 0.3,
        top_P=top_p if top_p is not None else 0.5,
        top_K=top_k if top_k is not None else 20,
        spk=spk if spk else -1,
        infer_seed=seed if seed else -1,
        prompt1=prompt1 if prompt1 else "",
        prompt2=prompt2 if prompt2 else "",
        prefix=prefix if prefix else "",
    )

    byte_io = io.BytesIO()
    write(byte_io, sample_rate, audio_data)
    byte_io.seek(0)

    return AudioSegment.from_file(byte_io, format="wav")


def synthesize_segment(segment: Dict[str, Any]) -> Union[AudioSegment, None]:
    if "break" in segment:
        pause_segment = AudioSegment.silent(duration=segment["break"])
        return pause_segment

    attrs = segment.get("attrs", {})
    text = segment.get("text", "")
    is_end = segment.get("is_end", False)

    text = str(text).strip()

    if text == "":
        return None

    spk = attrs.get("spk", "")
    if isinstance(spk, str):
        spk = int(spk)
    seed = to_number(attrs.get("seed", ""), int, -1)
    top_k = to_number(attrs.get("top_k", ""), int, None)
    top_p = to_number(attrs.get("top_p", ""), float, None)
    temp = to_number(attrs.get("temp", ""), float, None)

    prompt1 = attrs.get("prompt1", "")
    prompt2 = attrs.get("prompt2", "")
    prefix = attrs.get("prefix", "")
    disable_normalize = attrs.get("normalize", "") == "False"

    audio_segment = generate_audio_segment(
        text,
        enable_normalize=not disable_normalize,
        spk=spk,
        seed=seed,
        top_k=top_k,
        top_p=top_p,
        temp=temp,
        prompt1=prompt1,
        prompt2=prompt2,
        prefix=prefix,
        is_end=is_end,
    )

    rate = float(attrs.get("rate", "1.0"))
    volume = float(attrs.get("volume", "0"))
    pitch = float(attrs.get("pitch", "0"))

    audio_segment = apply_prosody(audio_segment, rate, volume, pitch)

    return audio_segment


# 示例使用
if __name__ == "__main__":
    ssml_segments = [
        {
            "text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙  [lbreak]",
            "attrs": {"spk": 2, "temp": 0.1, "seed": 42},
        },
        {
            "text": "大🍉,一个大🍉,嘿,你的感觉真的很奇妙  [lbreak]",
            "attrs": {"spk": 2, "temp": 0.1, "seed": 42},
        },
        {
            "text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙  [lbreak]",
            "attrs": {"spk": 2, "temp": 0.3, "seed": 42},
        },
    ]

    synthesizer = SynthesizeSegments(batch_size=2)
    audio_segments = synthesizer.synthesize_segments(ssml_segments)
    combined_audio = combine_audio_segments(audio_segments)
    combined_audio.export("output.wav", format="wav")