# Will be fixed soon, but meanwhile:
import os
if os.getenv('SPACES_ZERO_GPU') == "true":
    os.environ['SPACES_ZERO_GPU'] = "1"

import gradio as gr
import random
import torch
import os
from torch import inference_mode
from typing import Optional, List
import numpy as np
from models import load_model
import utils
import spaces
import huggingface_hub
from inversion_utils import inversion_forward_process, inversion_reverse_process


LDM2 = "cvssp/audioldm2"
MUSIC = "cvssp/audioldm2-music"
LDM2_LARGE = "cvssp/audioldm2-large"
STABLEAUD = "chaowenguo/stable-audio-open-1.0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ldm2 = load_model(model_id=LDM2, device=device)
ldm2_large = load_model(model_id=LDM2_LARGE, device=device)
ldm2_music = load_model(model_id=MUSIC, device=device)
ldm_stableaud = load_model(model_id=STABLEAUD, device=device, token=os.getenv('PRIV_TOKEN'))


def randomize_seed_fn(seed, randomize_seed):
    if randomize_seed:
        seed = random.randint(0, np.iinfo(np.int32).max)
    torch.manual_seed(seed)
    return seed


def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute):
    # ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)

    with inference_mode():
        w0 = ldm_stable.vae_encode(x0)

    # find Zs and wts - forward process
    _, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1,
                                                       prompts=[prompt_src],
                                                       cfg_scales=[cfg_scale_src],
                                                       num_inference_steps=num_diffusion_steps,
                                                       numerical_fix=True,
                                                       duration=duration,
                                                       save_compute=save_compute)
    return zs, wts, extra_info


def sample(ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute):
    # reverse process (via Zs and wT)
    tstart = torch.tensor(tstart, dtype=torch.int)
    w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart,
                                      etas=1., prompts=[prompt_tar],
                                      neg_prompts=[""], cfg_scales=[cfg_scale_tar],
                                      zs=zs[:int(tstart)],
                                      duration=duration,
                                      extra_info=extra_info,
                                      save_compute=save_compute)

    # vae decode image
    with inference_mode():
        x0_dec = ldm_stable.vae_decode(w0)

    if 'stable-audio' not in ldm_stable.model_id:
        if x0_dec.dim() < 4:
            x0_dec = x0_dec[None, :, :, :]

        with torch.no_grad():
            audio = ldm_stable.decode_to_mel(x0_dec)
    else:
        audio = x0_dec.squeeze(0).T

    return (ldm_stable.get_sr(), audio.squeeze().cpu().numpy())


def get_duration(input_audio,
                 model_id: str,
                 do_inversion: bool,
                 wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], extra_info: Optional[List],
                 saved_inv_model: str,
                 source_prompt: str = "",
                 target_prompt: str = "",
                 steps: int = 200,
                 cfg_scale_src: float = 3.5,
                 cfg_scale_tar: float = 12,
                 t_start: int = 45,
                 randomize_seed: bool = True,
                 save_compute: bool = True,
                 oauth_token: Optional[gr.OAuthToken] = None):
    if model_id == LDM2:
        factor = 1
    elif model_id == LDM2_LARGE:
        factor = 2.5
    elif model_id == STABLEAUD:
        factor = 3.2
    else:  # MUSIC
        factor = 1

    forwards = 0
    if do_inversion or randomize_seed:
        forwards = steps if source_prompt == "" else steps * 2  # x2 when there is a prompt text
    forwards += int(t_start / 100 * steps) * 2

    duration = min(utils.get_duration(input_audio), utils.MAX_DURATION)
    time_for_maxlength = factor * forwards * 0.15  # 0.25 is the time per forward pass
    
    if model_id != STABLEAUD:
        time_for_maxlength = time_for_maxlength / utils.MAX_DURATION * duration
        
    print('expected time:', time_for_maxlength)
    spare_time = 5
    return max(10, time_for_maxlength + spare_time)


def verify_model_params(model_id: str, input_audio, src_prompt: str, tar_prompt: str, cfg_scale_src: float,
                        oauth_token: gr.OAuthToken | None):
    if input_audio is None:
        raise gr.Error('Input audio missing!')

    if tar_prompt == "":
        raise gr.Error("Please provide a target prompt to edit the audio.")

    if src_prompt != "":
        if model_id == STABLEAUD and cfg_scale_src != 1:
            gr.Info("Consider using Source Guidance Scale=1 for Stable Audio Open 1.0.")
        elif model_id != STABLEAUD and cfg_scale_src != 3:
            gr.Info(f"Consider using Source Guidance Scale=3 for {model_id}.")

    if model_id == STABLEAUD:
        if oauth_token is None:
            raise gr.Error("You must be logged in to use Stable Audio Open 1.0. Please log in and try again.")
        try:
            huggingface_hub.get_hf_file_metadata(huggingface_hub.hf_hub_url(STABLEAUD, 'transformer/config.json'),
                                                 token=oauth_token.token)
            print('Has Access')
        # except huggingface_hub.utils._errors.GatedRepoError:
        except huggingface_hub.errors.GatedRepoError:
            raise gr.Error("You need to accept the license agreement to use Stable Audio Open 1.0. "
                           "Visit the <a href='https://huggingface.co/stabilityai/stable-audio-open-1.0'>"
                           "model page</a> to get access.")


@spaces.GPU(duration=get_duration)
def edit(input_audio,
         model_id: str,
         do_inversion: bool,
         wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], extra_info: Optional[List],
         saved_inv_model: str,
         source_prompt: str = "",
         target_prompt: str = "",
         steps: int = 200,
         cfg_scale_src: float = 3.5,
         cfg_scale_tar: float = 12,
         t_start: int = 45,
         randomize_seed: bool = True,
         save_compute: bool = True,
         oauth_token: Optional[gr.OAuthToken] = None):
    print(model_id)
    if model_id == LDM2:
        ldm_stable = ldm2
    elif model_id == LDM2_LARGE:
        ldm_stable = ldm2_large
    elif model_id == STABLEAUD:
        ldm_stable = ldm_stableaud
    else:  # MUSIC
        ldm_stable = ldm2_music

    ldm_stable.model.scheduler.set_timesteps(steps, device=device)

    # If the inversion was done for a different model, we need to re-run the inversion
    if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
        do_inversion = True

    if input_audio is None:
        raise gr.Error('Input audio missing!')
    x0, _, duration = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device,
                                       stft=('stable-audio' not in ldm_stable.model_id), model_sr=ldm_stable.get_sr())
    if wts is None or zs is None:
        do_inversion = True

    if do_inversion or randomize_seed:  # always re-run inversion
        zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
                                                        num_diffusion_steps=steps,
                                                        cfg_scale_src=cfg_scale_src,
                                                        duration=duration,
                                                        save_compute=save_compute)
        wts = wts_tensor
        zs = zs_tensor
        extra_info = extra_info_list
        saved_inv_model = model_id
        do_inversion = False
    else:
        wts_tensor = wts.to(device)
        zs_tensor = zs.to(device)
        extra_info_list = [e.to(device) for e in extra_info if e is not None]

    output = sample(ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt,
                    tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration,
                    save_compute=save_compute)

    return output, wts.cpu(), zs.cpu(), [e.cpu() for e in extra_info if e is not None], saved_inv_model, do_inversion
    # return output, wtszs_file, saved_inv_model, do_inversion


def get_example():
    case = [
        ['Examples/Beethoven.mp3',
         '',
         'A recording of an arcade game soundtrack.',
         45,
         'cvssp/audioldm2-music',
         '27s',
         'Examples/Beethoven_arcade.mp3',
         ],
        ['Examples/Beethoven.mp3',
         'A high quality recording of wind instruments and strings playing.',
         'A high quality recording of a piano playing.',
         45,
         'cvssp/audioldm2-music',
         '27s',
         'Examples/Beethoven_piano.mp3',
         ],
        ['Examples/Beethoven.mp3',
         '',
         'Heavy Rock.',
         40,
         'stabilityai/stable-audio-open-1.0',
         '27s',
         'Examples/Beethoven_rock.mp3',
         ],
        ['Examples/ModalJazz.mp3',
         'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.',
         'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.',
         45,
         'cvssp/audioldm2-music',
         '106s',
         'Examples/ModalJazz_banjo.mp3',],
        ['Examples/Shadows.mp3',
         '',
         '8-bit arcade game soundtrack.',
         40,
         'stabilityai/stable-audio-open-1.0',
         '34s',
         'Examples/Shadows_arcade.mp3',],
        ['Examples/Cat.mp3',
         '',
         'A dog barking.',
         75,
         'cvssp/audioldm2-large',
         '10s',
         'Examples/Cat_dog.mp3',]
    ]
    return case


intro = """
<h1 style="font-weight: 1000; text-align: center; margin: 0px;"> ZETA Editing ๐ŸŽง </h1>
<h2 style="font-weight: 1000; text-align: center; margin: 0px;">
    Zero-Shot Text-Based Audio Editing Using DDPM Inversion ๐ŸŽ›๏ธ </h2>
<h3 style="margin-top: 0px; margin-bottom: 10px; text-align: center;">
    <a href="https://arxiv.org/abs/2402.10009">[Paper]</a>&nbsp;|&nbsp;
    <a href="https://hilamanor.github.io/AudioEditing/">[Project page]</a>&nbsp;|&nbsp;
    <a href="https://github.com/HilaManor/AudioEditingCode">[Code]</a>
</h3>

<p style="font-size: 1rem; line-height: 1.2em;">
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
<a href="https://huggingface.co/spaces/hilamanor/audioEditing?duplicate=true">
<img style="margin-top: 0em; margin-bottom: 0em; display:inline" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" >
</a>
</p>
<p style="margin: 0px;">
<b>NEW - 15.10.24:</b> You can now edit using <b>Stable Audio Open 1.0</b>.
You must be <b>logged in</b> after accepting the
<b><a href="https://huggingface.co/stabilityai/stable-audio-open-1.0">license agreement</a></b> to use it.</br>
</p>
<ul style="padding-left:40px; line-height:normal;">
<li style="margin: 0px;">Prompts behave differently - e.g.,
try "8-bit arcade" directly instead of "a recording of...". Check out the new examples below!</li>
<li style="margin: 0px;">Try to play around <code>T-start=40%</code>.</li>
<li style="margin: 0px;">Under "More Options": Use <code>Source Guidance Scale=1</code>, 
and you can try fewer timesteps (even 20!).</li>
<li style="margin: 0px;">Stable Audio Open is a general-audio model.
For better music editing, duplicate the space and change to a
<a href="https://huggingface.co/models?other=base_model:finetune:stabilityai/stable-audio-open-1.0">
fine-tuned model for music</a>.</li>
</ul>
<p>
<b>NEW - 15.10.24:</b> Parallel editing is enabled by default.
To disable, uncheck <code>Efficient editing</code> under "More Options".
Saves a bit of time.
</p>
"""


help = """
<div style="font-size:medium">
<b>Instructions:</b><br>
<ul style="line-height: normal">
<li>You must provide an input audio and a target prompt to edit the audio. </li>
<li>T<sub>start</sub> is used to control the tradeoff between fidelity to the original signal and text-adhearance.
Lower value -> favor fidelity. Higher value -> apply a stronger edit.</li>
<li>Make sure that you use a model version that is suitable for your input audio.
For example, use AudioLDM2-music for music while AudioLDM2-large for general audio.
</li>
<li>You can additionally provide a source prompt to guide even further the editing process.</li>
<li>Longer input will take more time.</li>
<li><strong>Unlimited length</strong>: This space automatically trims input audio to a maximum length of 30 seconds.
For unlimited length, duplicated the space, and change the
<code style="display:inline; background-color: lightgrey;">MAX_DURATION</code> parameter
inside <code style="display:inline; background-color: lightgrey;">utils.py</code>
to <code style="display:inline; background-color: lightgrey;">None</code>.
</li>
</ul>
</div>

"""

css = '.gradio-container {max-width: 1000px !important; padding-top: 1.5rem !important;}' \
      '.audio-upload .wrap {min-height: 0px;}'

# with gr.Blocks(css='style.css') as demo:
with gr.Blocks(css=css) as demo:
    def reset_do_inversion(do_inversion_user, do_inversion):
        # do_inversion = gr.State(value=True)
        do_inversion = True
        do_inversion_user = True
        return do_inversion_user, do_inversion

    # handle the case where the user clicked the button but the inversion was not done
    def clear_do_inversion_user(do_inversion_user):
        do_inversion_user = False
        return do_inversion_user

    def post_match_do_inversion(do_inversion_user, do_inversion):
        if do_inversion_user:
            do_inversion = True
            do_inversion_user = False
        return do_inversion_user, do_inversion

    gr.HTML(intro)

    wts = gr.State()
    zs = gr.State()
    extra_info = gr.State()
    saved_inv_model = gr.State()
    do_inversion = gr.State(value=True)  # To save some runtime when editing the same thing over and over
    do_inversion_user = gr.State(value=False)

    with gr.Group():
        gr.Markdown("๐Ÿ’ก **note**: input longer than **30 sec** is automatically trimmed "
                    "(for unlimited input, see the Help section below)")
        with gr.Row(equal_height=True):
            input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath",
                                   editable=True, label="Input Audio", interactive=True, scale=1, format='wav',
                                   elem_classes=['audio-upload'])
            output_audio = gr.Audio(label="Edited Audio", interactive=False, scale=1, format='wav')

    with gr.Row():
        tar_prompt = gr.Textbox(label="Prompt", info="Describe your desired edited output",
                                placeholder="a recording of a happy upbeat arcade game soundtrack",
                                lines=2, interactive=True)

    with gr.Row():
        t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label="T-start (%)", interactive=True, scale=3,
                            info="Lower T-start -> closer to original audio. Higher T-start -> stronger edit.")
        model_id = gr.Dropdown(label="Model Version",
                               choices=[LDM2,
                                        LDM2_LARGE,
                                        MUSIC,
                                        STABLEAUD],
                               info="Choose a checkpoint suitable for your audio and edit",
                               value="cvssp/audioldm2-music", interactive=True, type="value", scale=2)
    with gr.Row():
        submit = gr.Button("Edit", variant="primary", scale=3)
        gr.LoginButton(value="Login to HF (For Stable Audio)", scale=1)

    with gr.Accordion("More Options", open=False):
        with gr.Row():
            src_prompt = gr.Textbox(label="Source Prompt", lines=2, interactive=True,
                                    info="Optional: Describe the original audio input",
                                    placeholder="A recording of a happy upbeat classical music piece",)

        with gr.Row(equal_height=True):
            cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None,
                                      label="Source Guidance Scale", interactive=True, scale=1)
            cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None,
                                      label="Target Guidance Scale", interactive=True, scale=1)
            steps = gr.Number(value=50, step=1, minimum=10, maximum=300,
                              info="Higher values (e.g. 200) yield higher-quality generation.",
                              label="Num Diffusion Steps", interactive=True, scale=2)
        with gr.Row(equal_height=True):
            seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
            randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
            save_compute = gr.Checkbox(label='Efficient editing', value=True)
            length = gr.Number(label="Length", interactive=False, visible=False)

    with gr.Accordion("Help๐Ÿ’ก", open=False):
        gr.HTML(help)

    submit.click(
            fn=verify_model_params,
            inputs=[model_id, input_audio, src_prompt, tar_prompt, cfg_scale_src],
            outputs=[]
        ).success(
            fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=[seed], queue=False
        ).then(
            fn=clear_do_inversion_user, inputs=[do_inversion_user], outputs=[do_inversion_user]
        ).then(
            fn=edit,
            inputs=[input_audio,
                    model_id,
                    do_inversion,
                    wts, zs, extra_info,
                    saved_inv_model,
                    src_prompt,
                    tar_prompt,
                    steps,
                    cfg_scale_src,
                    cfg_scale_tar,
                    t_start,
                    randomize_seed,
                    save_compute,
                    ],
            outputs=[output_audio, wts, zs, extra_info, saved_inv_model, do_inversion]
        ).success(
            fn=post_match_do_inversion,
            inputs=[do_inversion_user, do_inversion],
            outputs=[do_inversion_user, do_inversion]
        )

    # If sources changed we have to rerun inversion
    gr.on(
        triggers=[input_audio.change, src_prompt.change, model_id.change, cfg_scale_src.change,
                  steps.change, save_compute.change],
        fn=reset_do_inversion,
        inputs=[do_inversion_user, do_inversion],
        outputs=[do_inversion_user, do_inversion]
    )

    gr.Examples(
        label="Examples",
        examples=get_example(),
        inputs=[input_audio, src_prompt, tar_prompt, t_start, model_id, length, output_audio],
        outputs=[output_audio]
    )

    demo.queue()
    demo.launch(state_session_capacity=15)