# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
import torch.nn as nn

from mmcv.runner.fp16_utils import auto_fp16, cast_tensor_type, force_fp32


def test_cast_tensor_type():
    inputs = torch.FloatTensor([5.])
    src_type = torch.float32
    dst_type = torch.int32
    outputs = cast_tensor_type(inputs, src_type, dst_type)
    assert isinstance(outputs, torch.Tensor)
    assert outputs.dtype == dst_type

    # convert torch.float to torch.half
    inputs = torch.FloatTensor([5.])
    src_type = torch.float
    dst_type = torch.half
    outputs = cast_tensor_type(inputs, src_type, dst_type)
    assert isinstance(outputs, torch.Tensor)
    assert outputs.dtype == dst_type

    # skip the conversion when the type of input is not the same as src_type
    inputs = torch.IntTensor([5])
    src_type = torch.float
    dst_type = torch.half
    outputs = cast_tensor_type(inputs, src_type, dst_type)
    assert isinstance(outputs, torch.Tensor)
    assert outputs.dtype == inputs.dtype

    inputs = 'tensor'
    src_type = str
    dst_type = str
    outputs = cast_tensor_type(inputs, src_type, dst_type)
    assert isinstance(outputs, str)

    inputs = np.array([5.])
    src_type = np.ndarray
    dst_type = np.ndarray
    outputs = cast_tensor_type(inputs, src_type, dst_type)
    assert isinstance(outputs, np.ndarray)

    inputs = dict(
        tensor_a=torch.FloatTensor([1.]), tensor_b=torch.FloatTensor([2.]))
    src_type = torch.float32
    dst_type = torch.int32
    outputs = cast_tensor_type(inputs, src_type, dst_type)
    assert isinstance(outputs, dict)
    assert outputs['tensor_a'].dtype == dst_type
    assert outputs['tensor_b'].dtype == dst_type

    inputs = [torch.FloatTensor([1.]), torch.FloatTensor([2.])]
    src_type = torch.float32
    dst_type = torch.int32
    outputs = cast_tensor_type(inputs, src_type, dst_type)
    assert isinstance(outputs, list)
    assert outputs[0].dtype == dst_type
    assert outputs[1].dtype == dst_type

    inputs = 5
    outputs = cast_tensor_type(inputs, None, None)
    assert isinstance(outputs, int)


def test_auto_fp16():

    with pytest.raises(TypeError):
        # ExampleObject is not a subclass of nn.Module

        class ExampleObject:

            @auto_fp16()
            def __call__(self, x):
                return x

        model = ExampleObject()
        input_x = torch.ones(1, dtype=torch.float32)
        model(input_x)

    # apply to all input args
    class ExampleModule(nn.Module):

        @auto_fp16()
        def forward(self, x, y):
            return x, y

    model = ExampleModule()
    input_x = torch.ones(1, dtype=torch.float32)
    input_y = torch.ones(1, dtype=torch.float32)
    output_x, output_y = model(input_x, input_y)
    assert output_x.dtype == torch.float32
    assert output_y.dtype == torch.float32

    model.fp16_enabled = True
    output_x, output_y = model(input_x, input_y)
    assert output_x.dtype == torch.half
    assert output_y.dtype == torch.half

    if torch.cuda.is_available():
        model.cuda()
        output_x, output_y = model(input_x.cuda(), input_y.cuda())
        assert output_x.dtype == torch.half
        assert output_y.dtype == torch.half

    # apply to specified input args
    class ExampleModule(nn.Module):

        @auto_fp16(apply_to=('x', ))
        def forward(self, x, y):
            return x, y

    model = ExampleModule()
    input_x = torch.ones(1, dtype=torch.float32)
    input_y = torch.ones(1, dtype=torch.float32)
    output_x, output_y = model(input_x, input_y)
    assert output_x.dtype == torch.float32
    assert output_y.dtype == torch.float32

    model.fp16_enabled = True
    output_x, output_y = model(input_x, input_y)
    assert output_x.dtype == torch.half
    assert output_y.dtype == torch.float32

    if torch.cuda.is_available():
        model.cuda()
        output_x, output_y = model(input_x.cuda(), input_y.cuda())
        assert output_x.dtype == torch.half
        assert output_y.dtype == torch.float32

    # apply to optional input args
    class ExampleModule(nn.Module):

        @auto_fp16(apply_to=('x', 'y'))
        def forward(self, x, y=None, z=None):
            return x, y, z

    model = ExampleModule()
    input_x = torch.ones(1, dtype=torch.float32)
    input_y = torch.ones(1, dtype=torch.float32)
    input_z = torch.ones(1, dtype=torch.float32)
    output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
    assert output_x.dtype == torch.float32
    assert output_y.dtype == torch.float32
    assert output_z.dtype == torch.float32

    model.fp16_enabled = True
    output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
    assert output_x.dtype == torch.half
    assert output_y.dtype == torch.half
    assert output_z.dtype == torch.float32

    if torch.cuda.is_available():
        model.cuda()
        output_x, output_y, output_z = model(
            input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
        assert output_x.dtype == torch.half
        assert output_y.dtype == torch.half
        assert output_z.dtype == torch.float32

    # out_fp32=True
    class ExampleModule(nn.Module):

        @auto_fp16(apply_to=('x', 'y'), out_fp32=True)
        def forward(self, x, y=None, z=None):
            return x, y, z

    model = ExampleModule()
    input_x = torch.ones(1, dtype=torch.half)
    input_y = torch.ones(1, dtype=torch.float32)
    input_z = torch.ones(1, dtype=torch.float32)
    output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
    assert output_x.dtype == torch.half
    assert output_y.dtype == torch.float32
    assert output_z.dtype == torch.float32

    model.fp16_enabled = True
    output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
    assert output_x.dtype == torch.float32
    assert output_y.dtype == torch.float32
    assert output_z.dtype == torch.float32

    if torch.cuda.is_available():
        model.cuda()
        output_x, output_y, output_z = model(
            input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
        assert output_x.dtype == torch.float32
        assert output_y.dtype == torch.float32
        assert output_z.dtype == torch.float32


def test_force_fp32():

    with pytest.raises(TypeError):
        # ExampleObject is not a subclass of nn.Module

        class ExampleObject:

            @force_fp32()
            def __call__(self, x):
                return x

        model = ExampleObject()
        input_x = torch.ones(1, dtype=torch.float32)
        model(input_x)

    # apply to all input args
    class ExampleModule(nn.Module):

        @force_fp32()
        def forward(self, x, y):
            return x, y

    model = ExampleModule()
    input_x = torch.ones(1, dtype=torch.half)
    input_y = torch.ones(1, dtype=torch.half)
    output_x, output_y = model(input_x, input_y)
    assert output_x.dtype == torch.half
    assert output_y.dtype == torch.half

    model.fp16_enabled = True
    output_x, output_y = model(input_x, input_y)
    assert output_x.dtype == torch.float32
    assert output_y.dtype == torch.float32

    if torch.cuda.is_available():
        model.cuda()
        output_x, output_y = model(input_x.cuda(), input_y.cuda())
        assert output_x.dtype == torch.float32
        assert output_y.dtype == torch.float32

    # apply to specified input args
    class ExampleModule(nn.Module):

        @force_fp32(apply_to=('x', ))
        def forward(self, x, y):
            return x, y

    model = ExampleModule()
    input_x = torch.ones(1, dtype=torch.half)
    input_y = torch.ones(1, dtype=torch.half)
    output_x, output_y = model(input_x, input_y)
    assert output_x.dtype == torch.half
    assert output_y.dtype == torch.half

    model.fp16_enabled = True
    output_x, output_y = model(input_x, input_y)
    assert output_x.dtype == torch.float32
    assert output_y.dtype == torch.half

    if torch.cuda.is_available():
        model.cuda()
        output_x, output_y = model(input_x.cuda(), input_y.cuda())
        assert output_x.dtype == torch.float32
        assert output_y.dtype == torch.half

    # apply to optional input args
    class ExampleModule(nn.Module):

        @force_fp32(apply_to=('x', 'y'))
        def forward(self, x, y=None, z=None):
            return x, y, z

    model = ExampleModule()
    input_x = torch.ones(1, dtype=torch.half)
    input_y = torch.ones(1, dtype=torch.half)
    input_z = torch.ones(1, dtype=torch.half)
    output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
    assert output_x.dtype == torch.half
    assert output_y.dtype == torch.half
    assert output_z.dtype == torch.half

    model.fp16_enabled = True
    output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
    assert output_x.dtype == torch.float32
    assert output_y.dtype == torch.float32
    assert output_z.dtype == torch.half

    if torch.cuda.is_available():
        model.cuda()
        output_x, output_y, output_z = model(
            input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
        assert output_x.dtype == torch.float32
        assert output_y.dtype == torch.float32
        assert output_z.dtype == torch.half

    # out_fp16=True
    class ExampleModule(nn.Module):

        @force_fp32(apply_to=('x', 'y'), out_fp16=True)
        def forward(self, x, y=None, z=None):
            return x, y, z

    model = ExampleModule()
    input_x = torch.ones(1, dtype=torch.float32)
    input_y = torch.ones(1, dtype=torch.half)
    input_z = torch.ones(1, dtype=torch.half)
    output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
    assert output_x.dtype == torch.float32
    assert output_y.dtype == torch.half
    assert output_z.dtype == torch.half

    model.fp16_enabled = True
    output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
    assert output_x.dtype == torch.half
    assert output_y.dtype == torch.half
    assert output_z.dtype == torch.half

    if torch.cuda.is_available():
        model.cuda()
        output_x, output_y, output_z = model(
            input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
        assert output_x.dtype == torch.half
        assert output_y.dtype == torch.half
        assert output_z.dtype == torch.half