# Copyright (c) OpenMMLab. All rights reserved.
import math

import pytest
import torch

from mmocr.models.textrecog.decoders import (ABILanguageDecoder,
                                             ABIVisionDecoder, BaseDecoder,
                                             NRTRDecoder, ParallelSARDecoder,
                                             ParallelSARDecoderWithBS,
                                             SequentialSARDecoder)
from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode


def _create_dummy_input():
    feat = torch.rand(1, 512, 4, 40)
    out_enc = torch.rand(1, 512)
    tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
    img_metas = [{'valid_ratio': 1.0}]

    return feat, out_enc, tgt_dict, img_metas


def test_base_decoder():
    decoder = BaseDecoder()
    with pytest.raises(NotImplementedError):
        decoder.forward_train(None, None, None, None)
    with pytest.raises(NotImplementedError):
        decoder.forward_test(None, None, None)


def test_parallel_sar_decoder():
    # test parallel sar decoder
    decoder = ParallelSARDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
    decoder.init_weights()
    decoder.train()

    feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
    with pytest.raises(AssertionError):
        decoder(feat, out_enc, tgt_dict, [], True)
    with pytest.raises(AssertionError):
        decoder(feat, out_enc, tgt_dict, img_metas * 2, True)

    out_train = decoder(feat, out_enc, tgt_dict, img_metas, True)
    assert out_train.shape == torch.Size([1, 5, 36])

    out_test = decoder(feat, out_enc, tgt_dict, img_metas, False)
    assert out_test.shape == torch.Size([1, 5, 36])


def test_sequential_sar_decoder():
    # test parallel sar decoder
    decoder = SequentialSARDecoder(
        num_classes=37, padding_idx=36, max_seq_len=5)
    decoder.init_weights()
    decoder.train()

    feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
    with pytest.raises(AssertionError):
        decoder(feat, out_enc, tgt_dict, [])
    with pytest.raises(AssertionError):
        decoder(feat, out_enc, tgt_dict, img_metas * 2)

    out_train = decoder(feat, out_enc, tgt_dict, img_metas, True)
    assert out_train.shape == torch.Size([1, 5, 36])

    out_test = decoder(feat, out_enc, tgt_dict, img_metas, False)
    assert out_test.shape == torch.Size([1, 5, 36])


def test_parallel_sar_decoder_with_beam_search():
    with pytest.raises(AssertionError):
        ParallelSARDecoderWithBS(beam_width='beam')
    with pytest.raises(AssertionError):
        ParallelSARDecoderWithBS(beam_width=0)

    feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
    decoder = ParallelSARDecoderWithBS(
        beam_width=1, num_classes=37, padding_idx=36, max_seq_len=5)
    decoder.init_weights()
    decoder.train()
    with pytest.raises(AssertionError):
        decoder(feat, out_enc, tgt_dict, [])
    with pytest.raises(AssertionError):
        decoder(feat, out_enc, tgt_dict, img_metas * 2)

    out_test = decoder(feat, out_enc, tgt_dict, img_metas, train_mode=False)
    assert out_test.shape == torch.Size([1, 5, 36])

    # test decodenode
    with pytest.raises(AssertionError):
        DecodeNode(1, 1)
    with pytest.raises(AssertionError):
        DecodeNode([1, 2], ['4', '3'])
    with pytest.raises(AssertionError):
        DecodeNode([1, 2], [0.5])
    decode_node = DecodeNode([1, 2], [0.7, 0.8])
    assert math.isclose(decode_node.eval(), 1.5)


def test_transformer_decoder():
    decoder = NRTRDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
    decoder.init_weights()
    decoder.train()

    out_enc = torch.rand(1, 25, 512)
    tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
    img_metas = [{'valid_ratio': 1.0}]
    tgt_dict['padded_targets'] = tgt_dict['padded_targets']

    out_train = decoder(None, out_enc, tgt_dict, img_metas, True)
    assert out_train.shape == torch.Size([1, 5, 36])

    out_test = decoder(None, out_enc, tgt_dict, img_metas, False)
    assert out_test.shape == torch.Size([1, 5, 36])


def test_abi_language_decoder():
    decoder = ABILanguageDecoder(max_seq_len=25)
    logits = torch.randn(2, 25, 90)
    result = decoder(
        feat=None, out_enc=logits, targets_dict=None, img_metas=None)
    assert result['feature'].shape == torch.Size([2, 25, 512])
    assert result['logits'].shape == torch.Size([2, 25, 90])


def test_abi_vision_decoder():
    model = ABIVisionDecoder(
        in_channels=128, num_channels=16, max_seq_len=10, use_result=None)
    x = torch.randn(2, 128, 8, 32)
    result = model(x, None)
    assert result['feature'].shape == torch.Size([2, 10, 128])
    assert result['logits'].shape == torch.Size([2, 10, 90])
    assert result['attn_scores'].shape == torch.Size([2, 10, 8, 32])