File size: 3,159 Bytes
f5f3483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Epath FLAGS utils."""

from __future__ import annotations

import os
import sys
import typing
from typing import Optional

from etils.epath import abstract_path
from etils.epath import typing as epath_typing
from typing_extensions import Literal

if typing.TYPE_CHECKING:
  from absl import flags

if 'absl.flags' in sys.modules:
  from absl import flags  # pylint: disable=g-import-not-at-top]
  # Skip this module when detecting in which module the flag is defined.
  # This is required to avoid duplicate flag issues when reloading adhoc
  # imports.
  flags.disclaim_key_flags()  # pylint: disable=used-before-assignment


# required=True -> Path
@typing.overload
def DEFINE_path(  # pylint: disable=invalid-name
    name: str,
    default: None,
    help: str,  # pylint: disable=redefined-builtin
    flag_values: flags.FlagValues = ...,
    *,
    required: Literal[True],
    **kwargs,
) -> flags.FlagHolder[abstract_path.Path]:
  ...


# required=False, default=None -> Path | None
@typing.overload
def DEFINE_path(  # For consistency with other flags, pylint: disable=invalid-name
    name: str,
    default: None,
    help: str,  # pylint: disable=redefined-builtin
    flag_values: flags.FlagValues = ...,
    *,
    required: Literal[False] = False,
    **kwargs,
) -> flags.FlagHolder[Optional[abstract_path.Path]]:
  ...


# required=False, default='/path' -> Path
@typing.overload
def DEFINE_path(  # For consistency with other flags, pylint: disable=invalid-name
    name: str,
    default: epath_typing.PathLike,
    help: str,  # pylint: disable=redefined-builtin
    flag_values: flags.FlagValues = ...,
    *,
    required: Literal[False] = False,
    **kwargs,
) -> flags.FlagHolder[abstract_path.Path]:
  ...


def DEFINE_path(  # pylint: disable=invalid-name
    name,
    default,
    help,  # pylint: disable=redefined-builtin
    flag_values=None,
    *,
    required=False,
    **kwargs,
):
  """Defines a flag containing a epath.Path value."""

  # Lazy-import as absl is an optional dep
  from absl import flags  # pylint: disable=g-import-not-at-top

  if flag_values is None:
    flag_values = flags.FLAGS

  class _PathParser(flags.ArgumentParser):

    def parse(self, value):
      return abstract_path.Path(value)

  class _PathSerializer(flags.ArgumentSerializer):

    def serialize(self, value):
      return os.fspath(value)

  return flags.DEFINE(
      _PathParser(),
      name,
      default,
      help,
      flag_values,
      _PathSerializer(),
      required=required,
      **kwargs,
  )  # pytype: disable=bad-return-type