File size: 4,289 Bytes
78ab311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) Open-MMLab.
import sys
from collections.abc import Iterable
from runpy import run_path
from shlex import split
from typing import Any, Dict, List
from unittest.mock import patch


def check_python_script(cmd):
    """Run the python cmd script with `__main__`. The difference between
    `os.system` is that, this function exectues code in the current process, so
    that it can be tracked by coverage tools. Currently it supports two forms:

    - ./tests/data/scripts/hello.py zz
    - python tests/data/scripts/hello.py zz
    """
    args = split(cmd)
    if args[0] == 'python':
        args = args[1:]
    with patch.object(sys, 'argv', args):
        run_path(args[0], run_name='__main__')


def _any(judge_result):
    """Since built-in ``any`` works only when the element of iterable is not
    iterable, implement the function."""
    if not isinstance(judge_result, Iterable):
        return judge_result

    try:
        for element in judge_result:
            if _any(element):
                return True
    except TypeError:
        # Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
        if judge_result:
            return True
    return False


def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
                                expected_subset: Dict[Any, Any]) -> bool:
    """Check if the dict_obj contains the expected_subset.

    Args:
        dict_obj (Dict[Any, Any]): Dict object to be checked.
        expected_subset (Dict[Any, Any]): Subset expected to be contained in
            dict_obj.

    Returns:
        bool: Whether the dict_obj contains the expected_subset.
    """

    for key, value in expected_subset.items():
        if key not in dict_obj.keys() or _any(dict_obj[key] != value):
            return False
    return True


def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
    """Check if attribute of class object is correct.

    Args:
        obj (object): Class object to be checked.
        expected_attrs (Dict[str, Any]): Dict of the expected attrs.

    Returns:
        bool: Whether the attribute of class object is correct.
    """
    for attr, value in expected_attrs.items():
        if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
            return False
    return True


def assert_dict_has_keys(obj: Dict[str, Any],
                         expected_keys: List[str]) -> bool:
    """Check if the obj has all the expected_keys.

    Args:
        obj (Dict[str, Any]): Object to be checked.
        expected_keys (List[str]): Keys expected to contained in the keys of
            the obj.

    Returns:
        bool: Whether the obj has the expected keys.
    """
    return set(expected_keys).issubset(set(obj.keys()))


def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
    """Check if target_keys is equal to result_keys.

    Args:
        result_keys (List[str]): Result keys to be checked.
        target_keys (List[str]): Target keys to be checked.

    Returns:
        bool: Whether target_keys is equal to result_keys.
    """
    return set(result_keys) == set(target_keys)


def assert_is_norm_layer(module) -> bool:
    """Check if the module is a norm layer.

    Args:
        module (nn.Module): The module to be checked.

    Returns:
        bool: Whether the module is a norm layer.
    """
    from .parrots_wrapper import _BatchNorm, _InstanceNorm
    from torch.nn import GroupNorm, LayerNorm
    norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
    return isinstance(module, norm_layer_candidates)


def assert_params_all_zeros(module) -> bool:
    """Check if the parameters of the module is all zeros.

    Args:
        module (nn.Module): The module to be checked.

    Returns:
        bool: Whether the parameters of the module is all zeros.
    """
    weight_data = module.weight.data
    is_weight_zero = weight_data.allclose(
        weight_data.new_zeros(weight_data.size()))

    if hasattr(module, 'bias') and module.bias is not None:
        bias_data = module.bias.data
        is_bias_zero = bias_data.allclose(
            bias_data.new_zeros(bias_data.size()))
    else:
        is_bias_zero = True

    return is_weight_zero and is_bias_zero