File size: 2,128 Bytes
88435ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Any, TypeVar

from neollm.utils.utils import cprint

Immutable = tuple[Any, ...] | str | int | float | bool
_T = TypeVar("_T")
_TD = TypeVar("_TD")


def _to_immutable(x: Any) -> Immutable:
    """list, dictをtupleに変換して, setに格納できるようにする

    Args:
        x (Any): 要素

    Returns:
        Immutable: Immutableな要素(dict, listはtupleに変換)
    """
    if isinstance(x, list):
        return tuple(map(_to_immutable, x))
    if isinstance(x, dict):
        return tuple((key, _to_immutable(value)) for key, value in sorted(x.items()))
    if isinstance(x, (set, frozenset)):
        return tuple(sorted(map(_to_immutable, x)))
    if isinstance(x, (str, int, float, bool)):
        return x
    cprint("_to_immutable: not supported: 無理やりstr(*)", color="yellow", background=True)
    return str(x)


def _remove_duplicate(arr: list[_T | None]) -> list[_T]:
    """listの重複と初期値を削除する

    Args:
        arr (list[Any]): リスト

    Returns:
        list[Any]: 重複削除済みのlist
    """
    seen_set: set[Immutable] = set()
    unique_list: list[_T] = []
    for x in arr:
        if x is None or bool(x) is False:
            continue
        x_immutable = _to_immutable(x)
        if x_immutable not in seen_set:
            unique_list.append(x)
            seen_set.add(x_immutable)
    return unique_list


def get_entity(arr: list[_T | None], default: _TD, index: int | None = None) -> _T | _TD:
    """listから必要な1要素を取得する

    Args:
        arr (list[Any]): list
        default (Any): 初期値
        index (int | None, optional): 複数ある場合、指定のindex. Defaults to None.

    Returns:
        Any: 要素
    """
    arr_cleaned = _remove_duplicate(arr)
    if len(arr_cleaned) == 0:
        return default
    if len(arr_cleaned) == 1:
        return arr_cleaned[0]
    if index is not None:
        return arr_cleaned[index]
    cprint("get_entity: not unique", color="yellow", background=True)
    cprint(arr_cleaned, color="yellow", background=True)
    return arr_cleaned[0]