import os

# set CUDA_MODULE_LOADING=LAZY to speed up the serverless function
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
# set SAFETENSORS_FAST_GPU=1 to speed up the serverless function
os.environ["SAFETENSORS_FAST_GPU"] = "1"
import cv2
import torch
import time
import imageio
import numpy as np
from tqdm import tqdm
import moviepy.editor as mp
import torch

from audio import load_wav, melspectrogram
from fete_model import FETE_model
from preprocess_videos import face_detect, load_from_npz

fps = 25
mel_idx_multiplier = 80.0 / fps

mel_step_size = 16
batch_size = 64 if torch.cuda.is_available() else 4
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} for inference.".format(device))
use_fp16 = True if torch.cuda.is_available() else False
print("Using FP16 for inference.") if use_fp16 else None
torch.backends.cudnn.benchmark = True if device == "cuda" else False


def init_model():
    checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints/obama-fp16.safetensors")
    model = FETE_model()
    if checkpoint_path.endswith(".pth") or checkpoint_path.endswith(".ckpt"):
        if device == "cuda":
            checkpoint = torch.load(checkpoint_path)
        else:
            checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
        s = checkpoint["state_dict"]
    else:
        from safetensors import safe_open

        s = {}
        with safe_open(checkpoint_path, framework="pt", device=device) as f:
            for key in f.keys():
                s[key] = f.get_tensor(key)
    new_s = {}
    for k, v in s.items():
        new_s[k.replace("module.", "")] = v
    model.load_state_dict(new_s)

    model = model.to(device)
    model.eval()
    print("Model loaded")
    if use_fp16:
        for name, module in model.named_modules():
            if ".query_conv" in name or ".key_conv" in name or ".value_conv" in name:
                # keep attention layers in full precision to avoid error
                module.to(torch.float)
            else:
                module.to(torch.half)
        print("Model converted to half precision to accelerate inference")
    return model


def make_mask(image_size=256, border_size=32):
    mask_bar = np.linspace(1, 0, border_size).reshape(1, -1).repeat(image_size, axis=0)
    mask = np.zeros((image_size, image_size), dtype=np.float32)
    mask[-border_size:, :] += mask_bar.T[::-1]
    mask[:, :border_size] = mask_bar
    mask[:, -border_size:] = mask_bar[:, ::-1]
    mask[-border_size:, :][mask[-border_size:, :] < 0.6] = 0.6
    mask = np.stack([mask] * 3, axis=-1).astype(np.float32)
    return mask


face_mask = make_mask()


def blend_images(foreground, background):
    # Blend the foreground and background images using the mask
    temp_mask = cv2.resize(face_mask, (foreground.shape[1], foreground.shape[0]))
    blended = cv2.multiply(foreground.astype(np.float32), temp_mask)
    blended += cv2.multiply(background.astype(np.float32), 1 - temp_mask)
    blended = np.clip(blended, 0, 255).astype(np.uint8)
    return blended


def smooth_coord(last_coord, current_coord, factor=0.4):
    change = np.array(current_coord) - np.array(last_coord)
    change = change * factor
    return (np.array(last_coord) + np.array(change)).astype(int).tolist()


def add_black(imgs):
    for i in range(len(imgs)):
        # print('x', imgs[i].shape)
        imgs[i] = cv2.vconcat(
            [np.zeros((100, imgs[i].shape[1], 3), dtype=np.uint8), imgs[i], np.zeros((20, imgs[i].shape[1], 3), dtype=np.uint8)]
        )
        # imgs[i] = cv2.hconcat([np.zeros((imgs[i].shape[0], 100, 3), dtype=np.uint8), imgs[i], np.zeros((imgs[i].shape[0], 100, 3), dtype=np.uint8)])[:480+150,740-100:-740+100,:]

        # print('xx', imgs[i].shape)
    return imgs


def remove_black(img):
    return img[100:-20]


def resize_length(input_attributes, length):
    input_attributes = np.array(input_attributes)
    resized_attributes = [input_attributes[int(i_ * (input_attributes.shape[0] / length))] for i_ in range(length)]
    return np.array(resized_attributes).T


def output_chunks(input_attributes):
    output_chunks = []
    len_ = len(input_attributes[0])

    i = 0
    # print(mel.shape, pose.shape)
    # (80, 801) (3, 801)
    while 1:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len_:
            output_chunks.append(input_attributes[:, len_ - mel_step_size :])
            break
        output_chunks.append(input_attributes[:, start_idx : start_idx + mel_step_size])
        i += 1
    return output_chunks


def prepare_data(face_path, audio_path, pose, emotion, blink, img_size=256, pads=[0, 0, 0, 0]):
    if os.path.isfile(face_path) and face_path.split(".")[1] in ["jpg", "png", "jpeg"]:
        static = True
        full_frames = [cv2.imread(face_path)]
    else:
        static = False
        video_stream = cv2.VideoCapture(face_path)

        # print('Reading video frames...')
        full_frames = []
        while 1:
            still_reading, frame = video_stream.read()
            if not still_reading:
                video_stream.release()
                break
            full_frames.append(frame)
    print("Number of frames available for inference: " + str(len(full_frames)))

    wav = load_wav(audio_path, 16000)
    mel = melspectrogram(wav)
    # take half
    len_ = mel.shape[1]  #  //2
    mel = mel[:, :len_]
    # print('>>>', mel.shape)

    pose = resize_length(pose, len_)
    emotion = resize_length(emotion, len_)
    blink = resize_length(blink, len_)

    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError("Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again")

    mel_chunks = output_chunks(mel)
    pose_chunks = output_chunks(pose)
    emotion_chunks = output_chunks(emotion)
    blink_chunks = output_chunks(blink)

    gen = datagen(face_path, full_frames, mel_chunks, pose_chunks, emotion_chunks, blink_chunks, static=static, img_size=img_size, pads=pads)
    steps = int(np.ceil(float(len(mel_chunks)) / batch_size))

    return gen, steps


def preprocess_batch(batch):
    return torch.FloatTensor(np.reshape(batch, [len(batch), 1, batch[0].shape[0], batch[0].shape[1]])).to(device)


def datagen(face_path, frames, mels, poses, emotions, blinks, static=False, img_size=256, pads=[0, 0, 0, 0]):
    img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], []
    scale_factor = img_size // 128

    # print("Length of mel chunks: {}".format(len(mel_chunks)))
    frames = frames[: len(mels)]
    frames = add_black(frames)
    try:
        video_name = os.path.basename(face_path).split(".")[0]
        coords = load_from_npz(video_name)
        face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]

    except Exception as e:
        print("No existing coords found, running face detection...", "Error: ", e)
        if not static:
            coords = face_detect(frames, pads)
            face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]
        else:
            coords = face_detect([frames[0]], pads)
            face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]

    face_det_results = face_det_results[: len(mels)]

    while len(frames) < len(mels):
        face_det_results = face_det_results + face_det_results[::-1]
        frames = frames + frames[::-1]
    else:
        face_det_results = face_det_results[: len(mels)]
        frames = frames[: len(mels)]

    for i in range(len(mels)):
        idx = 0 if static else i % len(frames)
        frame_to_save = frames[idx].copy()
        face, coords = face_det_results[idx].copy()
        face = cv2.resize(face, (img_size, img_size))

        img_batch.append(face)
        mel_batch.append(mels[i])
        pose_batch.append(poses[i])
        emotion_batch.append(emotions[i])
        blink_batch.append(blinks[i])
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)

        # print(m.shape, poses[i].shape)
        # (80, 16) (3, 16)
        if len(img_batch) >= batch_size:
            img_masked = np.asarray(img_batch).copy()

            img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0

            img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
            img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)

            mel_batch = preprocess_batch(mel_batch)
            pose_batch = preprocess_batch(pose_batch)
            emotion_batch = preprocess_batch(emotion_batch)
            blink_batch = preprocess_batch(blink_batch)

            if use_fp16:
                yield (
                    img_batch.half(),
                    mel_batch.half(),
                    pose_batch.half(),
                    emotion_batch.half(),
                    blink_batch.half(),
                ), frame_batch, coords_batch
            else:
                yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch
            img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], []

    if len(img_batch) > 0:
        img_masked = np.asarray(img_batch).copy()

        img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0

        img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
        img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)

        mel_batch = preprocess_batch(mel_batch)
        pose_batch = preprocess_batch(pose_batch)
        emotion_batch = preprocess_batch(emotion_batch)
        blink_batch = preprocess_batch(blink_batch)

        if use_fp16:
            yield (img_batch.half(), mel_batch.half(), pose_batch.half(), emotion_batch.half(), blink_batch.half()), frame_batch, coords_batch
        else:
            yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch


def infenrece(model, face_path, audio_path, pose, emotion, blink, preview=False):
    timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime(time.time()))
    gen, steps = prepare_data(face_path, audio_path, pose, emotion, blink)
    steps = 1 if preview else steps
    # duration = librosa.get_duration(filename=audio_path)

    if preview:
        outfile = "/tmp/{}.jpg".format(timestamp)
    else:
        outfile = "/tmp/{}.mp4".format(timestamp)
        tmp_video = "/tmp/temp_{}.mp4".format(timestamp)
        writer = (
            imageio.get_writer(tmp_video, fps=fps, codec="libx264", quality=10, pixelformat="yuv420p", macro_block_size=1)
            if not preview
            else None
        )
    # print('Generating frames...', outfile, steps)
    for inputs, frames, coords in tqdm(gen, total=steps):
        with torch.no_grad():
            pred = model(*inputs)

        pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0

        for p, f, c in zip(pred, frames, coords):
            y1, y2, x1, x2 = c
            y1, y2, x1, x2 = int(y1), int(y2), int(x1), int(x2)
            y = round(y2 - y1)
            x = round(x2 - x1)
            p = cv2.resize(p.astype(np.uint8), (x, y))

            try:
                f[y1 : y1 + y, x1 : x1 + x] = blend_images(f[y1 : y1 + y, x1 : x1 + x], p)
            except Exception as e:
                print(e)
                f[y1 : y1 + y, x1 : x1 + x] = p
            f = remove_black(f)
            if preview:
                cv2.imwrite(outfile, f, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
                return outfile
            writer.append_data(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
    writer.close()
    video_clip = mp.VideoFileClip(tmp_video)
    audio_clip = mp.AudioFileClip(audio_path)
    video_clip = video_clip.set_audio(audio_clip)
    video_clip.write_videofile(outfile, codec="libx264")

    print("Saved to {}".format(outfile) if os.path.exists(outfile) else "Failed to save {}".format(outfile))
    try:
        os.remove(tmp_video)
        del video_clip
        del audio_clip
        del gen
    except:
        pass
    return outfile


if __name__ == "__main__":
    model = init_model()

    from attributtes_utils import input_pose, input_emotion, input_blink

    pose = input_pose()
    emotion = input_emotion()
    blink = input_blink()
    audio_path = "./assets/sample.wav"
    face_path = "./assets/sample.mp4"

    infenrece(model, face_path, audio_path, pose, emotion, blink)