# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random

import pyarrow.parquet as pq
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import numpy as np
import re

torchaudio.set_audio_backend('soundfile')

AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
CHORUS = {"intro": 0, "chorus": 1, "verse1": 2, "verse2": 3, "verse": 2,
		  "outro": 4}

metadata_pattern = re.compile(r'^\[(ti|ar|al|by|offset):.*\]$')
timestamp_pattern = re.compile(r'^\[\d{2}:\d{2}\.\d{2}\](.*)$')


def parquet_opener(data, mode='train', audio_data={}):
	""" Give url or local file, return file descriptor
        Inplace operation.

        Args:
            data(Iterable[str]): url or local file list

        Returns:
            Iterable[{src, stream}]
    """
	for sample in data:
		assert 'src' in sample

		url = sample['src']
		try:
			df = pq.read_table(url).to_pandas()
			for i in df.index:
				sample.update(dict(df.loc[i]))
				yield {**sample}
		except Exception as ex:
			logging.warning('Failed to open {}, ex info {}'.format(url, ex))


def clean_lyrics(data, mode="train"):
	for sample in data:
		lyrics = sample["text"]
		cleaned = []
		for line in lyrics.splitlines():
			if metadata_pattern.match(line):
				continue
			timestamp_match = timestamp_pattern.match(line)
			if timestamp_match:
				lyric = timestamp_match.group(1).strip()
				if lyric:
					cleaned.append(lyric)
			else:
				if line.strip():
					cleaned.append(line.strip())
		sample["text"] = '\n'.join(cleaned)
		yield sample


def cut_by_length(data, max_length=8000, num_times=4, mode="train"):
	for sample in data:
		if "semantic_token" in sample:
			sample["semantic_token"] = [
				sample["semantic_token"][0][:max_length]]
		if "acoustic_token" not in sample:
			sample["acoustic_token"] = sample["speech_token"]
		sample["acoustic_token"] = sample["acoustic_token"][
								   :max_length * num_times]

		yield sample


def filter(data,
           max_length=22500,  # 22500 #5min #10240
           max_acoustic_length=45000,
           min_length=10,
           min_acoustic_length=150,
           token_max_length=200,
           token_min_length=1,
           min_output_input_ratio=0.0005,
           max_output_input_ratio=1,
           mode='train'):
	""" Filter sample according to feature and label length
        Inplace operation.

        Args::
            data: Iterable[{key, wav, label, sample_rate}]
            max_length: drop utterance which is greater than max_length(10ms)
            min_length: drop utterance which is less than min_length(10ms)
            token_max_length: drop utterance which is greater than
                token_max_length, especially when use char unit for
                english modeling
            token_min_length: drop utterance which is
                less than token_max_length
            min_output_input_ratio: minimal ration of
                token_length / feats_length(10ms)
            max_output_input_ratio: maximum ration of
                token_length / feats_length(10ms)

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
	if mode == "train":
		for sample in data:
			if "semantic_token" in sample:
				new_sample_frames = sample['semantic_token'][0].shape[0]
			else:
				new_sample_frames = sample['speech_token']

			if "text_token" in sample:
				new_sample_frames += len(sample['text_token'])

			if new_sample_frames > max_length or new_sample_frames < min_length:
				print(f"skipped 1 item length={new_sample_frames}")
				continue

			sample["chorus"] = sample["chorus"].split(",")
			if not isinstance(sample["time_start"], np.ndarray):
				sample["time_start"] = [sample["time_start"]]
				sample["time_end"] = [sample["time_end"]]
			for i, t in enumerate(sample["chorus"]):
				if sample["chorus"][i] == "verse":
					sample["chorus"][i] = "verse1"

			yield sample

	if mode == "train_flow":
		for sample in data:
			if "semantic_token" in sample:
				new_sample_frames = sample['semantic_token'][0].shape[0]
			if "acoustic_token" in sample:
				target_sample_frames = sample['acoustic_token'][0].shape[0]

			if new_sample_frames > max_length or new_sample_frames < min_acoustic_length or new_sample_frames < min_length or target_sample_frames > max_acoustic_length:
				print(
					f"skipped 1 item length={new_sample_frames}, target_length={target_sample_frames}")
				continue

			yield sample

	elif mode == "inference":
		for sample in data:
			yield sample


def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
	""" Resample data.
        Inplace operation.

        Args:
            data: Iterable[{key, wav, label, sample_rate}]
            resample_rate: target resample rate

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
	for sample in data:
		assert 'sample_rate' in sample
		assert 'speech' in sample
		sample_rate = sample['sample_rate']
		waveform = sample['speech']
		if sample_rate != resample_rate:
			if sample_rate < min_sample_rate:
				continue
			sample['sample_rate'] = resample_rate
			sample['speech'] = torchaudio.transforms.Resample(
					orig_freq=sample_rate, new_freq=resample_rate)(waveform)
		max_val = sample['speech'].abs().max()
		if max_val > 1:
			sample['speech'] /= max_val
		yield sample


def truncate(data, truncate_length=24576, mode='train'):
	""" Truncate data.

        Args:
            data: Iterable[{key, wav, label, sample_rate}]
            truncate_length: truncate length

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
	for sample in data:
		waveform = sample['audio']
		if waveform.shape[1] > truncate_length:
			start = random.randint(0, waveform.shape[1] - truncate_length)
			waveform = waveform[:, start: start + truncate_length]
		else:
			waveform = torch.concat([waveform, torch.zeros(1, truncate_length -
														   waveform.shape[1])],
									dim=1)
		sample['audio'] = waveform
		yield sample


def upsample(data, resample_rate=48000, min_sample_rate=16000, mode='train',
			 n_codebook=4):
	""" Resample data.
        Inplace operation.

        Args:
            data: Iterable[{key, wav, label, sample_rate}]
            resample_rate: target resample rate

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
	for sample in data:
		assert 'semantic_token' in sample
		# TODO: unify data processing key names
		if 'acoustic_token' not in sample:
			continue

		if 'sample_rate' in sample.keys():
			sample_rate = sample['sample_rate']
		else:
			sample_rate = 24000
		token = np.array(sample['semantic_token'][0][:-1])

		# Calculate the repetition factor for resampling
		repetition_factor = int(n_codebook * resample_rate / sample_rate)
		if sample_rate != resample_rate:
			if sample_rate < min_sample_rate:
				continue
			sample['sample_rate'] = resample_rate
			sample['semantic_token'] = np.array(
					[np.repeat(token, repetition_factor)])

		yield sample

def compute_fbank(data,
				  feat_extractor,
				  mode='train'):
	""" Extract fbank

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
	for sample in data:
		assert 'sample_rate' in sample
		assert 'speech' in sample
		assert 'utt' in sample
		assert 'text_token' in sample
		waveform = sample['speech']
		mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
		sample['speech_feat'] = mat
		del sample['speech']
		yield sample


def parse_embedding(data, normalize, mode='train'):
	""" Parse utt_embedding/spk_embedding

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """

	for sample in data:
		sample['utt_embedding'] = torch.tensor(sample['utt_embedding'],
											   dtype=torch.float32)
		sample['spk_embedding'] = torch.tensor(sample['spk_embedding'],
											   dtype=torch.float32)
		if normalize:
			sample['utt_embedding'] = F.normalize(sample['utt_embedding'],
												  dim=0)
			sample['spk_embedding'] = F.normalize(sample['spk_embedding'],
												  dim=0)
		yield sample

def tokenize(data, get_tokenizer, allowed_special, mode='train'):
	""" Decode text to chars or BPE
        Inplace operation

        Args:
            data: Iterable[{key, wav, txt, sample_rate}]

        Returns:
            Iterable[{key, wav, txt, tokens, label, sample_rate}]
    """
	tokenizer = get_tokenizer()

	for sample in data:
		assert 'text' in sample
		sample['text_token'] = tokenizer.encode(sample['text'],
												allowed_special=allowed_special)
		yield sample


def shuffle(data, shuffle_size=10000, mode='train'):
	""" Local shuffle the data

        Args:
            data: Iterable[{key, feat, label}]
            shuffle_size: buffer size for shuffle

        Returns:
            Iterable[{key, feat, label}]
    """
	buf = []
	for sample in data:
		buf.append(sample)
		if len(buf) >= shuffle_size:
			random.shuffle(buf)
			for x in buf:
				yield x
			buf = []
	# The sample left over
	random.shuffle(buf)
	for x in buf:
		yield x


def sort(data, sort_size=500, mode='train'):
	""" Sort the data by feature length.
        Sort is used after shuffle and before batch, so we can group
        utts with similar lengths into a batch, and `sort_size` should
        be less than `shuffle_size`

        Args:
            data: Iterable[{key, feat, label}]
            sort_size: buffer size for sort

        Returns:
            Iterable[{key, feat, label}]
    """

	buf = []
	for sample in data:
		if sample["chorus"] == "verse":
			sample["chorus"] = "verse1"

		if sample["acoustic_token"].shape[0] == 1:
			sample["acoustic_token"] = np.concatenate(
					sample["acoustic_token"][0])
		else:
			sample["acoustic_token"] = np.concatenate(sample["acoustic_token"])

		sample["acoustic_token"] = torch.from_numpy(sample["acoustic_token"])
		buf.append(sample)
		if len(buf) >= sort_size:
			buf.sort(key=lambda x: x['acoustic_token'].size(0))
			for x in buf:
				yield x
			buf = []
	# The sample left over
	buf.sort(key=lambda x: x['acoustic_token'].size(0))
	for x in buf:
		yield x


def static_batch(data, batch_size=32):
	""" Static batch the data by `batch_size`

        Args:
            data: Iterable[{key, feat, label}]
            batch_size: batch size

        Returns:
            Iterable[List[{key, feat, label}]]
    """
	buf = []
	data_empty = True
	for sample in data:
		data_empty = False
		buf.append(sample)
		if len(buf) >= batch_size:
			yield buf
			buf = []
	if data_empty:
		raise ValueError("data is empty")
	if len(buf) > 0:
		yield buf


def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
	""" Dynamic batch the data until the total frames in batch
        reach `max_frames_in_batch`

        Args:
            data: Iterable[{key, feat, label}]
            max_frames_in_batch: max_frames in one batch

        Returns:
            Iterable[List[{key, feat, label}]]
    """
	buf = []
	longest_frames = 0
	for sample in data:
		assert 'acoustic_token' in sample
		assert isinstance(sample['acoustic_token'], torch.Tensor)

		if 'semantic_token' in sample:
			new_sample_frames = sample['semantic_token'][0].shape[0]
		else:
			new_sample_frames = sample['semantic_token']

		if "text_token" in sample:
			new_sample_frames += len(sample['text_token'])

		longest_frames = max(longest_frames, new_sample_frames)
		frames_after_padding = longest_frames * (len(buf) + 1)

		if frames_after_padding > max_frames_in_batch:
			if len(buf) > 0:
				yield buf
			buf = [sample]
			longest_frames = new_sample_frames
		else:
			buf.append(sample)
	if len(buf) > 0:
		yield buf


def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000,
		  mode='train'):
	""" Wrapper for static/dynamic batch
    """
	if mode == 'inference':
		return static_batch(data, 1)
	elif mode == 'processing':
		return static_batch(data, batch_size)
	else:
		if batch_type == 'static':
			return static_batch(data, batch_size)
		elif batch_type == 'dynamic':
			return dynamic_batch(data, max_frames_in_batch)
		else:
			logging.fatal('Unsupported batch type {}'.format(batch_type))


def padding(data, mode='train'):
	""" Padding the data into training data

        Args:
            data: Iterable[List[{key, feat, label}]]

        Returns:
            Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
    """
	if mode == "train":
		for sample in data:
			assert isinstance(sample, list)
			if len(sample) != 0:
				acoustic_feat_len = torch.tensor(
						[x['acoustic_token'].size(0) for x in sample],
						dtype=torch.int32)
				order = torch.argsort(acoustic_feat_len, descending=True)
				utts = [sample[i]['utt'] for i in order]
				acoustic_token = [
					sample[i]['acoustic_token'].clone().to(torch.int32) for i in
					order]
				acoustic_token_len = torch.tensor(
						[i.size(0) for i in acoustic_token], dtype=torch.int32)

				acoustic_token = pad_sequence(acoustic_token,
											  batch_first=True,
											  padding_value=0)

				text = [sample[i]['text'] for i in order]
				text_token = [torch.tensor(sample[i]['text_token']).long() for i
							  in order]
				text_token_len = torch.tensor([i.size(0) for i in text_token],
											  dtype=torch.int32)
				text_token = pad_sequence(text_token, batch_first=True,
										  padding_value=0)
				time_start = torch.tensor(
						[sample[i]['time_start'] for i in order])
				time_end = torch.tensor([sample[i]['time_end'] for i in order])

				if isinstance(sample[0]['chorus'], str):
					chorus = torch.tensor(
							[CHORUS[sample[i]['chorus']] for i in order])
				else:
					chorus = [
						torch.tensor([CHORUS[t] for t in sample[i]['chorus']])
						for i in order]
					chorus = pad_sequence(chorus, batch_first=True,
										  padding_value=-1)

				batch = {
					"utts"              : utts,
					"acoustic_token"    : acoustic_token,
					"acoustic_token_len": acoustic_token_len,
					"time_start"        : time_start,
					"time_end"          : time_end,
					"chorus"            : chorus,
					"text"              : text,
					"text_token"        : text_token,
					"text_token_len"    : text_token_len,
				}

				if "semantic_token" in sample[0]:
					semantic_token = [
						torch.tensor(sample[i]['semantic_token'][0],
									 dtype=torch.int32) for i in order]
					semantic_token_len = torch.tensor(
							[i.size(0) for i in semantic_token],
							dtype=torch.int32)
					semantic_token = pad_sequence(semantic_token,
												  batch_first=True,
												  padding_value=0)
					batch.update({"semantic_token"    : semantic_token,
								  "semantic_token_len": semantic_token_len})

				yield batch
			else:
				logging.info("WARNING: sample is empty []!")

	elif mode == "inference":
		for sample in data:
			assert isinstance(sample, list)
			utts = [sample[i]['utt'] for i in range(len(sample))]
			text = [sample[i]['text'] for i in range(len(sample))]
			text_token = [torch.tensor(sample[i]['text_token']).long() for i in
						  range(len(sample))]
			text_token_len = torch.tensor([i.size(0) for i in text_token],
										  dtype=torch.int32)
			text_token = pad_sequence(text_token, batch_first=True,
									  padding_value=0)
			time_start = torch.tensor(
					[sample[i]['time_start'] for i in range(len(sample))])
			time_end = torch.tensor(
					[sample[i]['time_end'] for i in range(len(sample))])

			if isinstance(sample[0]['chorus'], str):
				chorus = torch.tensor([CHORUS[sample[i]['chorus']] for i in
									   range(len(sample))])
			else:
				chorus = [torch.tensor([CHORUS[t] for t in sample[i]['chorus']])
						  for i in range(len(sample))]
				chorus = pad_sequence(chorus, batch_first=True,
									  padding_value=-1)

			if "acoustic_token" in sample[0]:
				acoustic_token = [
					sample[i]['acoustic_token'].clone().to(torch.int32) for i in
					range(len(sample))]
				acoustic_token_len = torch.tensor(
						[i.size(0) for i in acoustic_token], dtype=torch.int32)
				acoustic_token = pad_sequence(acoustic_token,
											  batch_first=True,
											  padding_value=0)
			else:
				acoustic_token = None
				acoustic_token_len = None

			batch = {
				"utts"              : utts,
				"acoustic_token"    : acoustic_token,
				"acoustic_token_len": acoustic_token_len,
				"time_start"        : time_start,
				"time_end"          : time_end,
				"chorus"            : chorus,
				"text"              : text,
				"text_token"        : text_token,
				"text_token_len"    : text_token_len,
			}

			if "semantic_token" in sample[0]:
				semantic_token = [torch.tensor(sample[i]['semantic_token'][0],
											   dtype=torch.int32) for i in
								  range(len(sample))]
				semantic_token_len = torch.tensor(
						[i.size(0) for i in semantic_token], dtype=torch.int32)
				semantic_token = pad_sequence(semantic_token,
											  batch_first=True,
											  padding_value=0)
				batch.update({"semantic_token"    : semantic_token,
							  "semantic_token_len": semantic_token_len})

			yield batch