File size: 3,268 Bytes
4bdb245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import annotations

import argparse

from typing import List, Literal, Union, Any, Type, TypeVar

from pydantic import BaseModel


def _get_base_type(annotation: Type[Any]) -> Type[Any]:
    if getattr(annotation, "__origin__", None) is Literal:
        assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1  # type: ignore
        return type(annotation.__args__[0])  # type: ignore
    elif getattr(annotation, "__origin__", None) is Union:
        assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1  # type: ignore
        non_optional_args: List[Type[Any]] = [
            arg for arg in annotation.__args__ if arg is not type(None)  # type: ignore
        ]
        if non_optional_args:
            return _get_base_type(non_optional_args[0])
    elif (
        getattr(annotation, "__origin__", None) is list
        or getattr(annotation, "__origin__", None) is List
    ):
        assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1  # type: ignore
        return _get_base_type(annotation.__args__[0])  # type: ignore
    return annotation


def _contains_list_type(annotation: Type[Any] | None) -> bool:
    origin = getattr(annotation, "__origin__", None)

    if origin is list or origin is List:
        return True
    elif origin in (Literal, Union):
        return any(_contains_list_type(arg) for arg in annotation.__args__)  # type: ignore
    else:
        return False


def _parse_bool_arg(arg: str | bytes | bool) -> bool:
    if isinstance(arg, bytes):
        arg = arg.decode("utf-8")

    true_values = {"1", "on", "t", "true", "y", "yes"}
    false_values = {"0", "off", "f", "false", "n", "no"}

    arg_str = str(arg).lower().strip()

    if arg_str in true_values:
        return True
    elif arg_str in false_values:
        return False
    else:
        raise ValueError(f"Invalid boolean argument: {arg}")


def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]):
    """Add arguments from a pydantic model to an argparse parser."""

    for name, field in model.model_fields.items():
        description = field.description
        if field.default and description and not field.is_required():
            description += f" (default: {field.default})"
        base_type = (
            _get_base_type(field.annotation) if field.annotation is not None else str
        )
        list_type = _contains_list_type(field.annotation)
        if base_type is not bool:
            parser.add_argument(
                f"--{name}",
                dest=name,
                nargs="*" if list_type else None,
                type=base_type,
                help=description,
            )
        if base_type is bool:
            parser.add_argument(
                f"--{name}",
                dest=name,
                type=_parse_bool_arg,
                help=f"{description}",
            )


T = TypeVar("T", bound=Type[BaseModel])


def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
    """Parse a pydantic model from an argparse namespace."""
    return model(
        **{
            k: v
            for k, v in vars(args).items()
            if v is not None and k in model.model_fields
        }
    )