import fam.llm.fast_inference_utils
from fam.llm.fast_inference import TTS as FAMTTS
from fam.llm.inference import Model as FAMModel
from fam.llm.inference import InferenceConfig
from fam.llm.adapters.tilted_encodec import TiltedEncodec
from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook
from fam.llm.decoders import EncodecDecoder
from fam.llm.enhancers import get_enhancer
from fam.llm.utils import get_default_dtype, get_device
from fam.llm.fast_model import Transformer
from fam.llm.model import GPT, GPTConfig
from fam.quantiser.text.tokenise import TrainedBPETokeniser
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder as FAMSpeakerEncoder
from fam.quantiser.audio.speaker_encoder.model import mel_n_channels, model_hidden_size, model_embedding_size, model_num_layers

import os
from pathlib import Path
from typing import Optional, Union
from json import load, dump
from base64 import b64encode, b64decode

import torch
from torch import nn
from huggingface_hub import snapshot_download, HfFileSystem
from safetensors.torch import load_model, save_model

def convert_to_safetensors(
	stage1_path: str,
	stage2_path: str,
	spk_emb_ckpt_path: str,
	precision: torch.dtype,
	output_path: str
):
	config_second_stage = InferenceConfig(
		ckpt_path=stage2_path,
		num_samples=1,
		seed=0,
		device='cpu',
		dtype='float16' if precision == torch.float16 else 'bfloat16',
		compile=False,
		init_from='resume',
		output_dir='.',
	)
	data_adapter_second_stage = TiltedEncodec(end_of_audio_token=512)
	stage2_model = Model(config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode)

	stage2_checkpoint = torch.load(stage2_path, map_location='cpu')
	stage2_state_dict = stage2_checkpoint['model']
	unwanted_prefix = '_orig_mod.'
	for k in stage2_state_dict.keys():
		if k.startswith(unwanted_prefix):
			stage2_state_dict[k[len(unwanted_prefix) :]] = stage2_state_dict.pop(k)
	save_model(stage2_model.model, os.path.join(output_path, 'second_stage.safetensors'))

	stage1_model, tokenizer, smodel = fam.llm.fast_inference_utils._load_model(stage1_path, spk_emb_ckpt_path, 'cpu', precision)
	tokenizer_info = torch.load(stage1_path, map_location='cpu').get('meta', {}).get('tokenizer', {})
	save_model(stage1_model, os.path.join(output_path, 'first_stage.safetensors'))
	save_model(smodel, os.path.join(output_path, 'speaker_encoder.safetensors'))

	with open(os.path.join(output_path, 'config.json'), 'w') as f:
		tokenizer_info['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in tokenizer_info['mergeable_ranks'].items()}
		stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'].items()}
		dump({
			'model_name': 'metavoice-1B-v0.1',
			'stage1': {
				'tokenizer_info': tokenizer_info
			},
			'stage2': {
				'config': stage2_checkpoint['config'],
				'meta': stage2_checkpoint['meta'],
				'model_args': stage2_checkpoint['model_args']
			}
		}, f)

class SpeakerEncoder(FAMSpeakerEncoder):
	def __init__(
		self,
		weights_fpath: str,
		device: Optional[Union[str, torch.device]] = None,
		verbose: bool = True,
		eval: bool = False,
	):
		nn.Module.__init__(self)

		# Define the network
		self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
		self.linear = nn.Linear(model_hidden_size, model_embedding_size)
		self.relu = nn.ReLU()

		# Get the target device
		if device is None:
			device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
		elif isinstance(device, str):
			device = torch.device(device)
		self.device = device

		weights_fpath = str(weights_fpath)
		if weights_fpath.endswith('.safetensors'):
			load_model(self, weights_fpath)
		else:
			checkpoint = torch.load(weights_fpath, map_location='cpu')
			self.load_state_dict(checkpoint['model_state'], strict=False)
		self.to(device)

		if eval:
			self.eval()

def load_safetensors_model(checkpoint_path, spk_emb_ckpt_path, device, precision):
	##### MODEL
	with torch.device(device):
		model = Transformer.from_name('metavoice-1B')
		load_model(model, checkpoint_path)
	model = model.to(device=device, dtype=precision)

	###### TOKENIZER
	with open(f'{os.path.dirname(checkpoint_path)}/config.json', 'r') as f:
		config = load(f)['stage1']
	config['tokenizer_info']['mergeable_ranks'] = {b64decode(k): v for k, v in config['tokenizer_info']['mergeable_ranks'].items()}
	tokenizer_info = config['tokenizer_info']
	tokenizer = TrainedBPETokeniser(**tokenizer_info)

	###### SPEAKER EMBEDDER
	smodel = SpeakerEncoder(
		weights_fpath=spk_emb_ckpt_path,
		device=device,
		eval=True,
		verbose=False,
	)
	return model.eval(), tokenizer, smodel

class Model(FAMModel):
	def _init_model(self):
		if self.config.init_from == 'safetensors':
			with open(f'{os.path.dirname(self.config.ckpt_path)}/config.json', 'r') as f:
				config = load(f)['stage2']
			self.vocab_sizes = config['model_args']['vocab_sizes']
			self.checkpoint_config = config['config']
			config['meta']['tokenizer']['mergeable_ranks'] = {b64decode(k): v for k, v in config['meta']['tokenizer']['mergeable_ranks'].items()}

			self.meta = config['meta']
			self.load_meta = True
			self.use_bpe_tokenizer = 'stoi' not in self.meta or 'itos' not in self.meta
			self.speaker_cond = self.meta.get('speaker_cond')

			speaker_emb_size = None
			if self.speaker_cond:
				speaker_emb_size = self.meta['speaker_emb_size']

			model_args = config['model_args']
			if 'causal' in self.checkpoint_config and self.checkpoint_config['causal'] is False:
				self._encodec_ctx_window = model_args['block_size']

			gptconf = GPTConfig(**model_args)
			self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size)
			load_model(self.model, self.config.ckpt_path)

		super()._init_model()

class MetaVoiceModel(FAMTTS):
	def __init__(self, model_name: str, *, seed: int = 1337, output_dir: str = 'outputs', enforce_safetensors: bool = True):
		self._dtype = get_default_dtype()
		self._device = get_device()

		if os.path.exists(model_name):
			if enforce_safetensors:
				assert all(x in os.listdir(model_name) for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors'
				self._model_dir = model_name
			else:
				print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html')
				self._model_dir = model_name
		else:
			if enforce_safetensors:
				fs = HfFileSystem()
				files = [os.path.basename(x) for x in fs.ls(model_name, detail=False)]
				assert all(x in files for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors'
				self._model_dir = snapshot_download(repo_id=model_name, allow_patterns='second_stage.safetensors,first_stage.safetensors,speaker_encoder.safetensors,config.json')
			else:
				print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html')
				self._model_dir = snapshot_download(repo_id=model_name)

		self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
		self.output_dir = output_dir
		os.makedirs(self.output_dir, exist_ok=True)

		is_safetensors = os.path.exists(f'{self._model_dir}/second_stage.safetensors')
		second_stage_ckpt_path = f'{self._model_dir}/{"second_stage.safetensors" if is_safetensors else "second_stage.pt"}'
		config_second_stage = InferenceConfig(
			ckpt_path=second_stage_ckpt_path,
			num_samples=1,
			seed=seed,
			device=self._device,
			dtype=self._dtype,
			compile=False,
			init_from='safetensors' if is_safetensors else 'resume',
			output_dir=self.output_dir,
		)
		data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
		self.llm_second_stage = Model(
			config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
		)

		self.enhancer = get_enhancer('df')
		self.precision = {'float16': torch.float16, 'bfloat16': torch.bfloat16}[self._dtype]
		build_model_kwargs = {
			'precision': self.precision,
			'device': self._device,
			'compile': False,
			'compile_prefill': True,
		}
		if is_safetensors:
			fam.llm.fast_inference_utils._load_model = load_safetensors_model
			checkpoint_path, spk_emb_ckpt_path = Path(f'{self._model_dir}/first_stage.safetensors'), Path(f'{self._model_dir}/speaker_encoder.safetensors')
		else:
			checkpoint_path, spk_emb_ckpt_path= Path(f'{self._model_dir}/first_stage.pt'), Path(f'{self._model_dir}/speaker_encoder.pt')

		self.model, self.tokenizer, self.smodel, self.model_size = fam.llm.fast_inference_utils.build_model(
			checkpoint_path=checkpoint_path,
			spk_emb_ckpt_path=spk_emb_ckpt_path,
			**build_model_kwargs
		)

	@torch.inference_mode()
	def generate(self, text: str, source: str = 'https://upload.wikimedia.org/wikipedia/commons/e/e1/King_Charles_Addresses_Scottish_Parliament_-_12_September_2022.flac'):
		self.synthesise(text, source)

	def save(self, path: str):
		save_model(self.model, os.path.join(path, 'first_stage.safetensors'))
		save_model(self.smodel, os.path.join(path, 'speaker_encoder.safetensors'))
		save_model(self.llm_second_stage.model, os.path.join(path, 'second_stage.safetensors'))

	@classmethod
	def from_hub(cls, path: str):
		# TODO: TEMPORARY OUTPUT DIR
		return cls(path, enforce_safetensors=True)