import torch
from torch import nn
from transformers import HubertConfig, HubertModel
import logging

# Ignore fairseq's logger
logging.getLogger("fairseq").setLevel(logging.WARNING)
logging.getLogger("torch.distributed.nn.jit.instantiator").setLevel(logging.WARNING)

from fairseq import checkpoint_utils

models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
    ["content-vec-best-legacy-500.pt"], suffix=""
)
model = models[0]
model.eval()
model.eval()


class HubertModelWithFinalProj(HubertModel):
    def __init__(self, config):
        super().__init__(config)

        self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)


# Default Config
hubert = HubertModelWithFinalProj(HubertConfig())

# huggingface: fairseq
mapping = {
    "masked_spec_embed": "mask_emb",
    "encoder.layer_norm.bias": "encoder.layer_norm.bias",
    "encoder.layer_norm.weight": "encoder.layer_norm.weight",
    "encoder.pos_conv_embed.conv.bias": "encoder.pos_conv.0.bias",
    "encoder.pos_conv_embed.conv.weight_g": "encoder.pos_conv.0.weight_g",
    "encoder.pos_conv_embed.conv.weight_v": "encoder.pos_conv.0.weight_v",
    "feature_projection.layer_norm.bias": "layer_norm.bias",
    "feature_projection.layer_norm.weight": "layer_norm.weight",
    "feature_projection.projection.bias": "post_extract_proj.bias",
    "feature_projection.projection.weight": "post_extract_proj.weight",
    "final_proj.bias": "final_proj.bias",
    "final_proj.weight": "final_proj.weight",
}

# Convert encoder
for layer in range(12):
    for j in ["q", "k", "v"]:
        mapping[
            f"encoder.layers.{layer}.attention.{j}_proj.weight"
        ] = f"encoder.layers.{layer}.self_attn.{j}_proj.weight"
        mapping[
            f"encoder.layers.{layer}.attention.{j}_proj.bias"
        ] = f"encoder.layers.{layer}.self_attn.{j}_proj.bias"

    mapping[
        f"encoder.layers.{layer}.final_layer_norm.bias"
    ] = f"encoder.layers.{layer}.final_layer_norm.bias"
    mapping[
        f"encoder.layers.{layer}.final_layer_norm.weight"
    ] = f"encoder.layers.{layer}.final_layer_norm.weight"

    mapping[
        f"encoder.layers.{layer}.layer_norm.bias"
    ] = f"encoder.layers.{layer}.self_attn_layer_norm.bias"
    mapping[
        f"encoder.layers.{layer}.layer_norm.weight"
    ] = f"encoder.layers.{layer}.self_attn_layer_norm.weight"

    mapping[
        f"encoder.layers.{layer}.attention.out_proj.bias"
    ] = f"encoder.layers.{layer}.self_attn.out_proj.bias"
    mapping[
        f"encoder.layers.{layer}.attention.out_proj.weight"
    ] = f"encoder.layers.{layer}.self_attn.out_proj.weight"

    mapping[
        f"encoder.layers.{layer}.feed_forward.intermediate_dense.bias"
    ] = f"encoder.layers.{layer}.fc1.bias"
    mapping[
        f"encoder.layers.{layer}.feed_forward.intermediate_dense.weight"
    ] = f"encoder.layers.{layer}.fc1.weight"

    mapping[
        f"encoder.layers.{layer}.feed_forward.output_dense.bias"
    ] = f"encoder.layers.{layer}.fc2.bias"
    mapping[
        f"encoder.layers.{layer}.feed_forward.output_dense.weight"
    ] = f"encoder.layers.{layer}.fc2.weight"

# Convert Conv Layers
for layer in range(7):
    mapping[
        f"feature_extractor.conv_layers.{layer}.conv.weight"
    ] = f"feature_extractor.conv_layers.{layer}.0.weight"

    if layer != 0:
        continue

    mapping[
        f"feature_extractor.conv_layers.{layer}.layer_norm.weight"
    ] = f"feature_extractor.conv_layers.{layer}.2.weight"
    mapping[
        f"feature_extractor.conv_layers.{layer}.layer_norm.bias"
    ] = f"feature_extractor.conv_layers.{layer}.2.bias"

hf_keys = set(hubert.state_dict().keys())
fair_keys = set(model.state_dict().keys())

hf_keys -= set(mapping.keys())
fair_keys -= set(mapping.values())

for i, j in zip(sorted(hf_keys), sorted(fair_keys)):
    print(i, j)

print(hf_keys, fair_keys)
print(len(hf_keys), len(fair_keys))

# try loading the weights
new_state_dict = {}
for k, v in mapping.items():
    new_state_dict[k] = model.state_dict()[v]

x = hubert.load_state_dict(new_state_dict, strict=False)
print(x)
hubert.eval()

with torch.no_grad():
    new_input = torch.randn(1, 16384)

    result1 = hubert(new_input, output_hidden_states=True)["hidden_states"][9]
    result1 = hubert.final_proj(result1)

    result2 = model.extract_features(
        **{
            "source": new_input,
            "padding_mask": torch.zeros(1, 16384, dtype=torch.bool),
            # "features_only": True,
            "output_layer": 9,
        }
    )[0]
    result2 = model.final_proj(result2)

    assert torch.allclose(result1, result2, atol=1e-3)

print("Sanity check passed")

# Save huggingface model
hubert.save_pretrained(".")
print("Saved model")