# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import requests
from tqdm import tqdm
import torch
import numpy as np
import laion_clap
from clap_module.factory import load_state_dict
import librosa
import pyloudnorm as pyln

# following documentation from https://github.com/LAION-AI/CLAP
def int16_to_float32(x):
    return (x / 32767.0).astype(np.float32)

def float32_to_int16(x):
    x = np.clip(x, a_min=-1., a_max=1.)
    return (x * 32767.).astype(np.int16)


def clap_score(id2text, audio_path, audio_files_extension='.wav', clap_model='music_audioset_epoch_15_esc_90.14.pt'):
    """
    Cosine similarity is computed between the LAION-CLAP text embedding of the given prompt and 
    the LAION-CLAP audio embedding of the generated audio. LION-CLAP: https://github.com/LAION-AI/CLAP
    
    This evaluation script assumes that audio_path files are identified with the ids in id2text.
    
    clap_score() evaluates all ids in id2text.

    GPU-based computation.

    Select one of the following models from https://github.com/LAION-AI/CLAP:
        - music_speech_audioset_epoch_15_esc_89.98.pt (used by musicgen)
        - music_audioset_epoch_15_esc_90.14.pt
        - music_speech_epoch_15_esc_89.25.pt
        - 630k-audioset-fusion-best.pt (our default, with "fusion" to handle longer inputs)

    Params:
    -- id2text: dictionary with the mapping between id (generated audio filenames in audio_path) 
                and text (prompt used to generate audio). clap_score() evaluates all ids in id2text.
    -- audio_path: path where the generated audio files to evaluate are available.
    -- audio_files_extension: files extension (default .wav) in eval_path.
    -- clap_model: choose one of the above clap_models (default: '630k-audioset-fusion-best.pt').
    Returns:
    -- CLAP-LION score
    """
    # load model
    if clap_model == 'music_speech_audioset_epoch_15_esc_89.98.pt':
        url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt'
        clap_path = 'CLAP/music_speech_audioset_epoch_15_esc_89.98.pt'
        model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base',  device='cuda')
    elif clap_model == 'music_audioset_epoch_15_esc_90.14.pt':
        url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt'
        clap_path = 'CLAP/music_audioset_epoch_15_esc_90.14.pt'
        model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base',  device='cuda')
    elif clap_model == 'music_speech_epoch_15_esc_89.25.pt':
        url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_epoch_15_esc_89.25.pt'
        clap_path = 'CLAP/music_speech_epoch_15_esc_89.25.pt'
        model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base',  device='cuda')
    elif clap_model == '630k-audioset-fusion-best.pt':
        url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-fusion-best.pt'
        clap_path = 'CLAP/630k-audioset-fusion-best.pt'
        model = laion_clap.CLAP_Module(enable_fusion=True, device='cuda')
    else:
        raise ValueError('clap_model not implemented')

    # download clap_model if not already downloaded
    if not os.path.exists(clap_path):
        print('Downloading ', clap_model, '...')
        os.makedirs(os.path.dirname(clap_path), exist_ok=True)

        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))

        with open(clap_path, 'wb') as file:
            with tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar:
                for data in response.iter_content(chunk_size=8192):
                    file.write(data)
                    progress_bar.update(len(data))

    # fixing CLAP-LION issue, see: https://github.com/LAION-AI/CLAP/issues/118
    pkg = load_state_dict(clap_path)
    pkg.pop('text_branch.embeddings.position_ids', None)
    model.model.load_state_dict(pkg)
    model.eval()

    if not os.path.isdir(audio_path):        
        raise ValueError(f'audio_path: {audio_path} does not exist')

    if id2text:   
        print('[EXTRACTING TEXT EMBEDDINGS] ')
        batch_size = 64
        text_emb = {}
        for i in tqdm(range(0, len(id2text), batch_size)):
            batch_ids = list(id2text.keys())[i:i+batch_size]
            batch_texts = [id2text[id] for id in batch_ids]
            with torch.no_grad():
                embeddings = model.get_text_embedding(batch_texts, use_tensor=True)
            for id, emb in zip(batch_ids, embeddings):
                text_emb[id] = emb

    else:
        raise ValueError('Must specify id2text')

    print('[EVALUATING GENERATIONS] ', audio_path)
    score = 0
    count = 0
    for id in tqdm(id2text.keys()):
        file_path = os.path.join(audio_path, str(id)+audio_files_extension)
        if os.path.isfile(file_path):
            with torch.no_grad():
                audio, _ = librosa.load(file_path, sr=48000, mono=True) # sample rate should be 48000
                audio = pyln.normalize.peak(audio, -1.0)
                audio = audio.reshape(1, -1) # unsqueeze (1,T)
                audio = torch.from_numpy(int16_to_float32(float32_to_int16(audio))).float()
                audio_embeddings = model.get_audio_embedding_from_data(x = audio, use_tensor=True)
            cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_emb[id].unsqueeze(0), dim=1, eps=1e-8)[0]
            print(f"{id} | CLAP score = {cosine_sim}") 
            score += cosine_sim
            count += 1

    return score / count if count > 0 else 0